diff --git a/diffdrr/_modidx.py b/diffdrr/_modidx.py index db423cd6a..97b7b2238 100644 --- a/diffdrr/_modidx.py +++ b/diffdrr/_modidx.py @@ -87,6 +87,8 @@ 'diffdrr.pose.RigidTransform.__getitem__': ('api/pose.html#rigidtransform.__getitem__', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.__init__': ('api/pose.html#rigidtransform.__init__', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.__len__': ('api/pose.html#rigidtransform.__len__', 'diffdrr/pose.py'), + 'diffdrr.pose.RigidTransform.__matmul__': ('api/pose.html#rigidtransform.__matmul__', 'diffdrr/pose.py'), + 'diffdrr.pose.RigidTransform.__new__': ('api/pose.html#rigidtransform.__new__', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.compose': ('api/pose.html#rigidtransform.compose', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.convert': ('api/pose.html#rigidtransform.convert', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.forward': ('api/pose.html#rigidtransform.forward', 'diffdrr/pose.py'), diff --git a/diffdrr/pose.py b/diffdrr/pose.py index 482005c94..f3d97e651 100644 --- a/diffdrr/pose.py +++ b/diffdrr/pose.py @@ -8,6 +8,7 @@ import torch from einops import rearrange +from roma import is_orthonormal_matrix class RigidTransform(torch.nn.Module): @@ -17,17 +18,29 @@ class RigidTransform(torch.nn.Module): inversion, and conversions to various representations of SE(3). """ - def __init__(self, matrix): + def __new__(cls, matrix, eps=1e-6): + if isinstance(matrix, cls): + return matrix + return super().__new__(cls) + + def __init__(self, matrix, eps=1e-6): + if isinstance(matrix, type(self)): + return + super().__init__() if matrix.dim() == 2: matrix = matrix.unsqueeze(0) self.register_buffer("matrix", matrix) + self.eps = eps def __len__(self): return len(self.matrix) def __getitem__(self, idx): - return self.matrix[idx] + return type(self)(self.matrix[idx]) + + def __matmul__(self, T): + return T.compose(self) def forward(self, x): """Apply (a batch) of rigid transforms to a pointcloud.""" @@ -43,16 +56,19 @@ def translation(self): return self.matrix[..., :3, 3] def inverse(self): - R = self.matrix[..., :3, :3] - t = self.matrix[..., :3, 3] - Rinv = R.mT - tinv = -torch.einsum("bij, bj -> bi", Rinv, t) - matrix = make_matrix(Rinv, tinv) - return RigidTransform(matrix) + if is_orthonormal_matrix(self.matrix[..., :3, :3], self.eps): + R = self.matrix[..., :3, :3] + t = self.matrix[..., :3, 3] + Rinv = R.mT + tinv = -torch.einsum("bij, bj -> bi", Rinv, t) + matrix = make_matrix(Rinv, tinv) + else: + matrix = self.matrix.inverse() + return type(self)(matrix) def compose(self, T): matrix = torch.einsum("bij, bjk -> bik", T.matrix, self.matrix) - return RigidTransform(matrix) + return type(self)(matrix) def convert(self, parameterization, convention=None, degrees=False): translation = -self.inverse().translation @@ -90,13 +106,13 @@ def get_se3_log(self): # %% ../notebooks/api/06_pose.ipynb 7 def make_matrix(R, t): - assert (batch_size := len(R)) == len(t) + batch_size = len(R) + assert batch_size == len(t) matrix = torch.zeros(batch_size, 4, 4).to(R) matrix[..., :3, :3] = R matrix[..., :3, 3] = t matrix[..., -1, -1] = 1.0 return matrix - # %% ../notebooks/api/06_pose.ipynb 8 from scipy.spatial.transform import Rotation diff --git a/environment.yml b/environment.yml index 89948b867..65fdecac3 100644 --- a/environment.yml +++ b/environment.yml @@ -20,3 +20,5 @@ dependencies: - tqdm - pyvista>=0.45.0 - vtk>9.4.0 + - pip: + - roma diff --git a/notebooks/api/06_pose.ipynb b/notebooks/api/06_pose.ipynb index 4db672906..2aad6d4b4 100644 --- a/notebooks/api/06_pose.ipynb +++ b/notebooks/api/06_pose.ipynb @@ -92,6 +92,7 @@ "import torch\n", "\n", "from einops import rearrange\n", + "from roma import is_orthonormal_matrix\n", "\n", "\n", "class RigidTransform(torch.nn.Module):\n", @@ -101,17 +102,29 @@ " inversion, and conversions to various representations of SE(3).\n", " \"\"\"\n", "\n", - " def __init__(self, matrix):\n", + " def __new__(cls, matrix, eps=1e-6):\n", + " if isinstance(matrix, cls):\n", + " return matrix\n", + " return super().__new__(cls)\n", + "\n", + " def __init__(self, matrix, eps=1e-6):\n", + " if isinstance(matrix, type(self)):\n", + " return \n", + "\n", " super().__init__()\n", " if matrix.dim() == 2:\n", " matrix = matrix.unsqueeze(0)\n", " self.register_buffer(\"matrix\", matrix)\n", + " self.eps = eps\n", "\n", " def __len__(self):\n", " return len(self.matrix)\n", "\n", " def __getitem__(self, idx):\n", - " return self.matrix[idx]\n", + " return type(self)(self.matrix[idx])\n", + "\n", + " def __matmul__(self, T):\n", + " return T.compose(self)\n", "\n", " def forward(self, x):\n", " \"\"\"Apply (a batch) of rigid transforms to a pointcloud.\"\"\"\n", @@ -127,16 +140,19 @@ " return self.matrix[..., :3, 3]\n", "\n", " def inverse(self):\n", - " R = self.matrix[..., :3, :3]\n", - " t = self.matrix[..., :3, 3]\n", - " Rinv = R.mT\n", - " tinv = -torch.einsum(\"bij, bj -> bi\", Rinv, t)\n", - " matrix = make_matrix(Rinv, tinv)\n", - " return RigidTransform(matrix)\n", + " if is_orthonormal_matrix(self.matrix[..., :3, :3], self.eps):\n", + " R = self.matrix[..., :3, :3]\n", + " t = self.matrix[..., :3, 3]\n", + " Rinv = R.mT\n", + " tinv = -torch.einsum(\"bij, bj -> bi\", Rinv, t)\n", + " matrix = make_matrix(Rinv, tinv)\n", + " else:\n", + " matrix = self.matrix.inverse()\n", + " return type(self)(matrix)\n", "\n", " def compose(self, T):\n", " matrix = torch.einsum(\"bij, bjk -> bik\", T.matrix, self.matrix)\n", - " return RigidTransform(matrix)\n", + " return type(self)(matrix)\n", "\n", " def convert(self, parameterization, convention=None, degrees=False):\n", " translation = -self.inverse().translation\n", @@ -182,7 +198,8 @@ "source": [ "#| exporti\n", "def make_matrix(R, t):\n", - " assert (batch_size := len(R)) == len(t)\n", + " batch_size = len(R)\n", + " assert batch_size == len(t)\n", " matrix = torch.zeros(batch_size, 4, 4).to(R)\n", " matrix[..., :3, :3] = R\n", " matrix[..., :3, 3] = t\n", diff --git a/settings.ini b/settings.ini index 66823fff2..89fe18984 100644 --- a/settings.ini +++ b/settings.ini @@ -26,8 +26,8 @@ keywords = nbdev jupyter notebook python language = English status = 3 user = eigenvivek -requirements = matplotlib seaborn tqdm imageio fastcore 'pyvista[all]' einops torchvision scipy torchio timm numpy kornia -pip_requirements = torch +requirements = matplotlib seaborn tqdm imageio fastcore 'pyvista[all]' einops torchvision scipy torchio timm numpy kornia roma +pip_requirements = torch roma conda_requirements = pytorch dev_requirements = nbdev black flake8 ipykernel ipywidgets jupyterlab jupyterlab_execute_time jupyterlab-code-formatter isort readme_nb = index.ipynb