From 506d786604fa04ec3cc36694f4a02c97e9672213 Mon Sep 17 00:00:00 2001 From: Vivek Gopalakrishnan Date: Tue, 11 Nov 2025 11:26:51 -0500 Subject: [PATCH 1/5] Add roma for faster rotation geodesic (#402) * Add RoMa to requirements * Switch to RoMa backend for rotation geodesic * Fix environment.yml * Add roma to pip requirements --- diffdrr/metrics.py | 6 ++---- environment.yml | 2 ++ notebooks/api/05_metrics.ipynb | 4 ++-- settings.ini | 3 +-- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/diffdrr/metrics.py b/diffdrr/metrics.py index a75265289..f1f0188aa 100644 --- a/diffdrr/metrics.py +++ b/diffdrr/metrics.py @@ -158,7 +158,7 @@ def forward( return pose_2.compose(pose_1.inverse()).get_se3_log().norm(dim=1) # %% ../notebooks/api/05_metrics.ipynb 18 -from .pose import so3_log_map +from roma import rotmat_geodesic_distance class DoubleGeodesicSE3(torch.nn.Module): @@ -175,9 +175,7 @@ def __init__( self.sdr = sdd / 2 self.eps = eps - self.rot_geo = lambda r1, r2: self.sdr * so3_log_map( - r1.transpose(-1, -2) @ r2 - ).norm(dim=-1) + self.rot_geo = lambda r1, r2: rotmat_geodesic_distance(r1, r2) self.xyz_geo = lambda t1, t2: (t1 - t2).norm(dim=-1) def forward(self, pose_1: RigidTransform, pose_2: RigidTransform): diff --git a/environment.yml b/environment.yml index 89948b867..65fdecac3 100644 --- a/environment.yml +++ b/environment.yml @@ -20,3 +20,5 @@ dependencies: - tqdm - pyvista>=0.45.0 - vtk>9.4.0 + - pip: + - roma diff --git a/notebooks/api/05_metrics.ipynb b/notebooks/api/05_metrics.ipynb index 74f993872..93dd5aec3 100644 --- a/notebooks/api/05_metrics.ipynb +++ b/notebooks/api/05_metrics.ipynb @@ -431,7 +431,7 @@ "outputs": [], "source": [ "#| export\n", - "from diffdrr.pose import so3_log_map\n", + "from roma import rotmat_geodesic_distance\n", "\n", "\n", "class DoubleGeodesicSE3(torch.nn.Module):\n", @@ -448,7 +448,7 @@ " self.sdr = sdd / 2\n", " self.eps = eps\n", "\n", - " self.rot_geo = lambda r1, r2: self.sdr * so3_log_map(r1.transpose(-1, -2) @ r2).norm(dim=-1)\n", + " self.rot_geo = lambda r1, r2: rotmat_geodesic_distance(r1, r2)\n", " self.xyz_geo = lambda t1, t2: (t1 - t2).norm(dim=-1)\n", "\n", " def forward(self, pose_1: RigidTransform, pose_2: RigidTransform):\n", diff --git a/settings.ini b/settings.ini index 66823fff2..90a706c62 100644 --- a/settings.ini +++ b/settings.ini @@ -26,7 +26,7 @@ keywords = nbdev jupyter notebook python language = English status = 3 user = eigenvivek -requirements = matplotlib seaborn tqdm imageio fastcore 'pyvista[all]' einops torchvision scipy torchio timm numpy kornia +requirements = matplotlib seaborn tqdm imageio fastcore 'pyvista[all]' einops torchvision scipy torchio timm numpy kornia roma pip_requirements = torch conda_requirements = pytorch dev_requirements = nbdev black flake8 ipykernel ipywidgets jupyterlab jupyterlab_execute_time jupyterlab-code-formatter isort @@ -38,4 +38,3 @@ clean_ids = True clear_all = False cell_number = True skip_procs = - From 202e84836f4adb03655432d93bae31e1770ae9c9 Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Tue, 11 Nov 2025 12:57:01 -0500 Subject: [PATCH 2/5] Revert "Add roma for faster rotation geodesic (#402)" This reverts commit 506d786604fa04ec3cc36694f4a02c97e9672213. --- diffdrr/metrics.py | 6 ++++-- environment.yml | 2 -- notebooks/api/05_metrics.ipynb | 4 ++-- settings.ini | 3 ++- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/diffdrr/metrics.py b/diffdrr/metrics.py index f1f0188aa..a75265289 100644 --- a/diffdrr/metrics.py +++ b/diffdrr/metrics.py @@ -158,7 +158,7 @@ def forward( return pose_2.compose(pose_1.inverse()).get_se3_log().norm(dim=1) # %% ../notebooks/api/05_metrics.ipynb 18 -from roma import rotmat_geodesic_distance +from .pose import so3_log_map class DoubleGeodesicSE3(torch.nn.Module): @@ -175,7 +175,9 @@ def __init__( self.sdr = sdd / 2 self.eps = eps - self.rot_geo = lambda r1, r2: rotmat_geodesic_distance(r1, r2) + self.rot_geo = lambda r1, r2: self.sdr * so3_log_map( + r1.transpose(-1, -2) @ r2 + ).norm(dim=-1) self.xyz_geo = lambda t1, t2: (t1 - t2).norm(dim=-1) def forward(self, pose_1: RigidTransform, pose_2: RigidTransform): diff --git a/environment.yml b/environment.yml index 65fdecac3..89948b867 100644 --- a/environment.yml +++ b/environment.yml @@ -20,5 +20,3 @@ dependencies: - tqdm - pyvista>=0.45.0 - vtk>9.4.0 - - pip: - - roma diff --git a/notebooks/api/05_metrics.ipynb b/notebooks/api/05_metrics.ipynb index 93dd5aec3..74f993872 100644 --- a/notebooks/api/05_metrics.ipynb +++ b/notebooks/api/05_metrics.ipynb @@ -431,7 +431,7 @@ "outputs": [], "source": [ "#| export\n", - "from roma import rotmat_geodesic_distance\n", + "from diffdrr.pose import so3_log_map\n", "\n", "\n", "class DoubleGeodesicSE3(torch.nn.Module):\n", @@ -448,7 +448,7 @@ " self.sdr = sdd / 2\n", " self.eps = eps\n", "\n", - " self.rot_geo = lambda r1, r2: rotmat_geodesic_distance(r1, r2)\n", + " self.rot_geo = lambda r1, r2: self.sdr * so3_log_map(r1.transpose(-1, -2) @ r2).norm(dim=-1)\n", " self.xyz_geo = lambda t1, t2: (t1 - t2).norm(dim=-1)\n", "\n", " def forward(self, pose_1: RigidTransform, pose_2: RigidTransform):\n", diff --git a/settings.ini b/settings.ini index 90a706c62..66823fff2 100644 --- a/settings.ini +++ b/settings.ini @@ -26,7 +26,7 @@ keywords = nbdev jupyter notebook python language = English status = 3 user = eigenvivek -requirements = matplotlib seaborn tqdm imageio fastcore 'pyvista[all]' einops torchvision scipy torchio timm numpy kornia roma +requirements = matplotlib seaborn tqdm imageio fastcore 'pyvista[all]' einops torchvision scipy torchio timm numpy kornia pip_requirements = torch conda_requirements = pytorch dev_requirements = nbdev black flake8 ipykernel ipywidgets jupyterlab jupyterlab_execute_time jupyterlab-code-formatter isort @@ -38,3 +38,4 @@ clean_ids = True clear_all = False cell_number = True skip_procs = + From 5587c766f8fe20682921114f0233603cacf6730a Mon Sep 17 00:00:00 2001 From: Vivek Gopalakrishnan Date: Sat, 15 Nov 2025 21:15:10 -0500 Subject: [PATCH 3/5] Add orthogonality check to RigidTransform inverse (#403) * Add roma to dependencies * Add rotation matrix check in inverse * Make Rigid.Transform__getitem__ return a RigidTransform * Change check from rotation to orthonormal * Add roma to requirements * Roll back Rigid.Transform__getitem__ return a RigidTransform --- diffdrr/pose.py | 17 +++++++++++------ environment.yml | 2 ++ notebooks/api/06_pose.ipynb | 17 +++++++++++------ settings.ini | 4 ++-- 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/diffdrr/pose.py b/diffdrr/pose.py index 482005c94..f80964176 100644 --- a/diffdrr/pose.py +++ b/diffdrr/pose.py @@ -8,6 +8,7 @@ import torch from einops import rearrange +from roma import is_orthonormal_matrix class RigidTransform(torch.nn.Module): @@ -17,11 +18,12 @@ class RigidTransform(torch.nn.Module): inversion, and conversions to various representations of SE(3). """ - def __init__(self, matrix): + def __init__(self, matrix, eps=1e-6): super().__init__() if matrix.dim() == 2: matrix = matrix.unsqueeze(0) self.register_buffer("matrix", matrix) + self.eps = eps def __len__(self): return len(self.matrix) @@ -43,11 +45,14 @@ def translation(self): return self.matrix[..., :3, 3] def inverse(self): - R = self.matrix[..., :3, :3] - t = self.matrix[..., :3, 3] - Rinv = R.mT - tinv = -torch.einsum("bij, bj -> bi", Rinv, t) - matrix = make_matrix(Rinv, tinv) + if is_orthonormal_matrix(self.matrix[..., :3, :3], self.eps): + R = self.matrix[..., :3, :3] + t = self.matrix[..., :3, 3] + Rinv = R.mT + tinv = -torch.einsum("bij, bj -> bi", Rinv, t) + matrix = make_matrix(Rinv, tinv) + else: + matrix = self.matrix.inverse() return RigidTransform(matrix) def compose(self, T): diff --git a/environment.yml b/environment.yml index 89948b867..65fdecac3 100644 --- a/environment.yml +++ b/environment.yml @@ -20,3 +20,5 @@ dependencies: - tqdm - pyvista>=0.45.0 - vtk>9.4.0 + - pip: + - roma diff --git a/notebooks/api/06_pose.ipynb b/notebooks/api/06_pose.ipynb index 4db672906..b565ea2e8 100644 --- a/notebooks/api/06_pose.ipynb +++ b/notebooks/api/06_pose.ipynb @@ -92,6 +92,7 @@ "import torch\n", "\n", "from einops import rearrange\n", + "from roma import is_orthonormal_matrix\n", "\n", "\n", "class RigidTransform(torch.nn.Module):\n", @@ -101,11 +102,12 @@ " inversion, and conversions to various representations of SE(3).\n", " \"\"\"\n", "\n", - " def __init__(self, matrix):\n", + " def __init__(self, matrix, eps=1e-6):\n", " super().__init__()\n", " if matrix.dim() == 2:\n", " matrix = matrix.unsqueeze(0)\n", " self.register_buffer(\"matrix\", matrix)\n", + " self.eps = eps\n", "\n", " def __len__(self):\n", " return len(self.matrix)\n", @@ -127,11 +129,14 @@ " return self.matrix[..., :3, 3]\n", "\n", " def inverse(self):\n", - " R = self.matrix[..., :3, :3]\n", - " t = self.matrix[..., :3, 3]\n", - " Rinv = R.mT\n", - " tinv = -torch.einsum(\"bij, bj -> bi\", Rinv, t)\n", - " matrix = make_matrix(Rinv, tinv)\n", + " if is_orthonormal_matrix(self.matrix[..., :3, :3], self.eps):\n", + " R = self.matrix[..., :3, :3]\n", + " t = self.matrix[..., :3, 3]\n", + " Rinv = R.mT\n", + " tinv = -torch.einsum(\"bij, bj -> bi\", Rinv, t)\n", + " matrix = make_matrix(Rinv, tinv)\n", + " else:\n", + " matrix = self.matrix.inverse()\n", " return RigidTransform(matrix)\n", "\n", " def compose(self, T):\n", diff --git a/settings.ini b/settings.ini index 66823fff2..89fe18984 100644 --- a/settings.ini +++ b/settings.ini @@ -26,8 +26,8 @@ keywords = nbdev jupyter notebook python language = English status = 3 user = eigenvivek -requirements = matplotlib seaborn tqdm imageio fastcore 'pyvista[all]' einops torchvision scipy torchio timm numpy kornia -pip_requirements = torch +requirements = matplotlib seaborn tqdm imageio fastcore 'pyvista[all]' einops torchvision scipy torchio timm numpy kornia roma +pip_requirements = torch roma conda_requirements = pytorch dev_requirements = nbdev black flake8 ipykernel ipywidgets jupyterlab jupyterlab_execute_time jupyterlab-code-formatter isort readme_nb = index.ipynb From c4c72270dac27683a55f57e36ee5071f5b20efc6 Mon Sep 17 00:00:00 2001 From: Vivek Gopalakrishnan Date: Sat, 15 Nov 2025 23:22:36 -0500 Subject: [PATCH 4/5] Various utilities for RigidTransforms (#404) * If input is RigidTransform, return input * Remove hardcoded class name * Make slicing return a RigidTransform * Add a matmul overload --- diffdrr/_modidx.py | 2 ++ diffdrr/pose.py | 17 ++++++++++++++--- notebooks/api/06_pose.ipynb | 17 ++++++++++++++--- 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/diffdrr/_modidx.py b/diffdrr/_modidx.py index db423cd6a..97b7b2238 100644 --- a/diffdrr/_modidx.py +++ b/diffdrr/_modidx.py @@ -87,6 +87,8 @@ 'diffdrr.pose.RigidTransform.__getitem__': ('api/pose.html#rigidtransform.__getitem__', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.__init__': ('api/pose.html#rigidtransform.__init__', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.__len__': ('api/pose.html#rigidtransform.__len__', 'diffdrr/pose.py'), + 'diffdrr.pose.RigidTransform.__matmul__': ('api/pose.html#rigidtransform.__matmul__', 'diffdrr/pose.py'), + 'diffdrr.pose.RigidTransform.__new__': ('api/pose.html#rigidtransform.__new__', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.compose': ('api/pose.html#rigidtransform.compose', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.convert': ('api/pose.html#rigidtransform.convert', 'diffdrr/pose.py'), 'diffdrr.pose.RigidTransform.forward': ('api/pose.html#rigidtransform.forward', 'diffdrr/pose.py'), diff --git a/diffdrr/pose.py b/diffdrr/pose.py index f80964176..92c9dcb2a 100644 --- a/diffdrr/pose.py +++ b/diffdrr/pose.py @@ -18,7 +18,15 @@ class RigidTransform(torch.nn.Module): inversion, and conversions to various representations of SE(3). """ + def __new__(cls, matrix, eps=1e-6): + if isinstance(matrix, cls): + return matrix + return super().__new__(cls) + def __init__(self, matrix, eps=1e-6): + if isinstance(matrix, type(self)): + return + super().__init__() if matrix.dim() == 2: matrix = matrix.unsqueeze(0) @@ -29,7 +37,10 @@ def __len__(self): return len(self.matrix) def __getitem__(self, idx): - return self.matrix[idx] + return type(self)(self.matrix[idx]) + + def __matmul__(self, T): + return T.compose(self) def forward(self, x): """Apply (a batch) of rigid transforms to a pointcloud.""" @@ -53,11 +64,11 @@ def inverse(self): matrix = make_matrix(Rinv, tinv) else: matrix = self.matrix.inverse() - return RigidTransform(matrix) + return type(self)(matrix) def compose(self, T): matrix = torch.einsum("bij, bjk -> bik", T.matrix, self.matrix) - return RigidTransform(matrix) + return type(self)(matrix) def convert(self, parameterization, convention=None, degrees=False): translation = -self.inverse().translation diff --git a/notebooks/api/06_pose.ipynb b/notebooks/api/06_pose.ipynb index b565ea2e8..8555a35e2 100644 --- a/notebooks/api/06_pose.ipynb +++ b/notebooks/api/06_pose.ipynb @@ -102,7 +102,15 @@ " inversion, and conversions to various representations of SE(3).\n", " \"\"\"\n", "\n", + " def __new__(cls, matrix, eps=1e-6):\n", + " if isinstance(matrix, cls):\n", + " return matrix\n", + " return super().__new__(cls)\n", + "\n", " def __init__(self, matrix, eps=1e-6):\n", + " if isinstance(matrix, type(self)):\n", + " return \n", + "\n", " super().__init__()\n", " if matrix.dim() == 2:\n", " matrix = matrix.unsqueeze(0)\n", @@ -113,7 +121,10 @@ " return len(self.matrix)\n", "\n", " def __getitem__(self, idx):\n", - " return self.matrix[idx]\n", + " return type(self)(self.matrix[idx])\n", + "\n", + " def __matmul__(self, T):\n", + " return T.compose(self)\n", "\n", " def forward(self, x):\n", " \"\"\"Apply (a batch) of rigid transforms to a pointcloud.\"\"\"\n", @@ -137,11 +148,11 @@ " matrix = make_matrix(Rinv, tinv)\n", " else:\n", " matrix = self.matrix.inverse()\n", - " return RigidTransform(matrix)\n", + " return type(self)(matrix)\n", "\n", " def compose(self, T):\n", " matrix = torch.einsum(\"bij, bjk -> bik\", T.matrix, self.matrix)\n", - " return RigidTransform(matrix)\n", + " return type(self)(matrix)\n", "\n", " def convert(self, parameterization, convention=None, degrees=False):\n", " translation = -self.inverse().translation\n", From 2e402a47e90dab3e68006b37246b3864dd0ba359 Mon Sep 17 00:00:00 2001 From: Henry Krumb Date: Thu, 18 Dec 2025 15:47:33 +0100 Subject: [PATCH 5/5] fix walrus operator in assert statement, leading to problems with optimized code --- diffdrr/pose.py | 4 ++-- notebooks/api/06_pose.ipynb | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/diffdrr/pose.py b/diffdrr/pose.py index 92c9dcb2a..f3d97e651 100644 --- a/diffdrr/pose.py +++ b/diffdrr/pose.py @@ -106,13 +106,13 @@ def get_se3_log(self): # %% ../notebooks/api/06_pose.ipynb 7 def make_matrix(R, t): - assert (batch_size := len(R)) == len(t) + batch_size = len(R) + assert batch_size == len(t) matrix = torch.zeros(batch_size, 4, 4).to(R) matrix[..., :3, :3] = R matrix[..., :3, 3] = t matrix[..., -1, -1] = 1.0 return matrix - # %% ../notebooks/api/06_pose.ipynb 8 from scipy.spatial.transform import Rotation diff --git a/notebooks/api/06_pose.ipynb b/notebooks/api/06_pose.ipynb index 8555a35e2..2aad6d4b4 100644 --- a/notebooks/api/06_pose.ipynb +++ b/notebooks/api/06_pose.ipynb @@ -198,7 +198,8 @@ "source": [ "#| exporti\n", "def make_matrix(R, t):\n", - " assert (batch_size := len(R)) == len(t)\n", + " batch_size = len(R)\n", + " assert batch_size == len(t)\n", " matrix = torch.zeros(batch_size, 4, 4).to(R)\n", " matrix[..., :3, :3] = R\n", " matrix[..., :3, 3] = t\n",