From 898274a6c09bd15c5fd1234c7f98eb5caf3c927e Mon Sep 17 00:00:00 2001 From: ppegolo Date: Fri, 26 Jun 2026 21:48:25 +0200 Subject: [PATCH 1/3] Add Wigner-D matrix infrastructure Real Wigner-D matrices for O(3), used to rotate spherical (o3_mu) tensors. `compute_wigner_batch` builds a batch of real D matrices up to a given angular momentum from ZYZ Euler angles via the standard recursion. A small TorchScript-compatibility `jit` helper (`_jit_compat`) lets the recursion be scripted. --- .../metatomic/torch/_augmentation/_wigner.py | 527 ++++++++++++++++++ .../metatomic/torch/_jit_compat.py | 13 + 2 files changed, 540 insertions(+) create mode 100644 python/metatomic_torch/metatomic/torch/_augmentation/_wigner.py create mode 100644 python/metatomic_torch/metatomic/torch/_jit_compat.py diff --git a/python/metatomic_torch/metatomic/torch/_augmentation/_wigner.py b/python/metatomic_torch/metatomic/torch/_augmentation/_wigner.py new file mode 100644 index 000000000..213d0644f --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/_augmentation/_wigner.py @@ -0,0 +1,527 @@ +"""Private Wigner-d/Wigner-D helpers for symmetry operations. + +Adapted from the `spherical` project (MIT license), primarily from +`spherical/recursions/wignerH.py`, `spherical/utilities/indexing.py`, and the Wigner-D +assembly logic in `spherical/wigner.py`. + +This reduced metatomic copy keeps only the recurrence-based pieces needed to build real +Wigner-D matrices from ZYZ Euler angles. It intentionally does not depend on +`spinsfast`, `quaternionic`, or the public `spherical` package. +""" + +import functools + +import numpy as np +import torch + +from .._jit_compat import jit + + +@jit +def _epsilon(m: int) -> int: + if m <= 0: + return 1 + if m % 2: + return -1 + return 1 + + +@jit +def _nm_index(n: int, m: int) -> int: + return m + n * (n + 1) + + +@jit +def _nabsm_index(n: int, absm: int) -> int: + return absm + (n * (n + 1)) // 2 + + +@jit +def _wigner_h_size(mp_max: int, ell_max: int) -> int: + if ell_max < 0: + return 0 + if mp_max >= ell_max: + return (ell_max + 1) * (ell_max + 2) * (2 * ell_max + 3) // 6 + + return ( + (ell_max + 1) * (ell_max + 2) * (2 * ell_max + 3) + - 2 * (ell_max - mp_max) * (ell_max - mp_max + 1) * (ell_max - mp_max + 2) + ) // 6 + + +@jit +def _wigner_d_size(ell_min: int, mp_max: int, ell_max: int) -> int: + if mp_max >= ell_max: + return ( + ell_max * (ell_max * (4 * ell_max + 12) + 11) + + ell_min * (1 - 4 * ell_min**2) + + 3 + ) // 3 + if mp_max > ell_min: + return ( + 3 * ell_max * (ell_max + 2) + + ell_min * (1 - 4 * ell_min**2) + + mp_max + * (3 * ell_max * (2 * ell_max + 4) + mp_max * (-2 * mp_max - 3) + 5) + + 3 + ) // 3 + + return (ell_max * (ell_max + 2) - ell_min**2) * (1 + 2 * mp_max) + 2 * mp_max + 1 + + +@jit +def _wigner_h_index_base(ell: int, mp: int, m: int, mp_max: int) -> int: + local_mp_max = mp_max + if local_mp_max > ell: + local_mp_max = ell + idx = _wigner_h_size(local_mp_max, ell - 1) + if mp < 1: + idx += (local_mp_max + mp) * (2 * ell - local_mp_max + mp + 1) // 2 + else: + idx += (local_mp_max + 1) * (2 * ell - local_mp_max + 2) // 2 + idx += (mp - 1) * (2 * ell - mp + 2) // 2 + idx += m - abs(mp) + return idx + + +@jit +def _wigner_h_index(ell: int, mp: int, m: int, mp_max: int) -> int: + if ell == 0: + return 0 + + local_mp_max = mp_max + if local_mp_max > ell: + local_mp_max = ell + + if m < -mp: + if m < mp: + return _wigner_h_index_base(ell, -mp, -m, local_mp_max) + return _wigner_h_index_base(ell, -m, -mp, local_mp_max) + + if m < mp: + return _wigner_h_index_base(ell, m, mp, local_mp_max) + return _wigner_h_index_base(ell, mp, m, local_mp_max) + + +@jit +def _wigner_d_index(ell: int, mp: int, m: int, ell_min: int, mp_max: int) -> int: + idx = 0 + for ell_prev in range(ell_min, ell): + local_mp_max = mp_max if mp_max < ell_prev else ell_prev + idx += (2 * local_mp_max + 1) * (2 * ell_prev + 1) + + local_mp_max = mp_max if mp_max < ell else ell + idx += (mp + local_mp_max) * (2 * ell + 1) + idx += m + ell + return idx + + +@jit +def _step_1(hwedge): + """Seed the recurrence: D^0_{0,0} = 1.""" + hwedge[0] = 1.0 + + +@jit +def _step_2(g, h, n_max, mp_max, hwedge, hextra, hv, expi_beta): + """Fill the mp=0 column (top diagonal) of H-wedge using the g/h recurrence.""" + cos_beta = expi_beta.real + sin_beta = expi_beta.imag + sqrt3 = np.sqrt(3.0) + inverse_sqrt2 = 1.0 / np.sqrt(2.0) + if n_max > 0: + n0n_index = _wigner_h_index(1, 0, 1, mp_max) + nn_index = _nm_index(1, 1) + hwedge[n0n_index] = sqrt3 + hwedge[n0n_index - 1] = (g[nn_index - 1] * cos_beta) * inverse_sqrt2 + for n in range(2, n_max + 2): + if n <= n_max: + n0n_index = _wigner_h_index(n, 0, n, mp_max) + out = hwedge + else: + n0n_index = n + out = hextra + prev_index = _wigner_h_index(n - 1, 0, n - 1, mp_max) + nn_index = _nm_index(n, n) + const = np.sqrt(1.0 + 0.5 / n) + g_i = g[nn_index - 1] + out[n0n_index] = const * hwedge[prev_index] + out[n0n_index - 1] = g_i * cos_beta * out[n0n_index] + for i in range(2, n): + g_i = g[nn_index - i] + h_i = h[nn_index - i] + out[n0n_index - i] = ( + g_i * cos_beta * out[n0n_index - i + 1] + - h_i * sin_beta**2 * out[n0n_index - i + 2] + ) + const = 1.0 / np.sqrt(4 * n + 2) + g_i = g[nn_index - n] + h_i = h[nn_index - n] + out[n0n_index - n] = ( + g_i * cos_beta * out[n0n_index - n + 1] + - h_i * sin_beta**2 * out[n0n_index - n + 2] + ) * const + prefactor = const + for i in range(1, n): + prefactor *= sin_beta + out[n0n_index - n + i] *= prefactor + if n <= n_max: + hv[_nm_index(n, 1)] = hwedge[_wigner_h_index(n, 0, 1, mp_max)] + hv[_nm_index(n, 0)] = hwedge[_wigner_h_index(n, 0, 1, mp_max)] + prefactor = 1.0 + for n in range(1, n_max + 1): + prefactor *= sin_beta + hwedge[_wigner_h_index(n, 0, n, mp_max)] *= prefactor / np.sqrt(4 * n + 2) + prefactor *= sin_beta + hextra[n_max + 1] *= prefactor / np.sqrt(4 * (n_max + 1) + 2) + hv[_nm_index(1, 1)] = hwedge[_wigner_h_index(1, 0, 1, mp_max)] + hv[_nm_index(1, 0)] = hwedge[_wigner_h_index(1, 0, 1, mp_max)] + + +@jit +def _step_3(a, b, n_max, mp_max, hwedge, hextra, expi_beta): + """Fill the mp=1 sub-diagonal of H-wedge using the a/b recurrence.""" + cos_beta = expi_beta.real + sin_beta = expi_beta.imag + if n_max > 0 and mp_max > 0: + for n in range(1, n_max + 1): + i1 = _wigner_h_index(n, 1, 1, mp_max) + if n + 1 <= n_max: + i2 = _wigner_h_index(n + 1, 0, 0, mp_max) + h2 = hwedge + else: + i2 = 0 + h2 = hextra + i3 = _nm_index(n + 1, 0) + i4 = _nabsm_index(n, 1) + inverse_b5 = 1.0 / b[i3] + for i in range(n): + b6 = b[-i + i3 - 2] + b7 = b[i + i3] + a8 = a[i + i4] + hwedge[i + i1] = inverse_b5 * ( + 0.5 + * ( + b6 * (1 - cos_beta) * h2[i + i2 + 2] + - b7 * (1 + cos_beta) * h2[i + i2] + ) + - a8 * sin_beta * h2[i + i2 + 1] + ) + + +@jit +def _step_4(d, n_max, mp_max, hwedge, hv): + """Fill H-wedge for mp >= 2 by stepping up in mp via the d recurrence.""" + if n_max > 0 and mp_max > 0: + for n in range(2, n_max + 1): + for mp in range(1, min(n, mp_max)): + i1 = _wigner_h_index(n, mp + 1, mp + 1, mp_max) - 1 + i2 = _wigner_h_index(n, mp - 1, mp, mp_max) + i3 = _wigner_h_index(n, mp, mp, mp_max) - 1 + i4 = _wigner_h_index(n, mp, mp + 1, mp_max) + i5 = _nm_index(n, mp) + i6 = _nm_index(n, mp - 1) + inverse_d5 = 1.0 / d[i5] + d6 = d[i6] + hv[_nm_index(n, mp + 1)] = inverse_d5 * ( + d6 * hwedge[i2] - d[i6] * hv[_nm_index(n, mp)] + d[i5] * hwedge[i4] + ) + for i in range(1, n - mp): + d7 = d[i + i6] + d8 = d[i + i5] + hwedge[i + i1] = inverse_d5 * ( + d6 * hwedge[i + i2] - d7 * hwedge[i + i3] + d8 * hwedge[i + i4] + ) + i = n - mp + hwedge[i + i1] = inverse_d5 * ( + d6 * hwedge[i + i2] - d[i + i6] * hwedge[i + i3] + ) + + +@jit +def _step_5(d, n_max, mp_max, hwedge, hv): + """Fill H-wedge for mp <= 0 by stepping down in mp via the d recurrence.""" + if n_max > 0 and mp_max > 0: + for n in range(0, n_max + 1): + for mp in range(0, -min(n, mp_max), -1): + i1 = _wigner_h_index(n, mp - 1, -mp + 1, mp_max) - 1 + i2 = _wigner_h_index(n, mp + 1, -mp + 1, mp_max) - 1 + i3 = _wigner_h_index(n, mp, -mp, mp_max) - 1 + i4 = _wigner_h_index(n, mp, -mp + 1, mp_max) + i5 = _nm_index(n, mp - 1) + i6 = _nm_index(n, mp) + i7 = _nm_index(n, -mp - 1) + i8 = _nm_index(n, -mp) + inverse_d5 = 1.0 / d[i5] + d6 = d[i6] + d7 = d[i7] + d8 = d[i8] + if mp == 0: + hv[_nm_index(n, mp - 1)] = inverse_d5 * ( + d6 * hv[_nm_index(n, mp + 1)] + + d7 * hv[_nm_index(n, mp)] + - d8 * hwedge[i4] + ) + else: + hv[_nm_index(n, mp - 1)] = inverse_d5 * ( + d6 * hwedge[i2] + d7 * hv[_nm_index(n, mp)] - d8 * hwedge[i4] + ) + for i in range(1, n + mp): + d7 = d[i + i7] + d8 = d[i + i8] + hwedge[i + i1] = inverse_d5 * ( + d6 * hwedge[i + i2] + d7 * hwedge[i + i3] - d8 * hwedge[i + i4] + ) + i = n + mp + hwedge[i + i1] = inverse_d5 * ( + d6 * hwedge[i + i2] + d[i + i7] * hwedge[i + i3] + ) + + +def _create_wigner_coefficients(ell_max: int): + """Pre-compute the scalar recurrence coefficients used by the five-step Risbo + recurrence. + + Returns five arrays ``(a, b, d, g, h)`` indexed by ``(n, m)`` pairs flattened in + lexicographic order up to ``ell_max + 1`` (one step beyond the target to seed the + recurrence). The arrays contain the coupling constants that appear in the three-term + recurrences for the Wigner d-matrix entries. + + :param ell_max: highest angular-momentum order needed + :return: ``(a, b, d, g, h)`` numpy float arrays + """ + n = np.array([n for n in range(ell_max + 2) for _ in range(-n, n + 1)]) + m = np.array([m for n in range(ell_max + 2) for m in range(-n, n + 1)]) + absn = np.array([n for n in range(ell_max + 2) for _ in range(n + 1)]) + absm = np.array([m for n in range(ell_max + 2) for m in range(n + 1)]) + + a = np.sqrt( + (absn + 1 + absm) * (absn + 1 - absm) / ((2 * absn + 1) * (2 * absn + 3)) + ) + b = np.sqrt((n - m - 1) * (n - m) / ((2 * n - 1) * (2 * n + 1))) + b[m < 0] *= -1 + d = 0.5 * np.sqrt((n - m) * (n + m + 1)) + d[m < 0] *= -1 + with np.errstate(divide="ignore", invalid="ignore"): + g = 2 * (m + 1) / np.sqrt((n - m) * (n + m + 1)) + h = np.sqrt((n + m + 2) * (n - m - 1) / ((n - m) * (n + m + 1))) + return a, b, d, g, h + + +def _complex_powers(z: complex, ell_max: int) -> np.ndarray: + powers = np.empty(ell_max + 1, dtype=np.complex128) + powers[0] = 1.0 + 0.0j + for idx in range(1, ell_max + 1): + powers[idx] = powers[idx - 1] * z + return powers + + +def _to_euler_phases( + alpha: float, beta: float, gamma: float +) -> tuple[complex, complex, complex]: + # Match spherical.Wigner's convention after converting scipy's ZYZ Euler angles + # into the phases used by the recurrence. + z_alpha = np.exp(-1j * alpha) + expi_beta = np.exp(1j * beta) + z_gamma = np.exp(-1j * gamma) + return z_alpha, expi_beta, z_gamma + + +def _compute_wigner_d_complex( + ell_max: int, alpha: np.ndarray, beta: np.ndarray, gamma: np.ndarray +) -> np.ndarray: + """Compute complex Wigner-D matrix elements for all ell in ``[0, ell_max]``. + + Implements the five-step Risbo/Trapani-Navaza recurrence (see module docstring). + Results are stored in a flat array indexed by :func:`_wigner_d_index`. + + :param ell_max: maximum angular-momentum order + :param alpha: ZYZ first rotation angle, arbitrary shape + :param beta: ZYZ second rotation angle, same shape as ``alpha`` + :param gamma: ZYZ third rotation angle, same shape as ``alpha`` + :return: complex array of shape ``(*alpha.shape, dsize)`` + """ + if not (alpha.shape == beta.shape == gamma.shape): + raise ValueError("alpha, beta, and gamma must have identical shapes") + + mp_max = ell_max + a, b, d, g, h = _create_wigner_coefficients(ell_max) + hsize = _wigner_h_size(mp_max, ell_max) + dsize = _wigner_d_size(0, mp_max, ell_max) + result = np.zeros(alpha.shape + (dsize,), dtype=np.complex128) + + for index in np.ndindex(alpha.shape): + z_alpha, expi_beta, z_gamma = _to_euler_phases( + float(alpha[index]), float(beta[index]), float(gamma[index]) + ) + hwedge = np.zeros(hsize, dtype=np.float64) + hv = np.zeros((ell_max + 1) ** 2, dtype=np.float64) + hextra = np.zeros(ell_max + 2, dtype=np.float64) + + _step_1(hwedge) + _step_2(g, h, ell_max, mp_max, hwedge, hextra, hv, expi_beta) + _step_3(a, b, ell_max, mp_max, hwedge, hextra, expi_beta) + _step_4(d, ell_max, mp_max, hwedge, hv) + _step_5(d, ell_max, mp_max, hwedge, hv) + + z_alpha_powers = _complex_powers(z_alpha, ell_max) + z_gamma_powers = _complex_powers(z_gamma, ell_max) + out = result[index] + for ell in range(0, ell_max + 1): + for mp in range(-ell, 0): + i_d = _wigner_d_index(ell, mp, -ell, 0, mp_max) + for m in range(-ell, 0): + i_h = _wigner_h_index(ell, mp, m, mp_max) + out[i_d] = ( + _epsilon(mp) + * _epsilon(-m) + * hwedge[i_h] + * z_gamma_powers[-m].conjugate() + * z_alpha_powers[-mp].conjugate() + ) + i_d += 1 + for m in range(0, ell + 1): + i_h = _wigner_h_index(ell, mp, m, mp_max) + out[i_d] = ( + _epsilon(mp) + * _epsilon(-m) + * hwedge[i_h] + * z_gamma_powers[m] + * z_alpha_powers[-mp].conjugate() + ) + i_d += 1 + for mp in range(0, ell + 1): + i_d = _wigner_d_index(ell, mp, -ell, 0, mp_max) + for m in range(-ell, 0): + i_h = _wigner_h_index(ell, mp, m, mp_max) + out[i_d] = ( + _epsilon(mp) + * _epsilon(-m) + * hwedge[i_h] + * z_gamma_powers[-m].conjugate() + * z_alpha_powers[mp] + ) + i_d += 1 + for m in range(0, ell + 1): + i_h = _wigner_h_index(ell, mp, m, mp_max) + out[i_d] = ( + _epsilon(mp) + * _epsilon(-m) + * hwedge[i_h] + * z_gamma_powers[m] + * z_alpha_powers[mp] + ) + i_d += 1 + + return result + + +def compute_complex_wigner_d_matrices( + ell_max: int, + angles: tuple[np.ndarray, np.ndarray, np.ndarray], +) -> dict[int, np.ndarray]: + """Return complex Wigner-D matrices for ell in ``[0, ell_max]`` at the given ZYZ + angles. + + :param ell_max: maximum angular-momentum order + :param angles: ``(alpha, beta, gamma)`` ZYZ Euler-angle arrays of matching shape + :return: ``{ell: array of shape (*angles[0].shape, 2*ell+1, 2*ell+1)}`` + """ + alpha, beta, gamma = angles + raw = _compute_wigner_d_complex(ell_max, alpha, beta, gamma) + matrices: dict[int, np.ndarray] = {} + for ell in range(ell_max + 1): + shape = alpha.shape + (2 * ell + 1, 2 * ell + 1) + block = np.zeros(shape, dtype=np.complex128) + for mp in range(-ell, ell + 1): + for m in range(-ell, ell + 1): + block[..., mp + ell, m + ell] = raw[ + ..., _wigner_d_index(ell, mp, m, 0, ell_max) + ] + matrices[ell] = block + return matrices + + +def compute_real_wigner_d_matrices( + ell_max: int, + angles: tuple[np.ndarray, np.ndarray, np.ndarray], + complex_to_real: dict[int, np.ndarray], +) -> dict[int, torch.Tensor]: + """Convert complex Wigner-D matrices to real ones using the provided + change-of-basis. + + :param ell_max: maximum angular-momentum order + :param angles: ``(alpha, beta, gamma)`` ZYZ Euler-angle arrays + :param complex_to_real: ``{ell: (2*ell+1, 2*ell+1)}`` unitary transform from complex + to real spherical harmonics + :return: ``{ell: real tensor of shape (*angles[0].shape, 2*ell+1, 2*ell+1)}`` + :raises ValueError: if imaginary residuals exceed numerical tolerance + """ + complex_matrices = compute_complex_wigner_d_matrices(ell_max, angles) + real_matrices: dict[int, torch.Tensor] = {} + for ell, matrix in complex_matrices.items(): + transform = complex_to_real[ell] + matrix = np.einsum("ij,...jk,kl->...il", transform.conj(), matrix, transform.T) + # Recursion accumulates floating-point noise that grows with ell, so scale + # the tolerance by the magnitude of the matrix instead of using a fixed atol. + scale = float(np.max(np.abs(matrix.real))) if matrix.size else 1.0 + atol = max(1e-9, scale * 1e-10) + if not np.allclose(matrix.imag, 0.0, atol=atol): + raise ValueError("real Wigner matrix conversion produced complex values") + real_matrices[ell] = torch.from_numpy(matrix.real) + return real_matrices + + +@functools.lru_cache(maxsize=None) +def _complex_to_real_spherical_harmonics_transform(ell: int) -> np.ndarray: + """ + Generate the transformation matrix from complex spherical harmonics to real + spherical harmonics for a given l. Returns a transformation matrix of shape ``(2l+1, + 2l+1)``. + """ + if ell < 0 or not isinstance(ell, int): + raise ValueError("l must be a non-negative integer.") + + size = 2 * ell + 1 + T = np.zeros((size, size), dtype=complex) + + for m in range(-ell, ell + 1): + m_index = m + ell + if m > 0: + T[m_index, ell + m] = 1 / np.sqrt(2) * (-1) ** m + T[m_index, ell - m] = 1 / np.sqrt(2) + elif m < 0: + T[m_index, ell + abs(m)] = -1j / np.sqrt(2) * (-1) ** m + T[m_index, ell - abs(m)] = 1j / np.sqrt(2) + else: + T[m_index, ell] = 1 + + return T + + +def compute_real_wigner_matrices( + o3_lambda_max: int, + angles: tuple[np.ndarray, np.ndarray, np.ndarray], +) -> dict[int, torch.Tensor]: + """Build the real Wigner-D matrices for ``ell = 0..o3_lambda_max`` at the given + ZYZ Euler angles, using the cached complex-to-real transform per ell.""" + complex_to_real = { + ell: _complex_to_real_spherical_harmonics_transform(ell) + for ell in range(o3_lambda_max + 1) + } + return compute_real_wigner_d_matrices(o3_lambda_max, angles, complex_to_real) + + +def compute_wigner_batch( + ell_max: int, + angles: tuple[np.ndarray, np.ndarray, np.ndarray], + *, + device: torch.device, + dtype: torch.dtype, +) -> dict[int, torch.Tensor]: + """Real Wigner-D matrices for ``ell = 0..ell_max`` at the given angles, cast to + the requested device and dtype.""" + return { + ell: tensor.to(device=device, dtype=dtype) + for ell, tensor in compute_real_wigner_matrices(ell_max, angles).items() + } diff --git a/python/metatomic_torch/metatomic/torch/_jit_compat.py b/python/metatomic_torch/metatomic/torch/_jit_compat.py new file mode 100644 index 000000000..d6466203f --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/_jit_compat.py @@ -0,0 +1,13 @@ +import functools + + +def _identity_decorator(func): + return func + + +try: + import numba as _numba +except ImportError: # pragma: no cover + jit = _identity_decorator +else: + jit = functools.partial(_numba.njit, cache=True) From 621c3a269f20ba1fdc449fe0cc03eb7c01578be6 Mon Sep 17 00:00:00 2001 From: ppegolo Date: Fri, 26 Jun 2026 21:48:48 +0200 Subject: [PATCH 2/3] Add O(3) augmentation for Systems and TensorMaps `apply_transformations` applies a batch of per-system O(3) matrices (proper or improper rotations) to a list of Systems and their target/extra-data TensorMaps simultaneously, deriving the needed Wigner-D matrices from the matrices themselves; `random_rotations` samples a uniform O(3) batch. Each component axis is transformed by tensor type inferred from its name: Cartesian axes (xyz/xyz_1/xyz_2) are contracted with R directly (so improper rotations flip them), and spherical axes (o3_mu/_1/_2) with the Wigner-D matrix of their o3_lambda, plus a (-1)^l * sigma parity per spherical axis when R is improper. A single TensorMap may mix scalar, Cartesian and spherical blocks, and gradients are transformed as blocks with extra axes. Value rows are routed to systems by their "system" label (remapping arbitrary dataset indices onto the provided systems); gradients are routed by the parent block's "system" label via the gradient "sample" column. System geometry, registered per-atom data and neighbor-list vectors are rotated too. The public entry point validates that the transformations are 3x3 and orthogonal and that transformations, systems and TensorMaps share a dtype and device. --- .../metatomic/torch/__init__.py | 1 + .../metatomic/torch/_augmentation/__init__.py | 667 ++++++++++++++++++ 2 files changed, 668 insertions(+) create mode 100644 python/metatomic_torch/metatomic/torch/_augmentation/__init__.py diff --git a/python/metatomic_torch/metatomic/torch/__init__.py b/python/metatomic_torch/metatomic/torch/__init__.py index a8bf363aa..dbbe3eb6d 100644 --- a/python/metatomic_torch/metatomic/torch/__init__.py +++ b/python/metatomic_torch/metatomic/torch/__init__.py @@ -52,6 +52,7 @@ pick_output = torch.ops.metatomic.pick_output from . import ase_calculator # noqa: F401 +from ._augmentation import apply_transformations, random_rotations # noqa: F401 from .model import ( # noqa: F401 AtomisticModel, ModelInterface, diff --git a/python/metatomic_torch/metatomic/torch/_augmentation/__init__.py b/python/metatomic_torch/metatomic/torch/_augmentation/__init__.py new file mode 100644 index 000000000..0c347b95c --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/_augmentation/__init__.py @@ -0,0 +1,667 @@ +""" +O(3) augmentation: apply batched rotation/inversion transformations to Systems and +TensorMaps. + +metatomic stores positions and cell as row vectors (shape ``(N, 3)``), so a rotation +matrix ``R`` (3x3) acts as ``x @ R.T`` throughout this module. +""" + +import metatensor.torch as mts +import numpy as np +import torch +from metatensor.torch import TensorBlock, TensorMap + +from .. import System, register_autograd_neighbors +from ._wigner import compute_wigner_batch + + +# Component-axis names recognised by the augmentation machinery. Cartesian axes are +# rotated by ``R`` directly (so improper rotations flip vectors); spherical axes are +# rotated by the Wigner-D matrix of the matching ``o3_lambda`` plus an explicit +# ``(-1)^ell * sigma`` inversion parity factor. +_CARTESIAN_AXES = frozenset({"xyz", "xyz_1", "xyz_2"}) +_SPHERICAL_AXIS_TO_LAMBDA = { + "o3_mu": "o3_lambda", + "o3_mu_1": "o3_lambda_1", + "o3_mu_2": "o3_lambda_2", +} +_SPHERICAL_AXIS_TO_SIGMA = { + "o3_mu": "o3_sigma", + "o3_mu_1": "o3_sigma_1", + "o3_mu_2": "o3_sigma_2", +} + +# einsum index letters for the per-axis component contraction (input lower, output +# upper). Six axes is far more than any realistic block (value + gradient) needs. +_EINSUM_IN = "abcdef" +_EINSUM_OUT = "ABCDEF" + + +def _row_indices_from_system_ids( + system_ids: torch.Tensor, + n_systems: int, +) -> list[torch.Tensor]: + """Group row indices by system, mapping each distinct ``system`` id to one system. + + The ``"system"`` column is not necessarily a 0-based batch index: metatrain keeps + each structure's original dataset index there, so the labels in a batch are an + arbitrary set (e.g. ``[67, 92, 38, ...]``). The i-th distinct id, **in sorted + order**, is taken to be system ``i``; when the ids already span ``[0, n_systems)`` + this reduces to the identity grouping. + + .. note:: + + This pairs the i-th sorted label with ``systems[i]``/``transformations[i]``, + which is only consistent with the (positional) system rotation when the caller + passes ``systems`` ordered by their label. metatrain satisfies this. + + :param system_ids: 1-D tensor with one system id per row + :param n_systems: number of systems being augmented + :return: list of length ``n_systems``; entry ``i`` selects the rows of system ``i`` + :raises ValueError: if an id is negative or there are more distinct ids than systems + """ + device = system_ids.device + if len(system_ids) == 0: + return [ + torch.zeros(0, dtype=torch.long, device=device) for _ in range(n_systems) + ] + + min_id = int(system_ids.min().item()) + max_id = int(system_ids.max().item()) + if min_id < 0: + raise ValueError("Encountered output samples with negative system indices.") + + if max_id < n_systems: + # ids already index into [0, n_systems): group directly (sparse rows are fine). + return [ + torch.nonzero(system_ids == i, as_tuple=False).reshape(-1) + for i in range(n_systems) + ] + + # ids are arbitrary labels: map the sorted distinct ids onto systems 0..n_systems-1. + unique_ids: list[int] = sorted(set(system_ids.tolist())) + if len(unique_ids) > n_systems: + raise ValueError( + f"TensorMap block has {len(unique_ids)} distinct system indices " + f"but only {n_systems} systems were provided." + ) + id_to_pos = {uid: pos for pos, uid in enumerate(unique_ids)} + remapped = torch.tensor( + [id_to_pos[int(sid)] for sid in system_ids.tolist()], + dtype=torch.long, + device=device, + ) + return [ + torch.nonzero(remapped == i, as_tuple=False).reshape(-1) + for i in range(n_systems) + ] + + +def _block_row_indices_by_system( + block: TensorBlock, + n_systems: int, +) -> list[torch.Tensor]: + """Return row-index tensors into ``block.values``, one per system. + + With a single system every row belongs to it (any ``"system"`` label is ignored). + With several systems the ``"system"`` column is required and is interpreted by + :func:`_row_indices_from_system_ids`. + + :param block: block whose ``samples`` may contain a ``"system"`` column + :param n_systems: number of systems being augmented + :return: list of length ``n_systems``; entry ``i`` selects all rows of system ``i`` + """ + if "system" not in block.samples.names: + if n_systems == 1: + return [torch.arange(block.values.shape[0], device=block.values.device)] + raise ValueError( + "Rotational augmentation expects output samples to include a 'system' " + "dimension when transforming multiple systems." + ) + system_ids = block.samples.column("system").to(dtype=torch.long) + return _row_indices_from_system_ids(system_ids, n_systems) + + +def _gradient_row_indices_by_system( + grad_block: TensorBlock, + parent_block: TensorBlock, + n_systems: int, +) -> list[torch.Tensor]: + """Return row-index tensors into a gradient block, one per system. + + Gradient samples carry a ``"sample"`` column indexing into the parent block (the + metatensor convention) rather than their own ``"system"`` column, so the system of + each gradient row is read from the parent block's ``"system"`` column. + + :param grad_block: gradient block to route + :param parent_block: the value block this gradient is attached to + :param n_systems: number of systems being augmented + :return: list of length ``n_systems``; entry ``i`` selects the gradient rows of + system ``i`` + """ + if n_systems == 1: + return [ + torch.arange(grad_block.values.shape[0], device=grad_block.values.device) + ] + if "sample" not in grad_block.samples.names: + raise ValueError( + "Gradient samples are expected to include a 'sample' dimension indexing " + "into the parent block." + ) + if "system" not in parent_block.samples.names: + raise ValueError( + "Rotational augmentation expects parent samples to include a 'system' " + "dimension when transforming gradients of multiple systems." + ) + parent_system = parent_block.samples.column("system").to(dtype=torch.long) + sample_index = grad_block.samples.column("sample").to(dtype=torch.long) + return _row_indices_from_system_ids(parent_system[sample_index], n_systems) + + +def _has_spherical_axis(tmap: TensorMap) -> bool: + """Whether any block of ``tmap`` carries a spherical component axis. + + :param tmap: TensorMap to inspect + :return: ``True`` if any block has an ``o3_mu``/``o3_mu_1``/``o3_mu_2`` component + """ + for block in tmap.blocks(): + for component in block.components: + if component.names[0] in _SPHERICAL_AXIS_TO_LAMBDA: + return True + return False + + +def _transform_single_system( + system: System, + transformation: torch.Tensor, +) -> System: + """Apply an O(3) transformation to a single System. + + Rotates positions, cell vectors, any registered per-atom data, and all neighbor-list + displacement vectors. Types and pbc flags are unchanged. + + Registered data is rotated by the same machinery as targets (see + :func:`_transform_tmap`), as a single-system batch: scalar blocks pass through and + Cartesian (``xyz``/``xyz_1``/``xyz_2``) blocks are rotated by ``transformation``. + No Wigner-D matrices are computed for System data, so data carrying a spherical + (``o3_mu``) axis cannot be rotated and is passed through unchanged. + + :param system: input system + :param transformation: (3, 3) rotation or improper-rotation matrix + :return: new System with transformed geometry + """ + new_system = System( + positions=system.positions @ transformation.T, + types=system.types, + cell=system.cell @ transformation.T, + pbc=system.pbc, + ) + for data_name in system.known_data(): + data = system.get_data(data_name) + if _has_spherical_axis(data): + # Rotating spherical data needs Wigner-D matrices, which are not computed + # for System data; rather than rotate it incorrectly we pass it through + # unchanged (attaching spherical data to a System is allowed but exotic). + new_system.add_data(data_name, data) + else: + new_system.add_data( + data_name, + _transform_tmap(data_name, data, [system], [transformation], {}), + ) + for options in system.known_neighbor_lists(): + neighbors = mts.detach_block(system.get_neighbor_list(options)) + # neighbor vectors are stored as (N, 3, 1); squeeze/unsqueeze around the matmul + neighbors.values[:] = ( + neighbors.values.squeeze(-1) @ transformation.T + ).unsqueeze(-1) + register_autograd_neighbors(new_system, neighbors) + new_system.add_neighbor_list(options, neighbors) + return new_system + + +def _contract_component_axes( + values: torch.Tensor, + matrices: list[torch.Tensor], +) -> torch.Tensor: + """Rotate each component axis of ``values`` by its matrix. + + ``values`` has shape ``(n_rows, d_1, ..., d_k, n_properties)`` and ``matrices[j]`` + (shape ``(d_j, d_j)``) is contracted with component axis ``j`` as + ``out[..., A, ...] = sum_a matrices[j][A, a] * values[..., a, ...]``. + + :param values: values tensor of a value or gradient block + :param matrices: one rotation matrix per component axis (empty for scalars) + :return: rotated values, same shape as the input + """ + if len(matrices) == 0: + return values + n_axes = len(matrices) + in_subscript = "i" + _EINSUM_IN[:n_axes] + "p" + out_subscript = "i" + _EINSUM_OUT[:n_axes] + "p" + matrix_subscripts = [_EINSUM_OUT[j] + _EINSUM_IN[j] for j in range(n_axes)] + equation = ",".join(matrix_subscripts + [in_subscript]) + "->" + out_subscript + return torch.einsum(equation, *matrices, values) + + +def _axis_matrices_and_parity( + name: str, + components: list, + key, + R: torch.Tensor, + wigner_D_matrices: dict[int, list[torch.Tensor]], + system_index: int, + is_inverted: bool, +) -> tuple[list[torch.Tensor], int]: + """Pick the rotation matrix for each component axis and the O(3) inversion parity. + + Cartesian axes (``xyz``/``xyz_1``/``xyz_2``) use ``R`` directly, so improper + rotations flip vectors automatically. Spherical axes (``o3_mu``/``o3_mu_1``/ + ``o3_mu_2``) use the proper-rotation Wigner-D matrix of the matching ``o3_lambda`` + plus a ``(-1)^ell * sigma`` parity factor accumulated whenever ``R`` is improper. + + :param name: TensorMap name, used only in error messages + :param components: component :class:`Labels` of the block (or gradient block) + :param key: the parent block's key, supplying ``o3_lambda``/``o3_sigma`` values + :param R: this system's (3, 3) transformation matrix + :param wigner_D_matrices: ``{ell: [D_0, ..., D_{N-1}]}`` real Wigner-D matrices + :param system_index: index selecting this system's Wigner-D matrix + :param is_inverted: whether ``R`` is an improper rotation (``det(R) < 0``) + :return: ``(matrices, parity)`` with one matrix per component axis + :raises ValueError: if a component axis is neither Cartesian nor spherical + """ + matrices: list[torch.Tensor] = [] + parity = 1 + for component in components: + axis_name = component.names[0] + if axis_name in _CARTESIAN_AXES: + matrices.append(R) + elif axis_name in _SPHERICAL_AXIS_TO_LAMBDA: + ell = int(key[_SPHERICAL_AXIS_TO_LAMBDA[axis_name]]) + matrices.append(wigner_D_matrices[ell][system_index]) + if is_inverted: + sigma = int(key[_SPHERICAL_AXIS_TO_SIGMA[axis_name]]) + parity *= ((-1) ** ell) * sigma + else: + raise ValueError( + f"TensorMap '{name}' has component axis '{axis_name}', which is " + "neither a Cartesian ('xyz'/'xyz_1'/'xyz_2') nor spherical " + "('o3_mu'/'o3_mu_1'/'o3_mu_2') axis; rotational augmentation cannot " + "transform it." + ) + return matrices, parity + + +def _transform_component_values( + name: str, + values: torch.Tensor, + components: list, + key, + row_indices: list[torch.Tensor], + transformations: list[torch.Tensor], + wigner_D_matrices: dict[int, list[torch.Tensor]], +) -> torch.Tensor: + """Rotate the values of a single value or gradient block, per system. + + :param name: TensorMap name, used only in error messages + :param values: the block's values tensor + :param components: the block's component :class:`Labels` + :param key: the parent block's key (for spherical ``o3_lambda``/``o3_sigma``) + :param row_indices: per-system row indices into ``values`` + :param transformations: per-system (3, 3) transformation matrices + :param wigner_D_matrices: ``{ell: [D_0, ..., D_{N-1}]}`` real Wigner-D matrices + :return: new values tensor with each system's rows rotated + """ + new_values = values.clone() + for system_index, rows in enumerate(row_indices): + if len(rows) == 0: + continue + R = transformations[system_index] + is_inverted = bool(torch.det(R) < 0) + matrices, parity = _axis_matrices_and_parity( + name, components, key, R, wigner_D_matrices, system_index, is_inverted + ) + rotated = _contract_component_axes(values[rows], matrices) + if parity != 1: + rotated = rotated * parity + new_values[rows] = rotated + return new_values + + +def _transform_block( + name: str, + key, + block: TensorBlock, + n_systems: int, + transformations: list[torch.Tensor], + wigner_D_matrices: dict[int, list[torch.Tensor]], +) -> TensorBlock: + """Rotate one block and all of its gradients. + + Value rows are routed to systems by their ``"system"`` column; every gradient is + routed by the parent block's ``"system"`` label via + :func:`_gradient_row_indices_by_system` (the gradient's ``"sample"`` column indexes + the parent). Each gradient reuses the parent ``key``, so a spherical component axis + inherited from the value keeps the value's ``o3_lambda``/``o3_sigma``, while the + extra Cartesian gradient-direction axis is rotated by ``R``. + + :param name: TensorMap name, used only in error messages + :param key: the block's key + :param block: the value block to rotate + :param n_systems: number of systems being augmented + :param transformations: per-system (3, 3) transformation matrices + :param wigner_D_matrices: ``{ell: [D_0, ..., D_{N-1}]}`` real Wigner-D matrices + :return: new block with rotated values and gradients + """ + row_indices = _block_row_indices_by_system(block, n_systems) + new_block = TensorBlock( + values=_transform_component_values( + name, + block.values, + block.components, + key, + row_indices, + transformations, + wigner_D_matrices, + ), + samples=block.samples, + components=block.components, + properties=block.properties, + ) + for gradient_name in block.gradients_list(): + grad_block = block.gradient(gradient_name) + grad_rows = _gradient_row_indices_by_system(grad_block, block, n_systems) + new_block.add_gradient( + gradient_name, + TensorBlock( + values=_transform_component_values( + name, + grad_block.values, + grad_block.components, + key, + grad_rows, + transformations, + wigner_D_matrices, + ), + samples=grad_block.samples, + components=grad_block.components, + properties=grad_block.properties, + ), + ) + return new_block + + +def _transform_tmap( + name: str, + tmap: TensorMap, + systems: list[System], + transformations: list[torch.Tensor], + wigner_D_matrices: dict[int, list[torch.Tensor]], +) -> TensorMap: + """Rotate every block (and its gradients) of a TensorMap. + + The tensor character of each component axis is inferred from its name, so a single + TensorMap may freely mix scalar, Cartesian and spherical blocks, and blocks may + carry gradients. + + :param name: used only in error messages + :param tmap: TensorMap to rotate + :param systems: input systems; length gives the batch size + :param transformations: per-system (3, 3) transformation matrices + :param wigner_D_matrices: ``{ell: [D_0, ..., D_{N-1}]}`` real Wigner-D matrices + :return: new TensorMap with rotated values and gradients + :raises ValueError: if a component axis name is not recognised + """ + n_systems = len(systems) + new_blocks = [ + _transform_block( + name, key, block, n_systems, transformations, wigner_D_matrices + ) + for key, block in tmap.items() + ] + return TensorMap(keys=tmap.keys, blocks=new_blocks) + + +def _apply_transformations( + systems: list[System], + targets: dict[str, TensorMap], + transformations: list[torch.Tensor], + wigner_D_matrices: dict[int, list[torch.Tensor]], + extra_data: dict[str, TensorMap] | None = None, +) -> tuple[list[System], dict[str, TensorMap], dict[str, TensorMap]]: + """Apply a batch of O(3) transformations to systems and TensorMaps simultaneously. + + Each element of ``transformations`` is a (3, 3) matrix (rotation or improper + rotation) applied to the corresponding system. TensorMaps in ``targets`` and + ``extra_data`` are transformed per system (Cartesian axes by ``R``, spherical axes + by the matching Wigner-D matrix), except that keys ending in ``"_mask"`` pass + through unchanged. + + :param systems: input systems, one per transformation + :param targets: TensorMaps to transform (e.g. model predictions to back-rotate) + :param transformations: per-system (3, 3) transformation matrices + :param wigner_D_matrices: ``{ell: [D_0, ..., D_{N-1}]}`` real Wigner-D matrices + :param extra_data: additional TensorMaps to transform alongside targets + :return: ``(new_systems, new_targets, new_extra_data)`` + """ + new_systems = [ + _transform_single_system(system, R) + for system, R in zip(systems, transformations, strict=True) + ] + + new_targets: dict[str, TensorMap] = { + name: _transform_tmap(name, tmap, systems, transformations, wigner_D_matrices) + for name, tmap in targets.items() + } + + new_extra_data: dict[str, TensorMap] = {} + if extra_data is not None: + for key, value in extra_data.items(): + if key.endswith("_mask"): + new_extra_data[key] = value + else: + new_extra_data[key] = _transform_tmap( + key, value, systems, transformations, wigner_D_matrices + ) + + return new_systems, new_targets, new_extra_data + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +_SPHERICAL_KEY_NAMES = frozenset({"o3_lambda", "o3_lambda_1", "o3_lambda_2"}) + + +def _rotations_to_zyz( + rotations: list[torch.Tensor], +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Decompose a list of O(3) matrices into ZYZ Euler angles :math:`(\\alpha, \\beta, + \\gamma)`. + + For improper rotations (det < 0) the proper part ``-R`` is decomposed; the inversion + parity factor is handled separately when applying Wigner-D matrices. + + :param rotations: list of (3, 3) orthogonal tensors + :return: ``(alphas, betas, gammas)`` as 1-D numpy arrays of length + ``len(rotations)`` + """ + alphas = np.empty(len(rotations)) + betas = np.empty(len(rotations)) + gammas = np.empty(len(rotations)) + for i, R in enumerate(rotations): + proper_R = R if torch.det(R) > 0 else -R + # R = Rz(alpha) Ry(beta) Rz(gamma): element [2,2] = cos(beta) + cos_beta = float(proper_R[2, 2].clamp(-1.0, 1.0)) + beta = np.arccos(cos_beta) + sin_beta = np.sin(beta) + if abs(sin_beta) < 1e-10: + # Gimbal lock: only alpha +/- gamma is defined; fix gamma=0 + if cos_beta > 0: + alpha = float(torch.atan2(proper_R[1, 0], proper_R[0, 0])) + else: + alpha = float(torch.atan2(-proper_R[1, 0], -proper_R[0, 0])) + gamma = 0.0 + else: + # R[0,2]=cos(alpha)*sin(beta), R[1,2]=sin(alpha)*sin(beta): alpha via atan2 + # R[2,1]=sin(beta)*sin(gamma), R[2,0]=-sin(beta)*cos(gamma): gamma via atan2 + alpha = float(torch.atan2(proper_R[1, 2], proper_R[0, 2])) + gamma = float(torch.atan2(proper_R[2, 1], -proper_R[2, 0])) + alphas[i] = alpha + betas[i] = beta + gammas[i] = gamma + return alphas, betas, gammas + + +def random_rotations( + n: int, + *, + device: torch.device, + dtype: torch.dtype, + include_inversions: bool = False, + generator: torch.Generator | None = None, +) -> list[torch.Tensor]: + """Sample ``n`` uniformly distributed O(3) transformations. + + Rotations are sampled from the Haar measure on SO(3) via random unit quaternions. + When ``include_inversions`` is ``True``, each matrix is independently negated with + probability 0.5, giving a uniform distribution over the full O(3) group. + + :param n: number of transformations to generate + :param device: target device for the output tensors + :param dtype: target dtype for the output tensors + :param include_inversions: if ``True``, sample from O(3) instead of SO(3) + :param generator: optional :class:`torch.Generator` for reproducible sampling; when + ``None`` the global RNG is used + :return: list of ``n`` orthogonal (3, 3) tensors + """ + q = torch.randn(n, 4, device=device, dtype=dtype, generator=generator) + q = q / q.norm(dim=1, keepdim=True) + w, x, y, z = q.unbind(1) + # Quaternion to rotation matrix (standard formula) + R = torch.stack( + [ + 1 - 2 * (y * y + z * z), + 2 * (x * y - w * z), + 2 * (x * z + w * y), + 2 * (x * y + w * z), + 1 - 2 * (x * x + z * z), + 2 * (y * z - w * x), + 2 * (x * z - w * y), + 2 * (y * z + w * x), + 1 - 2 * (x * x + y * y), + ], + dim=1, + ).reshape(n, 3, 3) + if include_inversions: + signs = torch.randint(0, 2, (n,), device=device, generator=generator) * 2 - 1 + R = R * signs.to(dtype=dtype).reshape(n, 1, 1) + return list(R.unbind(0)) + + +def apply_transformations( + systems: list[System], + targets: dict[str, TensorMap], + transformations: list[torch.Tensor], + extra_data: dict[str, TensorMap] | None = None, +) -> tuple[list[System], dict[str, TensorMap], dict[str, TensorMap]]: + """Apply a batch of O(3) transformations to systems and TensorMaps simultaneously. + + Wigner-D matrices are derived automatically from ``transformations``; the tensor + type (scalar, Cartesian, spherical) is inferred from each TensorMap's component axis + names. Keys in ``extra_data`` that end in ``"_mask"`` pass through unchanged. + + :param systems: input systems, one per transformation + :param targets: model output TensorMaps to back-rotate (e.g. predicted energies, + forces, or spherical features) + :param transformations: per-system (3, 3) orthogonal matrices; use + :func:`random_rotations` to generate these + :param extra_data: additional TensorMaps to transform alongside ``targets`` + :return: ``(new_systems, new_targets, new_extra_data)`` + :raises ValueError: if ``len(systems) != len(transformations)``, any matrix is not a + (3, 3) orthogonal matrix, or the transformations, systems and TensorMaps do not + share a common dtype and device + """ + if len(systems) != len(transformations): + raise ValueError( + f"Expected one transformation per system, got {len(transformations)} " + f"transformations for {len(systems)} systems." + ) + for i, R in enumerate(transformations): + if R.shape != (3, 3): + raise ValueError( + f"Transformation {i} has shape {tuple(R.shape)}; expected (3, 3)." + ) + identity = torch.eye(3, device=R.device, dtype=R.dtype) + if not torch.allclose(R @ R.T, identity, atol=1e-5): + raise ValueError( + f"Transformation {i} is not orthogonal (R @ R.T deviates from I)." + ) + + if len(transformations) > 0: + # Everything is contracted with the transformations (or the Wigner-D matrices + # derived from them), so dtype and device must match throughout, otherwise the + # matmuls below fail with a much less helpful message. + reference = transformations[0] + for i, R in enumerate(transformations): + if R.dtype != reference.dtype or R.device != reference.device: + raise ValueError( + f"Transformation {i} has dtype/device ({R.dtype}, {R.device}) " + f"differing from transformation 0 ({reference.dtype}, " + f"{reference.device}); all transformations must agree." + ) + for i, system in enumerate(systems): + if ( + system.positions.dtype != reference.dtype + or system.positions.device != reference.device + ): + raise ValueError( + f"System {i} has positions with dtype/device " + f"({system.positions.dtype}, {system.positions.device}) differing " + f"from the transformations ({reference.dtype}, " + f"{reference.device})." + ) + for label, tmap in list(targets.items()) + ( + [(k, v) for k, v in extra_data.items() if not k.endswith("_mask")] + if extra_data is not None + else [] + ): + for block in tmap.blocks(): + if ( + block.values.dtype != reference.dtype + or block.values.device != reference.device + ): + raise ValueError( + f"TensorMap '{label}' has values with dtype/device " + f"({block.values.dtype}, {block.values.device}) differing " + f"from the transformations ({reference.dtype}, " + f"{reference.device})." + ) + + # Determine the highest angular momentum present across all TensorMaps + ell_max = 0 + all_tmaps = list(targets.values()) + if extra_data is not None: + all_tmaps += [v for k, v in extra_data.items() if not k.endswith("_mask")] + for tmap in all_tmaps: + for name in tmap.keys.names: + if name in _SPHERICAL_KEY_NAMES: + col = tmap.keys.column(name) + if len(col) > 0: + ell_max = max(ell_max, int(col.max())) + + if len(transformations) > 0: + device = transformations[0].device + dtype = transformations[0].dtype + angles = _rotations_to_zyz(transformations) + wigner_batch = compute_wigner_batch(ell_max, angles, device=device, dtype=dtype) + # Unbind the batch dim: {ell: (N, 2l+1, 2l+1)} to {ell: [D_0, ..., D_{N-1}]} + wigner_D_matrices: dict[int, list[torch.Tensor]] = { + ell: list(D.unbind(0)) for ell, D in wigner_batch.items() + } + else: + wigner_D_matrices = {} + + return _apply_transformations( + systems, targets, transformations, wigner_D_matrices, extra_data + ) From 01f7b47a5e89c292a2b06e946a3bf2a40a24713a Mon Sep 17 00:00:00 2001 From: ppegolo Date: Fri, 26 Jun 2026 21:49:09 +0200 Subject: [PATCH 3/3] Add tests for O(3) augmentation Cross-checks the spherical (Wigner-D) path against the trivially-correct Cartesian path for general rotations, including both gimbal-lock branches of the ZYZ decomposition (beta = 0 and beta = pi) and improper rotations. Covers gradient rotation and per-system routing (including unsorted parent "system" labels and the arbitrary-dataset-index remap), System geometry, registered scalar/Cartesian/spherical data, neighbor-list vectors, `random_rotations` (orthogonality, inversions, reproducible generator) and the public-API validation errors. --- python/metatomic_torch/tests/augmentation.py | 739 +++++++++++++++++++ 1 file changed, 739 insertions(+) create mode 100644 python/metatomic_torch/tests/augmentation.py diff --git a/python/metatomic_torch/tests/augmentation.py b/python/metatomic_torch/tests/augmentation.py new file mode 100644 index 000000000..9bcaba8e9 --- /dev/null +++ b/python/metatomic_torch/tests/augmentation.py @@ -0,0 +1,739 @@ +import numpy as np +import pytest +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + +from metatomic.torch import ( + NeighborListOptions, + System, + apply_transformations, + random_rotations, + register_autograd_neighbors, +) +from metatomic.torch._augmentation import _apply_transformations, _rotations_to_zyz +from metatomic.torch._augmentation._wigner import compute_wigner_batch + + +def _axis_angle(axis, theta): + """A general (non-degenerate, beta != 0) rotation matrix from axis and angle.""" + axis = np.asarray(axis, dtype=float) + axis = axis / np.linalg.norm(axis) + x, y, z = axis + c, s = np.cos(theta), np.sin(theta) + one_minus_c = 1.0 - c + return np.array( + [ + [ + c + x * x * one_minus_c, + x * y * one_minus_c - z * s, + x * z * one_minus_c + y * s, + ], + [ + y * x * one_minus_c + z * s, + c + y * y * one_minus_c, + y * z * one_minus_c - x * s, + ], + [ + z * x * one_minus_c - y * s, + z * y * one_minus_c + x * s, + c + z * z * one_minus_c, + ], + ] + ) + + +# Change of basis from Cartesian (x, y, z) to real ell=1 spherical harmonics, ordered +# (m=-1, 0, +1) = (y, z, x). The real ell=1 Wigner-D matrix satisfies D1 = C @ R @ C.T, +# which lets us cross-check the spherical path (Wigner-D) against the trivially-correct +# Cartesian path under arbitrary rotations. +_CART_TO_SPHERICAL_L1 = torch.tensor( + [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], dtype=torch.float64 +) + + +def _make_system(types, positions=None, cell=None, pbc=None): + n_atoms = len(types) + if positions is None: + positions = torch.zeros((n_atoms, 3), dtype=torch.float64) + if cell is None: + cell = torch.zeros((3, 3), dtype=torch.float64) + if pbc is None: + pbc = torch.tensor([False, False, False]) + return System( + types=torch.tensor(types, dtype=torch.int32), + positions=positions, + cell=cell, + pbc=pbc, + ) + + +def _rotation_batch(alphas): + transformations = [] + for alpha in alphas: + transformations.append( + torch.tensor( + [ + [np.cos(alpha), -np.sin(alpha), 0.0], + [np.sin(alpha), np.cos(alpha), 0.0], + [0.0, 0.0, 1.0], + ], + dtype=torch.float64, + ) + ) + + zeros = np.zeros(len(alphas)) + wigner_D_matrices = { + ell: list(matrix.unbind(0)) + for ell, matrix in compute_wigner_batch( + 1, + (np.asarray(alphas), zeros, zeros), + device=torch.device("cpu"), + dtype=torch.float64, + ).items() + } + return transformations, wigner_D_matrices + + +def _row_indices(samples, n_systems): + system_ids = samples.column("system").to(dtype=torch.long) + return [ + torch.nonzero(system_ids == system_index, as_tuple=False).reshape(-1) + for system_index in range(n_systems) + ] + + +def test_sparse_atomic_basis_rank2_augmentation_with_missing_system_rows(): + """Rank-2 spherical features rotate on each mu axis, and empty system row-groups + are a no-op. + + The single block carries two component axes (``o3_mu_1``, ``o3_mu_2``) and all of + its rows belong to system 0; system 1 contributes no rows (the "missing system + rows"). The test confirms only system 0's rows are transformed -- by + ``D1 @ v @ D1.T`` on the two mu axes -- that the empty row-group for system 1 leaves + the block untouched, and that the values actually change (guarding against an + accidental identity transform). + """ + systems = [_make_system([1, 1]), _make_system([8, 8])] + transformations, wigner_D_matrices = _rotation_batch([np.pi / 2, np.pi]) + + components = [ + Labels( + ["o3_mu_1"], + torch.arange(-1, 2, dtype=torch.int32).reshape(-1, 1), + ), + Labels( + ["o3_mu_2"], + torch.arange(-1, 2, dtype=torch.int32).reshape(-1, 1), + ), + ] + property_labels = Labels( + ["n_1", "n_2"], + torch.tensor([[0, 0]], dtype=torch.int32), + ) + values = torch.arange(18, dtype=torch.float64).reshape(2, 3, 3, 1) + tensor = TensorMap( + Labels( + ["o3_lambda_1", "o3_lambda_2", "o3_sigma_1", "o3_sigma_2", "atom_type"], + torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.int32), + ), + [ + TensorBlock( + values=values, + samples=Labels( + ["system", "atom"], + torch.tensor([[0, 0], [0, 1]], dtype=torch.int32), + ), + components=components, + properties=property_labels, + ) + ], + ) + + _, augmented_targets, _ = _apply_transformations( + systems, + {"target": tensor}, + transformations, + wigner_D_matrices, + ) + augmented = augmented_targets["target"] + + expected_values = values.clone() + rows = _row_indices(tensor.block().samples, len(systems))[0] + expected_values[rows] = torch.einsum( + "Aa,iabp,bB->iABp", + wigner_D_matrices[1][0], + values[rows], + wigner_D_matrices[1][0].T, + ) + expected = TensorMap( + tensor.keys, + [ + TensorBlock( + values=expected_values, + samples=tensor.block().samples, + components=tensor.block().components, + properties=tensor.block().properties, + ) + ], + ) + + assert augmented.block().samples == expected.block().samples + assert torch.allclose(augmented.block().values, expected.block().values, atol=1e-12) + assert not torch.allclose(augmented.block().values, values) + + +def test_system_positions_and_cell_are_rotated(): + # Non-trivial positions and cell so the rotation is observable; verifies that + # `_apply_transformations` does not silently leave the System unchanged. + positions_a = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + dtype=torch.float64, + ) + positions_b = torch.tensor( + [[0.5, 0.5, 0.0], [1.0, -1.0, 0.5]], + dtype=torch.float64, + ) + cell_a = torch.eye(3, dtype=torch.float64) * 3.0 + cell_b = torch.eye(3, dtype=torch.float64) * 4.0 + pbc = torch.tensor([True, True, True]) + systems = [ + _make_system([1, 1, 1], positions=positions_a, cell=cell_a, pbc=pbc), + _make_system([8, 8], positions=positions_b, cell=cell_b, pbc=pbc), + ] + transformations, wigner_D_matrices = _rotation_batch([np.pi / 3, np.pi / 4]) + + new_systems, _, _ = _apply_transformations( + systems, + {}, + transformations, + wigner_D_matrices, + ) + + assert len(new_systems) == 2 + for original, new, R in zip(systems, new_systems, transformations, strict=True): + assert torch.allclose(new.positions, original.positions @ R.T, atol=1e-12) + assert torch.allclose(new.cell, original.cell @ R.T, atol=1e-12) + # types and pbc must pass through unchanged + assert torch.equal(new.types, original.types) + assert torch.equal(new.pbc, original.pbc) + + +def test_random_rotations_are_orthogonal(): + rotations = random_rotations(20, device=torch.device("cpu"), dtype=torch.float64) + assert len(rotations) == 20 + identity = torch.eye(3, dtype=torch.float64) + for R in rotations: + assert R.shape == (3, 3) + assert torch.allclose(R @ R.T, identity, atol=1e-10) + assert abs(float(torch.det(R)) - 1.0) < 1e-10 + + +def test_random_rotations_include_inversions(): + # With n=100 the probability that all determinants have the same sign is 2^{-99}. + rotations = random_rotations( + 100, device=torch.device("cpu"), dtype=torch.float64, include_inversions=True + ) + dets = torch.tensor([float(torch.det(R)) for R in rotations], dtype=torch.float64) + assert torch.allclose(dets.abs(), torch.ones(100, dtype=torch.float64), atol=1e-10) + assert (dets > 0).any() and (dets < 0).any() + + +def test_apply_transformations_raises_on_length_mismatch(): + systems = [_make_system([1])] + two_rotations = [torch.eye(3, dtype=torch.float64)] * 2 + with pytest.raises(ValueError, match="one transformation per system"): + apply_transformations(systems, {}, two_rotations) + + +def test_apply_transformations_raises_on_non_orthogonal(): + systems = [_make_system([1])] + not_orthogonal = torch.tensor( + [[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], dtype=torch.float64 + ) + with pytest.raises(ValueError, match="not orthogonal"): + apply_transformations(systems, {}, [not_orthogonal]) + + +# Rotations exercising the full ZYZ decomposition + real Wigner-D path, including both +# gimbal-lock branches of `_rotations_to_zyz`: a generic beta != 0 rotation, a z-axis +# rotation (beta == 0), a 180-degree rotation about x (beta == pi), and an improper one. +_GENERAL_ROTATIONS = [ + torch.tensor(_axis_angle([1.0, 2.0, 3.0], 0.7), dtype=torch.float64), + torch.tensor(_axis_angle([-2.0, 1.0, 0.5], 2.4), dtype=torch.float64), + torch.tensor(_axis_angle([0.0, 0.0, 1.0], 0.9), dtype=torch.float64), # beta=0 + torch.tensor(_axis_angle([1.0, 0.0, 0.0], np.pi), dtype=torch.float64), # beta=pi + # improper rotation (det = -1): a proper rotation composed with inversion + -torch.tensor(_axis_angle([0.3, -1.0, 2.0], 1.1), dtype=torch.float64), +] + + +@pytest.mark.parametrize("R", _GENERAL_ROTATIONS) +def test_general_rotation_ell1_wigner_matches_cartesian(R): + # For ell=1 the real Wigner-D matrix must equal C @ R_proper @ C.T, where R_proper + # is the proper part of R. This validates the Euler decomposition and complex->real + # conversion for general (non-degenerate) rotations, independently of the rest of + # the augmentation machinery. + proper_R = R if torch.det(R) > 0 else -R + angles = _rotations_to_zyz([R]) + D = compute_wigner_batch(1, angles, device=torch.device("cpu"), dtype=torch.float64) + expected = _CART_TO_SPHERICAL_L1 @ proper_R @ _CART_TO_SPHERICAL_L1.T + assert torch.allclose(D[1][0], expected, atol=1e-12) + + +@pytest.mark.parametrize("R", _GENERAL_ROTATIONS) +def test_general_rotation_spherical_matches_cartesian_vector(R): + # End-to-end through the public API: a Cartesian vector target and a spherical ell=1 + # (sigma=1, i.e. a true vector) target encoding the same vectors via C must stay + # related by C after augmentation. Covers general rotations *and* the inversion + # parity factor for the improper case. + systems = [_make_system([1, 8])] + cartesian_vectors = torch.randn(2, 3, 1, dtype=torch.float64) + + cartesian = TensorMap( + Labels(["_"], torch.tensor([[0]], dtype=torch.int32)), + [ + TensorBlock( + values=cartesian_vectors, + samples=Labels( + ["system", "atom"], + torch.tensor([[0, 0], [0, 1]], dtype=torch.int32), + ), + components=[ + Labels(["xyz"], torch.arange(3, dtype=torch.int32).reshape(-1, 1)) + ], + properties=Labels(["p"], torch.tensor([[0]], dtype=torch.int32)), + ) + ], + ) + # spherical encoding: w = C @ v along the component axis + spherical_values = torch.einsum( + "Aa,iap->iAp", _CART_TO_SPHERICAL_L1, cartesian_vectors + ) + spherical = TensorMap( + Labels(["o3_lambda", "o3_sigma"], torch.tensor([[1, 1]], dtype=torch.int32)), + [ + TensorBlock( + values=spherical_values, + samples=Labels( + ["system", "atom"], + torch.tensor([[0, 0], [0, 1]], dtype=torch.int32), + ), + components=[ + Labels( + ["o3_mu"], torch.arange(-1, 2, dtype=torch.int32).reshape(-1, 1) + ) + ], + properties=Labels(["p"], torch.tensor([[0]], dtype=torch.int32)), + ) + ], + ) + + _, out, _ = apply_transformations( + systems, {"cart": cartesian, "spher": spherical}, [R] + ) + + rotated_cart = out["cart"].block().values + rotated_spher = out["spher"].block().values + expected_spher = torch.einsum("Aa,iap->iAp", _CART_TO_SPHERICAL_L1, rotated_cart) + assert torch.allclose(rotated_spher, expected_spher, atol=1e-12) + + +def test_scalar_energy_gradients_are_rotated(): + # A scalar (energy-like) block carrying positions and strain gradients across two + # systems. Positions gradients transform as vectors (R @ g); strain gradients as + # rank-2 Cartesian tensors (R @ S @ R.T). Verifies gradient support and per-system + # routing through the parent block's "system" column. + systems = [ + _make_system([1, 1]), + _make_system([8, 8, 8]), + ] + R0 = torch.tensor(_axis_angle([1.0, 2.0, 3.0], 0.7), dtype=torch.float64) + R1 = torch.tensor(_axis_angle([0.0, 1.0, 1.0], 1.9), dtype=torch.float64) + + values = torch.tensor([[1.0], [2.0]], dtype=torch.float64) + pos_grad = torch.randn(5, 3, 1, dtype=torch.float64) # 2 + 3 atoms + strain_grad = torch.randn(2, 3, 3, 1, dtype=torch.float64) + + block = TensorBlock( + values=values, + samples=Labels(["system"], torch.tensor([[0], [1]], dtype=torch.int32)), + components=[], + properties=Labels(["energy"], torch.tensor([[0]], dtype=torch.int32)), + ) + block.add_gradient( + "positions", + TensorBlock( + values=pos_grad, + samples=Labels( + ["sample", "atom"], + torch.tensor( + [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]], dtype=torch.int32 + ), + ), + components=[ + Labels(["xyz"], torch.arange(3, dtype=torch.int32).reshape(-1, 1)) + ], + properties=Labels(["energy"], torch.tensor([[0]], dtype=torch.int32)), + ), + ) + block.add_gradient( + "strain", + TensorBlock( + values=strain_grad, + samples=Labels(["sample"], torch.tensor([[0], [1]], dtype=torch.int32)), + components=[ + Labels(["xyz_1"], torch.arange(3, dtype=torch.int32).reshape(-1, 1)), + Labels(["xyz_2"], torch.arange(3, dtype=torch.int32).reshape(-1, 1)), + ], + properties=Labels(["energy"], torch.tensor([[0]], dtype=torch.int32)), + ), + ) + tensor = TensorMap(Labels(["_"], torch.tensor([[0]], dtype=torch.int32)), [block]) + + _, out, _ = apply_transformations(systems, {"energy": tensor}, [R0, R1]) + out_block = out["energy"].block() + + # scalar values unchanged + assert torch.allclose(out_block.values, values) + + # positions gradients: rows 0,1 -> R0, rows 2,3,4 -> R1 + expected_pos = pos_grad.clone() + expected_pos[:2] = torch.einsum("Aa,iap->iAp", R0, pos_grad[:2]) + expected_pos[2:] = torch.einsum("Aa,iap->iAp", R1, pos_grad[2:]) + assert torch.allclose(out_block.gradient("positions").values, expected_pos) + + # strain gradients: row 0 -> R0 S R0.T, row 1 -> R1 S R1.T + expected_strain = strain_grad.clone() + expected_strain[0] = torch.einsum("Aa,abp,Bb->ABp", R0, strain_grad[0], R0) + expected_strain[1] = torch.einsum("Aa,abp,Bb->ABp", R1, strain_grad[1], R1) + assert torch.allclose(out_block.gradient("strain").values, expected_strain) + + +def test_positions_gradient_routes_by_parent_system_label(): + """Positions gradients follow the parent block's "system" label, not row position. + + The parent energy rows are given in non-sorted label order ([1, 0]); each gradient + row must be rotated by the transformation of the system named in *its parent row's* + "system" label, exactly as the values are routed. A positional routing (pairing + gradient ``sample == i`` with ``transformations[i]``) would mis-pair the gradient + rows with the systems whenever the parent rows are not sorted by label. + """ + # systems are passed sorted by label: systems[0] <-> label 0, systems[1] <-> label 1 + systems = [_make_system([1, 1]), _make_system([8, 8, 8])] + R0 = torch.tensor(_axis_angle([1.0, 2.0, 3.0], 0.7), dtype=torch.float64) + R1 = torch.tensor(_axis_angle([0.0, 1.0, 1.0], 1.9), dtype=torch.float64) + + # parent rows in non-sorted order: row 0 -> system label 1, row 1 -> system label 0 + values = torch.tensor([[1.0], [2.0]], dtype=torch.float64) + block = TensorBlock( + values=values, + samples=Labels(["system"], torch.tensor([[1], [0]], dtype=torch.int32)), + components=[], + properties=Labels(["energy"], torch.tensor([[0]], dtype=torch.int32)), + ) + # sample 0 -> parent row 0 (label 1, 3 atoms); sample 1 -> parent row 1 (label 0, 2) + pos_grad = torch.randn(5, 3, 1, dtype=torch.float64) + block.add_gradient( + "positions", + TensorBlock( + values=pos_grad, + samples=Labels( + ["sample", "atom"], + torch.tensor( + [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], dtype=torch.int32 + ), + ), + components=[ + Labels(["xyz"], torch.arange(3, dtype=torch.int32).reshape(-1, 1)) + ], + properties=Labels(["energy"], torch.tensor([[0]], dtype=torch.int32)), + ), + ) + tensor = TensorMap(Labels(["_"], torch.tensor([[0]], dtype=torch.int32)), [block]) + + _, out, _ = apply_transformations(systems, {"energy": tensor}, [R0, R1]) + grad = out["energy"].block().gradient("positions").values + + expected = pos_grad.clone() + # parent row 0 has label 1 -> R1 (gradient rows with sample == 0: indices 0,1,2) + expected[:3] = torch.einsum("Aa,iap->iAp", R1, pos_grad[:3]) + # parent row 1 has label 0 -> R0 (gradient rows with sample == 1: indices 3,4) + expected[3:] = torch.einsum("Aa,iap->iAp", R0, pos_grad[3:]) + assert torch.allclose(grad, expected, atol=1e-12) + + +def test_unsupported_component_axis_raises(): + # A component axis that is neither Cartesian nor spherical must raise rather than be + # silently passed through unchanged. + systems = [_make_system([1])] + tensor = TensorMap( + Labels(["_"], torch.tensor([[0]], dtype=torch.int32)), + [ + TensorBlock( + values=torch.zeros(1, 3, 1, dtype=torch.float64), + samples=Labels(["system"], torch.tensor([[0]], dtype=torch.int32)), + components=[ + Labels( + ["direction"], torch.arange(3, dtype=torch.int32).reshape(-1, 1) + ) + ], + properties=Labels(["p"], torch.tensor([[0]], dtype=torch.int32)), + ) + ], + ) + with pytest.raises(ValueError, match="neither a Cartesian"): + apply_transformations( + systems, {"bad": tensor}, [torch.eye(3, dtype=torch.float64)] + ) + + +def test_neighbor_list_vectors_are_rotated(): + R = torch.tensor(_axis_angle([1.0, 2.0, 3.0], 0.7), dtype=torch.float64) + system = _make_system( + [1, 1], + positions=torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=torch.float64), + ) + options = NeighborListOptions(cutoff=2.0, full_list=True, strict=False) + vectors = torch.tensor( + [[[1.0], [0.0], [0.0]]], dtype=torch.float64 + ) # (1 pair, 3, 1) + neighbors = TensorBlock( + values=vectors.clone(), + samples=Labels( + [ + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ], + torch.tensor([[0, 1, 0, 0, 0]], dtype=torch.int32), + ), + components=[Labels(["xyz"], torch.arange(3, dtype=torch.int32).reshape(-1, 1))], + properties=Labels(["distance"], torch.tensor([[0]], dtype=torch.int32)), + ) + register_autograd_neighbors(system, neighbors) + system.add_neighbor_list(options, neighbors) + + new_systems, _, _ = apply_transformations([system], {}, [R]) + + new_vectors = new_systems[0].get_neighbor_list(options).values + expected = (vectors.squeeze(-1) @ R.T).unsqueeze(-1) + assert torch.allclose(new_vectors, expected, atol=1e-12) + + +def test_system_custom_data_is_rotated_by_tensor_type(): + """Registered System data is rotated by the same machinery as targets. + + Scalar blocks must pass through unchanged, ``xyz`` vector blocks must rotate by + ``R``, and a multi-block TensorMap must be handled block-by-block (the previous + implementation rejected anything other than a single block). This is the path + exercised when a model carries per-atom geometric quantities (e.g. local frames) as + System data that has to follow the rotation of the structure. + """ + R = torch.tensor(_axis_angle([0.3, -0.7, 0.5], 0.9), dtype=torch.float64) + system = _make_system( + [1, 8], + positions=torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float64), + ) + + # scalar per-atom data spread over two blocks: both must pass through untouched + scalar = TensorMap( + keys=Labels(["block"], torch.tensor([[0], [1]], dtype=torch.int32)), + blocks=[ + TensorBlock( + values=torch.tensor([[1.0], [2.0]], dtype=torch.float64), + samples=Labels.range("atom", 2), + components=[], + properties=Labels.range("p", 1), + ), + TensorBlock( + values=torch.tensor([[3.0], [4.0]], dtype=torch.float64), + samples=Labels.range("atom", 2), + components=[], + properties=Labels.range("p", 1), + ), + ], + ) + # xyz vector per-atom data: must rotate by R on the component axis + vector_values = torch.tensor( + [[[1.0], [0.0], [0.0]], [[0.0], [2.0], [0.0]]], dtype=torch.float64 + ) # (2 atoms, 3, 1) + vector = TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=vector_values.clone(), + samples=Labels.range("atom", 2), + components=[ + Labels(["xyz"], torch.arange(3, dtype=torch.int32).reshape(-1, 1)) + ], + properties=Labels.range("p", 1), + ) + ], + ) + system.add_data("custom::scalar", scalar) + system.add_data("custom::vector", vector) + + new_systems, _, _ = apply_transformations([system], {}, [R]) + new_system = new_systems[0] + + new_scalar = new_system.get_data("custom::scalar") + for block_id in range(len(scalar.keys)): + assert torch.allclose( + new_scalar.block_by_id(block_id).values, + scalar.block_by_id(block_id).values, + ) + + new_vector = new_system.get_data("custom::vector").block().values + expected_vector = (vector_values.squeeze(-1) @ R.T).unsqueeze(-1) + assert torch.allclose(new_vector, expected_vector, atol=1e-12) + # sanity: the rotation actually changed the vector data + assert not torch.allclose(new_vector, vector_values) + + +def test_spherical_system_data_is_passed_through_unrotated(): + """Spherical System data is allowed but not rotated (no Wigner-D computed for it). + + Attaching an ``o3_mu`` block to a System via ``add_data`` is permitted, but the + augmentation does not build Wigner-D matrices for System data. Rather than crash or + rotate it incorrectly, such data must be passed through unchanged while the geometry + is still rotated -- so a model relying on it is responsible for its equivariance. + """ + R = torch.tensor(_axis_angle([0.2, 0.5, -0.8], 1.1), dtype=torch.float64) + system = _make_system( + [1, 8], + positions=torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float64), + ) + spherical_values = torch.tensor( + [[[1.0], [2.0], [3.0]], [[4.0], [5.0], [6.0]]], dtype=torch.float64 + ) # (2 atoms, 2*ell+1 = 3, 1) + mu = Labels(["o3_mu"], torch.arange(-1, 2, dtype=torch.int32).reshape(-1, 1)) + spherical = TensorMap( + keys=Labels( + ["o3_lambda", "o3_sigma"], torch.tensor([[1, 1]], dtype=torch.int32) + ), + blocks=[ + TensorBlock( + values=spherical_values.clone(), + samples=Labels.range("atom", 2), + components=[mu], + properties=Labels.range("p", 1), + ) + ], + ) + system.add_data("custom::spherical", spherical) + + # no targets -> no Wigner-D matrices are computed; this would KeyError if the + # spherical block were (incorrectly) sent through the rotation path + new_systems, _, _ = apply_transformations([system], {}, [R]) + + new_spherical = new_systems[0].get_data("custom::spherical").block().values + assert torch.allclose(new_spherical, spherical_values) + # the geometry itself is still rotated + assert torch.allclose(new_systems[0].positions, system.positions @ R.T, atol=1e-12) + + +def test_random_rotations_generator_is_reproducible(): + g1 = torch.Generator().manual_seed(12345) + g2 = torch.Generator().manual_seed(12345) + a = random_rotations( + 8, + device=torch.device("cpu"), + dtype=torch.float64, + include_inversions=True, + generator=g1, + ) + b = random_rotations( + 8, + device=torch.device("cpu"), + dtype=torch.float64, + include_inversions=True, + generator=g2, + ) + for Ra, Rb in zip(a, b, strict=True): + assert torch.equal(Ra, Rb) + + +def test_apply_transformations_raises_on_dtype_mismatch(): + systems = [_make_system([1, 8])] + R = torch.eye( + 3, dtype=torch.float32 + ) # mismatched vs float64 target/positions below + tensor = TensorMap( + Labels(["_"], torch.tensor([[0]], dtype=torch.int32)), + [ + TensorBlock( + values=torch.zeros(2, 3, 1, dtype=torch.float64), + samples=Labels( + ["system", "atom"], + torch.tensor([[0, 0], [0, 1]], dtype=torch.int32), + ), + components=[ + Labels(["xyz"], torch.arange(3, dtype=torch.int32).reshape(-1, 1)) + ], + properties=Labels(["p"], torch.tensor([[0]], dtype=torch.int32)), + ) + ], + ) + with pytest.raises(ValueError, match="dtype/device"): + apply_transformations(systems, {"t": tensor}, [R]) + + +def test_arbitrary_system_labels_are_remapped(): + # metatrain keeps each structure's original dataset index in the "system" column, + # so a 2-system batch can be labelled e.g. [92, 38] rather than [0, 1]. The i-th + # sorted label maps to system i: 38 -> system 0 (R0), 92 -> system 1 (R1). + systems = [_make_system([1, 8]), _make_system([1, 8])] + R0 = torch.tensor(_axis_angle([1.0, 2.0, 3.0], 0.7), dtype=torch.float64) + R1 = torch.tensor(_axis_angle([0.0, 1.0, 1.0], 1.9), dtype=torch.float64) + vectors = torch.randn(2, 3, 1, dtype=torch.float64) + tensor = TensorMap( + Labels(["_"], torch.tensor([[0]], dtype=torch.int32)), + [ + TensorBlock( + values=vectors, + samples=Labels( + ["system", "atom"], + torch.tensor([[92, 0], [38, 0]], dtype=torch.int32), + ), + components=[ + Labels(["xyz"], torch.arange(3, dtype=torch.int32).reshape(-1, 1)) + ], + properties=Labels(["p"], torch.tensor([[0]], dtype=torch.int32)), + ) + ], + ) + _, out, _ = apply_transformations(systems, {"t": tensor}, [R0, R1]) + result = out["t"].block().values + # row 0 has label 92 (sorted position 1) -> R1; row 1 has label 38 -> R0 + expected = vectors.clone() + expected[0] = torch.einsum("Aa,ap->Ap", R1, vectors[0]) + expected[1] = torch.einsum("Aa,ap->Ap", R0, vectors[1]) + assert torch.allclose(result, expected, atol=1e-12) + + +def test_too_many_distinct_system_ids_raises(): + # More distinct "system" labels than systems is genuinely ambiguous and must raise. + systems = [_make_system([1]), _make_system([8])] + R = torch.eye(3, dtype=torch.float64) + tensor = TensorMap( + Labels(["o3_lambda", "o3_sigma"], torch.tensor([[1, 1]], dtype=torch.int32)), + [ + TensorBlock( + values=torch.zeros(3, 3, 1, dtype=torch.float64), + samples=Labels( + ["system", "atom"], + torch.tensor([[10, 0], [20, 0], [30, 0]], dtype=torch.int32), + ), + components=[ + Labels( + ["o3_mu"], torch.arange(-1, 2, dtype=torch.int32).reshape(-1, 1) + ) + ], + properties=Labels(["p"], torch.tensor([[0]], dtype=torch.int32)), + ) + ], + ) + with pytest.raises(ValueError, match="distinct system indices"): + apply_transformations(systems, {"t": tensor}, [R, R])