diff --git a/python/metatomic_torch/metatomic/torch/__init__.py b/python/metatomic_torch/metatomic/torch/__init__.py index a8bf363a..dbbe3eb6 100644 --- a/python/metatomic_torch/metatomic/torch/__init__.py +++ b/python/metatomic_torch/metatomic/torch/__init__.py @@ -52,6 +52,7 @@ pick_output = torch.ops.metatomic.pick_output from . import ase_calculator # noqa: F401 +from ._augmentation import apply_transformations, random_rotations # noqa: F401 from .model import ( # noqa: F401 AtomisticModel, ModelInterface, diff --git a/python/metatomic_torch/metatomic/torch/_augmentation/__init__.py b/python/metatomic_torch/metatomic/torch/_augmentation/__init__.py new file mode 100644 index 00000000..0c347b95 --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/_augmentation/__init__.py @@ -0,0 +1,667 @@ +""" +O(3) augmentation: apply batched rotation/inversion transformations to Systems and +TensorMaps. + +metatomic stores positions and cell as row vectors (shape ``(N, 3)``), so a rotation +matrix ``R`` (3x3) acts as ``x @ R.T`` throughout this module. +""" + +import metatensor.torch as mts +import numpy as np +import torch +from metatensor.torch import TensorBlock, TensorMap + +from .. import System, register_autograd_neighbors +from ._wigner import compute_wigner_batch + + +# Component-axis names recognised by the augmentation machinery. Cartesian axes are +# rotated by ``R`` directly (so improper rotations flip vectors); spherical axes are +# rotated by the Wigner-D matrix of the matching ``o3_lambda`` plus an explicit +# ``(-1)^ell * sigma`` inversion parity factor. +_CARTESIAN_AXES = frozenset({"xyz", "xyz_1", "xyz_2"}) +_SPHERICAL_AXIS_TO_LAMBDA = { + "o3_mu": "o3_lambda", + "o3_mu_1": "o3_lambda_1", + "o3_mu_2": "o3_lambda_2", +} +_SPHERICAL_AXIS_TO_SIGMA = { + "o3_mu": "o3_sigma", + "o3_mu_1": "o3_sigma_1", + "o3_mu_2": "o3_sigma_2", +} + +# einsum index letters for the per-axis component contraction (input lower, output +# upper). Six axes is far more than any realistic block (value + gradient) needs. +_EINSUM_IN = "abcdef" +_EINSUM_OUT = "ABCDEF" + + +def _row_indices_from_system_ids( + system_ids: torch.Tensor, + n_systems: int, +) -> list[torch.Tensor]: + """Group row indices by system, mapping each distinct ``system`` id to one system. + + The ``"system"`` column is not necessarily a 0-based batch index: metatrain keeps + each structure's original dataset index there, so the labels in a batch are an + arbitrary set (e.g. ``[67, 92, 38, ...]``). The i-th distinct id, **in sorted + order**, is taken to be system ``i``; when the ids already span ``[0, n_systems)`` + this reduces to the identity grouping. + + .. note:: + + This pairs the i-th sorted label with ``systems[i]``/``transformations[i]``, + which is only consistent with the (positional) system rotation when the caller + passes ``systems`` ordered by their label. metatrain satisfies this. + + :param system_ids: 1-D tensor with one system id per row + :param n_systems: number of systems being augmented + :return: list of length ``n_systems``; entry ``i`` selects the rows of system ``i`` + :raises ValueError: if an id is negative or there are more distinct ids than systems + """ + device = system_ids.device + if len(system_ids) == 0: + return [ + torch.zeros(0, dtype=torch.long, device=device) for _ in range(n_systems) + ] + + min_id = int(system_ids.min().item()) + max_id = int(system_ids.max().item()) + if min_id < 0: + raise ValueError("Encountered output samples with negative system indices.") + + if max_id < n_systems: + # ids already index into [0, n_systems): group directly (sparse rows are fine). + return [ + torch.nonzero(system_ids == i, as_tuple=False).reshape(-1) + for i in range(n_systems) + ] + + # ids are arbitrary labels: map the sorted distinct ids onto systems 0..n_systems-1. + unique_ids: list[int] = sorted(set(system_ids.tolist())) + if len(unique_ids) > n_systems: + raise ValueError( + f"TensorMap block has {len(unique_ids)} distinct system indices " + f"but only {n_systems} systems were provided." + ) + id_to_pos = {uid: pos for pos, uid in enumerate(unique_ids)} + remapped = torch.tensor( + [id_to_pos[int(sid)] for sid in system_ids.tolist()], + dtype=torch.long, + device=device, + ) + return [ + torch.nonzero(remapped == i, as_tuple=False).reshape(-1) + for i in range(n_systems) + ] + + +def _block_row_indices_by_system( + block: TensorBlock, + n_systems: int, +) -> list[torch.Tensor]: + """Return row-index tensors into ``block.values``, one per system. + + With a single system every row belongs to it (any ``"system"`` label is ignored). + With several systems the ``"system"`` column is required and is interpreted by + :func:`_row_indices_from_system_ids`. + + :param block: block whose ``samples`` may contain a ``"system"`` column + :param n_systems: number of systems being augmented + :return: list of length ``n_systems``; entry ``i`` selects all rows of system ``i`` + """ + if "system" not in block.samples.names: + if n_systems == 1: + return [torch.arange(block.values.shape[0], device=block.values.device)] + raise ValueError( + "Rotational augmentation expects output samples to include a 'system' " + "dimension when transforming multiple systems." + ) + system_ids = block.samples.column("system").to(dtype=torch.long) + return _row_indices_from_system_ids(system_ids, n_systems) + + +def _gradient_row_indices_by_system( + grad_block: TensorBlock, + parent_block: TensorBlock, + n_systems: int, +) -> list[torch.Tensor]: + """Return row-index tensors into a gradient block, one per system. + + Gradient samples carry a ``"sample"`` column indexing into the parent block (the + metatensor convention) rather than their own ``"system"`` column, so the system of + each gradient row is read from the parent block's ``"system"`` column. + + :param grad_block: gradient block to route + :param parent_block: the value block this gradient is attached to + :param n_systems: number of systems being augmented + :return: list of length ``n_systems``; entry ``i`` selects the gradient rows of + system ``i`` + """ + if n_systems == 1: + return [ + torch.arange(grad_block.values.shape[0], device=grad_block.values.device) + ] + if "sample" not in grad_block.samples.names: + raise ValueError( + "Gradient samples are expected to include a 'sample' dimension indexing " + "into the parent block." + ) + if "system" not in parent_block.samples.names: + raise ValueError( + "Rotational augmentation expects parent samples to include a 'system' " + "dimension when transforming gradients of multiple systems." + ) + parent_system = parent_block.samples.column("system").to(dtype=torch.long) + sample_index = grad_block.samples.column("sample").to(dtype=torch.long) + return _row_indices_from_system_ids(parent_system[sample_index], n_systems) + + +def _has_spherical_axis(tmap: TensorMap) -> bool: + """Whether any block of ``tmap`` carries a spherical component axis. + + :param tmap: TensorMap to inspect + :return: ``True`` if any block has an ``o3_mu``/``o3_mu_1``/``o3_mu_2`` component + """ + for block in tmap.blocks(): + for component in block.components: + if component.names[0] in _SPHERICAL_AXIS_TO_LAMBDA: + return True + return False + + +def _transform_single_system( + system: System, + transformation: torch.Tensor, +) -> System: + """Apply an O(3) transformation to a single System. + + Rotates positions, cell vectors, any registered per-atom data, and all neighbor-list + displacement vectors. Types and pbc flags are unchanged. + + Registered data is rotated by the same machinery as targets (see + :func:`_transform_tmap`), as a single-system batch: scalar blocks pass through and + Cartesian (``xyz``/``xyz_1``/``xyz_2``) blocks are rotated by ``transformation``. + No Wigner-D matrices are computed for System data, so data carrying a spherical + (``o3_mu``) axis cannot be rotated and is passed through unchanged. + + :param system: input system + :param transformation: (3, 3) rotation or improper-rotation matrix + :return: new System with transformed geometry + """ + new_system = System( + positions=system.positions @ transformation.T, + types=system.types, + cell=system.cell @ transformation.T, + pbc=system.pbc, + ) + for data_name in system.known_data(): + data = system.get_data(data_name) + if _has_spherical_axis(data): + # Rotating spherical data needs Wigner-D matrices, which are not computed + # for System data; rather than rotate it incorrectly we pass it through + # unchanged (attaching spherical data to a System is allowed but exotic). + new_system.add_data(data_name, data) + else: + new_system.add_data( + data_name, + _transform_tmap(data_name, data, [system], [transformation], {}), + ) + for options in system.known_neighbor_lists(): + neighbors = mts.detach_block(system.get_neighbor_list(options)) + # neighbor vectors are stored as (N, 3, 1); squeeze/unsqueeze around the matmul + neighbors.values[:] = ( + neighbors.values.squeeze(-1) @ transformation.T + ).unsqueeze(-1) + register_autograd_neighbors(new_system, neighbors) + new_system.add_neighbor_list(options, neighbors) + return new_system + + +def _contract_component_axes( + values: torch.Tensor, + matrices: list[torch.Tensor], +) -> torch.Tensor: + """Rotate each component axis of ``values`` by its matrix. + + ``values`` has shape ``(n_rows, d_1, ..., d_k, n_properties)`` and ``matrices[j]`` + (shape ``(d_j, d_j)``) is contracted with component axis ``j`` as + ``out[..., A, ...] = sum_a matrices[j][A, a] * values[..., a, ...]``. + + :param values: values tensor of a value or gradient block + :param matrices: one rotation matrix per component axis (empty for scalars) + :return: rotated values, same shape as the input + """ + if len(matrices) == 0: + return values + n_axes = len(matrices) + in_subscript = "i" + _EINSUM_IN[:n_axes] + "p" + out_subscript = "i" + _EINSUM_OUT[:n_axes] + "p" + matrix_subscripts = [_EINSUM_OUT[j] + _EINSUM_IN[j] for j in range(n_axes)] + equation = ",".join(matrix_subscripts + [in_subscript]) + "->" + out_subscript + return torch.einsum(equation, *matrices, values) + + +def _axis_matrices_and_parity( + name: str, + components: list, + key, + R: torch.Tensor, + wigner_D_matrices: dict[int, list[torch.Tensor]], + system_index: int, + is_inverted: bool, +) -> tuple[list[torch.Tensor], int]: + """Pick the rotation matrix for each component axis and the O(3) inversion parity. + + Cartesian axes (``xyz``/``xyz_1``/``xyz_2``) use ``R`` directly, so improper + rotations flip vectors automatically. Spherical axes (``o3_mu``/``o3_mu_1``/ + ``o3_mu_2``) use the proper-rotation Wigner-D matrix of the matching ``o3_lambda`` + plus a ``(-1)^ell * sigma`` parity factor accumulated whenever ``R`` is improper. + + :param name: TensorMap name, used only in error messages + :param components: component :class:`Labels` of the block (or gradient block) + :param key: the parent block's key, supplying ``o3_lambda``/``o3_sigma`` values + :param R: this system's (3, 3) transformation matrix + :param wigner_D_matrices: ``{ell: [D_0, ..., D_{N-1}]}`` real Wigner-D matrices + :param system_index: index selecting this system's Wigner-D matrix + :param is_inverted: whether ``R`` is an improper rotation (``det(R) < 0``) + :return: ``(matrices, parity)`` with one matrix per component axis + :raises ValueError: if a component axis is neither Cartesian nor spherical + """ + matrices: list[torch.Tensor] = [] + parity = 1 + for component in components: + axis_name = component.names[0] + if axis_name in _CARTESIAN_AXES: + matrices.append(R) + elif axis_name in _SPHERICAL_AXIS_TO_LAMBDA: + ell = int(key[_SPHERICAL_AXIS_TO_LAMBDA[axis_name]]) + matrices.append(wigner_D_matrices[ell][system_index]) + if is_inverted: + sigma = int(key[_SPHERICAL_AXIS_TO_SIGMA[axis_name]]) + parity *= ((-1) ** ell) * sigma + else: + raise ValueError( + f"TensorMap '{name}' has component axis '{axis_name}', which is " + "neither a Cartesian ('xyz'/'xyz_1'/'xyz_2') nor spherical " + "('o3_mu'/'o3_mu_1'/'o3_mu_2') axis; rotational augmentation cannot " + "transform it." + ) + return matrices, parity + + +def _transform_component_values( + name: str, + values: torch.Tensor, + components: list, + key, + row_indices: list[torch.Tensor], + transformations: list[torch.Tensor], + wigner_D_matrices: dict[int, list[torch.Tensor]], +) -> torch.Tensor: + """Rotate the values of a single value or gradient block, per system. + + :param name: TensorMap name, used only in error messages + :param values: the block's values tensor + :param components: the block's component :class:`Labels` + :param key: the parent block's key (for spherical ``o3_lambda``/``o3_sigma``) + :param row_indices: per-system row indices into ``values`` + :param transformations: per-system (3, 3) transformation matrices + :param wigner_D_matrices: ``{ell: [D_0, ..., D_{N-1}]}`` real Wigner-D matrices + :return: new values tensor with each system's rows rotated + """ + new_values = values.clone() + for system_index, rows in enumerate(row_indices): + if len(rows) == 0: + continue + R = transformations[system_index] + is_inverted = bool(torch.det(R) < 0) + matrices, parity = _axis_matrices_and_parity( + name, components, key, R, wigner_D_matrices, system_index, is_inverted + ) + rotated = _contract_component_axes(values[rows], matrices) + if parity != 1: + rotated = rotated * parity + new_values[rows] = rotated + return new_values + + +def _transform_block( + name: str, + key, + block: TensorBlock, + n_systems: int, + transformations: list[torch.Tensor], + wigner_D_matrices: dict[int, list[torch.Tensor]], +) -> TensorBlock: + """Rotate one block and all of its gradients. + + Value rows are routed to systems by their ``"system"`` column; every gradient is + routed by the parent block's ``"system"`` label via + :func:`_gradient_row_indices_by_system` (the gradient's ``"sample"`` column indexes + the parent). Each gradient reuses the parent ``key``, so a spherical component axis + inherited from the value keeps the value's ``o3_lambda``/``o3_sigma``, while the + extra Cartesian gradient-direction axis is rotated by ``R``. + + :param name: TensorMap name, used only in error messages + :param key: the block's key + :param block: the value block to rotate + :param n_systems: number of systems being augmented + :param transformations: per-system (3, 3) transformation matrices + :param wigner_D_matrices: ``{ell: [D_0, ..., D_{N-1}]}`` real Wigner-D matrices + :return: new block with rotated values and gradients + """ + row_indices = _block_row_indices_by_system(block, n_systems) + new_block = TensorBlock( + values=_transform_component_values( + name, + block.values, + block.components, + key, + row_indices, + transformations, + wigner_D_matrices, + ), + samples=block.samples, + components=block.components, + properties=block.properties, + ) + for gradient_name in block.gradients_list(): + grad_block = block.gradient(gradient_name) + grad_rows = _gradient_row_indices_by_system(grad_block, block, n_systems) + new_block.add_gradient( + gradient_name, + TensorBlock( + values=_transform_component_values( + name, + grad_block.values, + grad_block.components, + key, + grad_rows, + transformations, + wigner_D_matrices, + ), + samples=grad_block.samples, + components=grad_block.components, + properties=grad_block.properties, + ), + ) + return new_block + + +def _transform_tmap( + name: str, + tmap: TensorMap, + systems: list[System], + transformations: list[torch.Tensor], + wigner_D_matrices: dict[int, list[torch.Tensor]], +) -> TensorMap: + """Rotate every block (and its gradients) of a TensorMap. + + The tensor character of each component axis is inferred from its name, so a single + TensorMap may freely mix scalar, Cartesian and spherical blocks, and blocks may + carry gradients. + + :param name: used only in error messages + :param tmap: TensorMap to rotate + :param systems: input systems; length gives the batch size + :param transformations: per-system (3, 3) transformation matrices + :param wigner_D_matrices: ``{ell: [D_0, ..., D_{N-1}]}`` real Wigner-D matrices + :return: new TensorMap with rotated values and gradients + :raises ValueError: if a component axis name is not recognised + """ + n_systems = len(systems) + new_blocks = [ + _transform_block( + name, key, block, n_systems, transformations, wigner_D_matrices + ) + for key, block in tmap.items() + ] + return TensorMap(keys=tmap.keys, blocks=new_blocks) + + +def _apply_transformations( + systems: list[System], + targets: dict[str, TensorMap], + transformations: list[torch.Tensor], + wigner_D_matrices: dict[int, list[torch.Tensor]], + extra_data: dict[str, TensorMap] | None = None, +) -> tuple[list[System], dict[str, TensorMap], dict[str, TensorMap]]: + """Apply a batch of O(3) transformations to systems and TensorMaps simultaneously. + + Each element of ``transformations`` is a (3, 3) matrix (rotation or improper + rotation) applied to the corresponding system. TensorMaps in ``targets`` and + ``extra_data`` are transformed per system (Cartesian axes by ``R``, spherical axes + by the matching Wigner-D matrix), except that keys ending in ``"_mask"`` pass + through unchanged. + + :param systems: input systems, one per transformation + :param targets: TensorMaps to transform (e.g. model predictions to back-rotate) + :param transformations: per-system (3, 3) transformation matrices + :param wigner_D_matrices: ``{ell: [D_0, ..., D_{N-1}]}`` real Wigner-D matrices + :param extra_data: additional TensorMaps to transform alongside targets + :return: ``(new_systems, new_targets, new_extra_data)`` + """ + new_systems = [ + _transform_single_system(system, R) + for system, R in zip(systems, transformations, strict=True) + ] + + new_targets: dict[str, TensorMap] = { + name: _transform_tmap(name, tmap, systems, transformations, wigner_D_matrices) + for name, tmap in targets.items() + } + + new_extra_data: dict[str, TensorMap] = {} + if extra_data is not None: + for key, value in extra_data.items(): + if key.endswith("_mask"): + new_extra_data[key] = value + else: + new_extra_data[key] = _transform_tmap( + key, value, systems, transformations, wigner_D_matrices + ) + + return new_systems, new_targets, new_extra_data + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +_SPHERICAL_KEY_NAMES = frozenset({"o3_lambda", "o3_lambda_1", "o3_lambda_2"}) + + +def _rotations_to_zyz( + rotations: list[torch.Tensor], +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Decompose a list of O(3) matrices into ZYZ Euler angles :math:`(\\alpha, \\beta, + \\gamma)`. + + For improper rotations (det < 0) the proper part ``-R`` is decomposed; the inversion + parity factor is handled separately when applying Wigner-D matrices. + + :param rotations: list of (3, 3) orthogonal tensors + :return: ``(alphas, betas, gammas)`` as 1-D numpy arrays of length + ``len(rotations)`` + """ + alphas = np.empty(len(rotations)) + betas = np.empty(len(rotations)) + gammas = np.empty(len(rotations)) + for i, R in enumerate(rotations): + proper_R = R if torch.det(R) > 0 else -R + # R = Rz(alpha) Ry(beta) Rz(gamma): element [2,2] = cos(beta) + cos_beta = float(proper_R[2, 2].clamp(-1.0, 1.0)) + beta = np.arccos(cos_beta) + sin_beta = np.sin(beta) + if abs(sin_beta) < 1e-10: + # Gimbal lock: only alpha +/- gamma is defined; fix gamma=0 + if cos_beta > 0: + alpha = float(torch.atan2(proper_R[1, 0], proper_R[0, 0])) + else: + alpha = float(torch.atan2(-proper_R[1, 0], -proper_R[0, 0])) + gamma = 0.0 + else: + # R[0,2]=cos(alpha)*sin(beta), R[1,2]=sin(alpha)*sin(beta): alpha via atan2 + # R[2,1]=sin(beta)*sin(gamma), R[2,0]=-sin(beta)*cos(gamma): gamma via atan2 + alpha = float(torch.atan2(proper_R[1, 2], proper_R[0, 2])) + gamma = float(torch.atan2(proper_R[2, 1], -proper_R[2, 0])) + alphas[i] = alpha + betas[i] = beta + gammas[i] = gamma + return alphas, betas, gammas + + +def random_rotations( + n: int, + *, + device: torch.device, + dtype: torch.dtype, + include_inversions: bool = False, + generator: torch.Generator | None = None, +) -> list[torch.Tensor]: + """Sample ``n`` uniformly distributed O(3) transformations. + + Rotations are sampled from the Haar measure on SO(3) via random unit quaternions. + When ``include_inversions`` is ``True``, each matrix is independently negated with + probability 0.5, giving a uniform distribution over the full O(3) group. + + :param n: number of transformations to generate + :param device: target device for the output tensors + :param dtype: target dtype for the output tensors + :param include_inversions: if ``True``, sample from O(3) instead of SO(3) + :param generator: optional :class:`torch.Generator` for reproducible sampling; when + ``None`` the global RNG is used + :return: list of ``n`` orthogonal (3, 3) tensors + """ + q = torch.randn(n, 4, device=device, dtype=dtype, generator=generator) + q = q / q.norm(dim=1, keepdim=True) + w, x, y, z = q.unbind(1) + # Quaternion to rotation matrix (standard formula) + R = torch.stack( + [ + 1 - 2 * (y * y + z * z), + 2 * (x * y - w * z), + 2 * (x * z + w * y), + 2 * (x * y + w * z), + 1 - 2 * (x * x + z * z), + 2 * (y * z - w * x), + 2 * (x * z - w * y), + 2 * (y * z + w * x), + 1 - 2 * (x * x + y * y), + ], + dim=1, + ).reshape(n, 3, 3) + if include_inversions: + signs = torch.randint(0, 2, (n,), device=device, generator=generator) * 2 - 1 + R = R * signs.to(dtype=dtype).reshape(n, 1, 1) + return list(R.unbind(0)) + + +def apply_transformations( + systems: list[System], + targets: dict[str, TensorMap], + transformations: list[torch.Tensor], + extra_data: dict[str, TensorMap] | None = None, +) -> tuple[list[System], dict[str, TensorMap], dict[str, TensorMap]]: + """Apply a batch of O(3) transformations to systems and TensorMaps simultaneously. + + Wigner-D matrices are derived automatically from ``transformations``; the tensor + type (scalar, Cartesian, spherical) is inferred from each TensorMap's component axis + names. Keys in ``extra_data`` that end in ``"_mask"`` pass through unchanged. + + :param systems: input systems, one per transformation + :param targets: model output TensorMaps to back-rotate (e.g. predicted energies, + forces, or spherical features) + :param transformations: per-system (3, 3) orthogonal matrices; use + :func:`random_rotations` to generate these + :param extra_data: additional TensorMaps to transform alongside ``targets`` + :return: ``(new_systems, new_targets, new_extra_data)`` + :raises ValueError: if ``len(systems) != len(transformations)``, any matrix is not a + (3, 3) orthogonal matrix, or the transformations, systems and TensorMaps do not + share a common dtype and device + """ + if len(systems) != len(transformations): + raise ValueError( + f"Expected one transformation per system, got {len(transformations)} " + f"transformations for {len(systems)} systems." + ) + for i, R in enumerate(transformations): + if R.shape != (3, 3): + raise ValueError( + f"Transformation {i} has shape {tuple(R.shape)}; expected (3, 3)." + ) + identity = torch.eye(3, device=R.device, dtype=R.dtype) + if not torch.allclose(R @ R.T, identity, atol=1e-5): + raise ValueError( + f"Transformation {i} is not orthogonal (R @ R.T deviates from I)." + ) + + if len(transformations) > 0: + # Everything is contracted with the transformations (or the Wigner-D matrices + # derived from them), so dtype and device must match throughout, otherwise the + # matmuls below fail with a much less helpful message. + reference = transformations[0] + for i, R in enumerate(transformations): + if R.dtype != reference.dtype or R.device != reference.device: + raise ValueError( + f"Transformation {i} has dtype/device ({R.dtype}, {R.device}) " + f"differing from transformation 0 ({reference.dtype}, " + f"{reference.device}); all transformations must agree." + ) + for i, system in enumerate(systems): + if ( + system.positions.dtype != reference.dtype + or system.positions.device != reference.device + ): + raise ValueError( + f"System {i} has positions with dtype/device " + f"({system.positions.dtype}, {system.positions.device}) differing " + f"from the transformations ({reference.dtype}, " + f"{reference.device})." + ) + for label, tmap in list(targets.items()) + ( + [(k, v) for k, v in extra_data.items() if not k.endswith("_mask")] + if extra_data is not None + else [] + ): + for block in tmap.blocks(): + if ( + block.values.dtype != reference.dtype + or block.values.device != reference.device + ): + raise ValueError( + f"TensorMap '{label}' has values with dtype/device " + f"({block.values.dtype}, {block.values.device}) differing " + f"from the transformations ({reference.dtype}, " + f"{reference.device})." + ) + + # Determine the highest angular momentum present across all TensorMaps + ell_max = 0 + all_tmaps = list(targets.values()) + if extra_data is not None: + all_tmaps += [v for k, v in extra_data.items() if not k.endswith("_mask")] + for tmap in all_tmaps: + for name in tmap.keys.names: + if name in _SPHERICAL_KEY_NAMES: + col = tmap.keys.column(name) + if len(col) > 0: + ell_max = max(ell_max, int(col.max())) + + if len(transformations) > 0: + device = transformations[0].device + dtype = transformations[0].dtype + angles = _rotations_to_zyz(transformations) + wigner_batch = compute_wigner_batch(ell_max, angles, device=device, dtype=dtype) + # Unbind the batch dim: {ell: (N, 2l+1, 2l+1)} to {ell: [D_0, ..., D_{N-1}]} + wigner_D_matrices: dict[int, list[torch.Tensor]] = { + ell: list(D.unbind(0)) for ell, D in wigner_batch.items() + } + else: + wigner_D_matrices = {} + + return _apply_transformations( + systems, targets, transformations, wigner_D_matrices, extra_data + ) diff --git a/python/metatomic_torch/metatomic/torch/_augmentation/_wigner.py b/python/metatomic_torch/metatomic/torch/_augmentation/_wigner.py new file mode 100644 index 00000000..213d0644 --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/_augmentation/_wigner.py @@ -0,0 +1,527 @@ +"""Private Wigner-d/Wigner-D helpers for symmetry operations. + +Adapted from the `spherical` project (MIT license), primarily from +`spherical/recursions/wignerH.py`, `spherical/utilities/indexing.py`, and the Wigner-D +assembly logic in `spherical/wigner.py`. + +This reduced metatomic copy keeps only the recurrence-based pieces needed to build real +Wigner-D matrices from ZYZ Euler angles. It intentionally does not depend on +`spinsfast`, `quaternionic`, or the public `spherical` package. +""" + +import functools + +import numpy as np +import torch + +from .._jit_compat import jit + + +@jit +def _epsilon(m: int) -> int: + if m <= 0: + return 1 + if m % 2: + return -1 + return 1 + + +@jit +def _nm_index(n: int, m: int) -> int: + return m + n * (n + 1) + + +@jit +def _nabsm_index(n: int, absm: int) -> int: + return absm + (n * (n + 1)) // 2 + + +@jit +def _wigner_h_size(mp_max: int, ell_max: int) -> int: + if ell_max < 0: + return 0 + if mp_max >= ell_max: + return (ell_max + 1) * (ell_max + 2) * (2 * ell_max + 3) // 6 + + return ( + (ell_max + 1) * (ell_max + 2) * (2 * ell_max + 3) + - 2 * (ell_max - mp_max) * (ell_max - mp_max + 1) * (ell_max - mp_max + 2) + ) // 6 + + +@jit +def _wigner_d_size(ell_min: int, mp_max: int, ell_max: int) -> int: + if mp_max >= ell_max: + return ( + ell_max * (ell_max * (4 * ell_max + 12) + 11) + + ell_min * (1 - 4 * ell_min**2) + + 3 + ) // 3 + if mp_max > ell_min: + return ( + 3 * ell_max * (ell_max + 2) + + ell_min * (1 - 4 * ell_min**2) + + mp_max + * (3 * ell_max * (2 * ell_max + 4) + mp_max * (-2 * mp_max - 3) + 5) + + 3 + ) // 3 + + return (ell_max * (ell_max + 2) - ell_min**2) * (1 + 2 * mp_max) + 2 * mp_max + 1 + + +@jit +def _wigner_h_index_base(ell: int, mp: int, m: int, mp_max: int) -> int: + local_mp_max = mp_max + if local_mp_max > ell: + local_mp_max = ell + idx = _wigner_h_size(local_mp_max, ell - 1) + if mp < 1: + idx += (local_mp_max + mp) * (2 * ell - local_mp_max + mp + 1) // 2 + else: + idx += (local_mp_max + 1) * (2 * ell - local_mp_max + 2) // 2 + idx += (mp - 1) * (2 * ell - mp + 2) // 2 + idx += m - abs(mp) + return idx + + +@jit +def _wigner_h_index(ell: int, mp: int, m: int, mp_max: int) -> int: + if ell == 0: + return 0 + + local_mp_max = mp_max + if local_mp_max > ell: + local_mp_max = ell + + if m < -mp: + if m < mp: + return _wigner_h_index_base(ell, -mp, -m, local_mp_max) + return _wigner_h_index_base(ell, -m, -mp, local_mp_max) + + if m < mp: + return _wigner_h_index_base(ell, m, mp, local_mp_max) + return _wigner_h_index_base(ell, mp, m, local_mp_max) + + +@jit +def _wigner_d_index(ell: int, mp: int, m: int, ell_min: int, mp_max: int) -> int: + idx = 0 + for ell_prev in range(ell_min, ell): + local_mp_max = mp_max if mp_max < ell_prev else ell_prev + idx += (2 * local_mp_max + 1) * (2 * ell_prev + 1) + + local_mp_max = mp_max if mp_max < ell else ell + idx += (mp + local_mp_max) * (2 * ell + 1) + idx += m + ell + return idx + + +@jit +def _step_1(hwedge): + """Seed the recurrence: D^0_{0,0} = 1.""" + hwedge[0] = 1.0 + + +@jit +def _step_2(g, h, n_max, mp_max, hwedge, hextra, hv, expi_beta): + """Fill the mp=0 column (top diagonal) of H-wedge using the g/h recurrence.""" + cos_beta = expi_beta.real + sin_beta = expi_beta.imag + sqrt3 = np.sqrt(3.0) + inverse_sqrt2 = 1.0 / np.sqrt(2.0) + if n_max > 0: + n0n_index = _wigner_h_index(1, 0, 1, mp_max) + nn_index = _nm_index(1, 1) + hwedge[n0n_index] = sqrt3 + hwedge[n0n_index - 1] = (g[nn_index - 1] * cos_beta) * inverse_sqrt2 + for n in range(2, n_max + 2): + if n <= n_max: + n0n_index = _wigner_h_index(n, 0, n, mp_max) + out = hwedge + else: + n0n_index = n + out = hextra + prev_index = _wigner_h_index(n - 1, 0, n - 1, mp_max) + nn_index = _nm_index(n, n) + const = np.sqrt(1.0 + 0.5 / n) + g_i = g[nn_index - 1] + out[n0n_index] = const * hwedge[prev_index] + out[n0n_index - 1] = g_i * cos_beta * out[n0n_index] + for i in range(2, n): + g_i = g[nn_index - i] + h_i = h[nn_index - i] + out[n0n_index - i] = ( + g_i * cos_beta * out[n0n_index - i + 1] + - h_i * sin_beta**2 * out[n0n_index - i + 2] + ) + const = 1.0 / np.sqrt(4 * n + 2) + g_i = g[nn_index - n] + h_i = h[nn_index - n] + out[n0n_index - n] = ( + g_i * cos_beta * out[n0n_index - n + 1] + - h_i * sin_beta**2 * out[n0n_index - n + 2] + ) * const + prefactor = const + for i in range(1, n): + prefactor *= sin_beta + out[n0n_index - n + i] *= prefactor + if n <= n_max: + hv[_nm_index(n, 1)] = hwedge[_wigner_h_index(n, 0, 1, mp_max)] + hv[_nm_index(n, 0)] = hwedge[_wigner_h_index(n, 0, 1, mp_max)] + prefactor = 1.0 + for n in range(1, n_max + 1): + prefactor *= sin_beta + hwedge[_wigner_h_index(n, 0, n, mp_max)] *= prefactor / np.sqrt(4 * n + 2) + prefactor *= sin_beta + hextra[n_max + 1] *= prefactor / np.sqrt(4 * (n_max + 1) + 2) + hv[_nm_index(1, 1)] = hwedge[_wigner_h_index(1, 0, 1, mp_max)] + hv[_nm_index(1, 0)] = hwedge[_wigner_h_index(1, 0, 1, mp_max)] + + +@jit +def _step_3(a, b, n_max, mp_max, hwedge, hextra, expi_beta): + """Fill the mp=1 sub-diagonal of H-wedge using the a/b recurrence.""" + cos_beta = expi_beta.real + sin_beta = expi_beta.imag + if n_max > 0 and mp_max > 0: + for n in range(1, n_max + 1): + i1 = _wigner_h_index(n, 1, 1, mp_max) + if n + 1 <= n_max: + i2 = _wigner_h_index(n + 1, 0, 0, mp_max) + h2 = hwedge + else: + i2 = 0 + h2 = hextra + i3 = _nm_index(n + 1, 0) + i4 = _nabsm_index(n, 1) + inverse_b5 = 1.0 / b[i3] + for i in range(n): + b6 = b[-i + i3 - 2] + b7 = b[i + i3] + a8 = a[i + i4] + hwedge[i + i1] = inverse_b5 * ( + 0.5 + * ( + b6 * (1 - cos_beta) * h2[i + i2 + 2] + - b7 * (1 + cos_beta) * h2[i + i2] + ) + - a8 * sin_beta * h2[i + i2 + 1] + ) + + +@jit +def _step_4(d, n_max, mp_max, hwedge, hv): + """Fill H-wedge for mp >= 2 by stepping up in mp via the d recurrence.""" + if n_max > 0 and mp_max > 0: + for n in range(2, n_max + 1): + for mp in range(1, min(n, mp_max)): + i1 = _wigner_h_index(n, mp + 1, mp + 1, mp_max) - 1 + i2 = _wigner_h_index(n, mp - 1, mp, mp_max) + i3 = _wigner_h_index(n, mp, mp, mp_max) - 1 + i4 = _wigner_h_index(n, mp, mp + 1, mp_max) + i5 = _nm_index(n, mp) + i6 = _nm_index(n, mp - 1) + inverse_d5 = 1.0 / d[i5] + d6 = d[i6] + hv[_nm_index(n, mp + 1)] = inverse_d5 * ( + d6 * hwedge[i2] - d[i6] * hv[_nm_index(n, mp)] + d[i5] * hwedge[i4] + ) + for i in range(1, n - mp): + d7 = d[i + i6] + d8 = d[i + i5] + hwedge[i + i1] = inverse_d5 * ( + d6 * hwedge[i + i2] - d7 * hwedge[i + i3] + d8 * hwedge[i + i4] + ) + i = n - mp + hwedge[i + i1] = inverse_d5 * ( + d6 * hwedge[i + i2] - d[i + i6] * hwedge[i + i3] + ) + + +@jit +def _step_5(d, n_max, mp_max, hwedge, hv): + """Fill H-wedge for mp <= 0 by stepping down in mp via the d recurrence.""" + if n_max > 0 and mp_max > 0: + for n in range(0, n_max + 1): + for mp in range(0, -min(n, mp_max), -1): + i1 = _wigner_h_index(n, mp - 1, -mp + 1, mp_max) - 1 + i2 = _wigner_h_index(n, mp + 1, -mp + 1, mp_max) - 1 + i3 = _wigner_h_index(n, mp, -mp, mp_max) - 1 + i4 = _wigner_h_index(n, mp, -mp + 1, mp_max) + i5 = _nm_index(n, mp - 1) + i6 = _nm_index(n, mp) + i7 = _nm_index(n, -mp - 1) + i8 = _nm_index(n, -mp) + inverse_d5 = 1.0 / d[i5] + d6 = d[i6] + d7 = d[i7] + d8 = d[i8] + if mp == 0: + hv[_nm_index(n, mp - 1)] = inverse_d5 * ( + d6 * hv[_nm_index(n, mp + 1)] + + d7 * hv[_nm_index(n, mp)] + - d8 * hwedge[i4] + ) + else: + hv[_nm_index(n, mp - 1)] = inverse_d5 * ( + d6 * hwedge[i2] + d7 * hv[_nm_index(n, mp)] - d8 * hwedge[i4] + ) + for i in range(1, n + mp): + d7 = d[i + i7] + d8 = d[i + i8] + hwedge[i + i1] = inverse_d5 * ( + d6 * hwedge[i + i2] + d7 * hwedge[i + i3] - d8 * hwedge[i + i4] + ) + i = n + mp + hwedge[i + i1] = inverse_d5 * ( + d6 * hwedge[i + i2] + d[i + i7] * hwedge[i + i3] + ) + + +def _create_wigner_coefficients(ell_max: int): + """Pre-compute the scalar recurrence coefficients used by the five-step Risbo + recurrence. + + Returns five arrays ``(a, b, d, g, h)`` indexed by ``(n, m)`` pairs flattened in + lexicographic order up to ``ell_max + 1`` (one step beyond the target to seed the + recurrence). The arrays contain the coupling constants that appear in the three-term + recurrences for the Wigner d-matrix entries. + + :param ell_max: highest angular-momentum order needed + :return: ``(a, b, d, g, h)`` numpy float arrays + """ + n = np.array([n for n in range(ell_max + 2) for _ in range(-n, n + 1)]) + m = np.array([m for n in range(ell_max + 2) for m in range(-n, n + 1)]) + absn = np.array([n for n in range(ell_max + 2) for _ in range(n + 1)]) + absm = np.array([m for n in range(ell_max + 2) for m in range(n + 1)]) + + a = np.sqrt( + (absn + 1 + absm) * (absn + 1 - absm) / ((2 * absn + 1) * (2 * absn + 3)) + ) + b = np.sqrt((n - m - 1) * (n - m) / ((2 * n - 1) * (2 * n + 1))) + b[m < 0] *= -1 + d = 0.5 * np.sqrt((n - m) * (n + m + 1)) + d[m < 0] *= -1 + with np.errstate(divide="ignore", invalid="ignore"): + g = 2 * (m + 1) / np.sqrt((n - m) * (n + m + 1)) + h = np.sqrt((n + m + 2) * (n - m - 1) / ((n - m) * (n + m + 1))) + return a, b, d, g, h + + +def _complex_powers(z: complex, ell_max: int) -> np.ndarray: + powers = np.empty(ell_max + 1, dtype=np.complex128) + powers[0] = 1.0 + 0.0j + for idx in range(1, ell_max + 1): + powers[idx] = powers[idx - 1] * z + return powers + + +def _to_euler_phases( + alpha: float, beta: float, gamma: float +) -> tuple[complex, complex, complex]: + # Match spherical.Wigner's convention after converting scipy's ZYZ Euler angles + # into the phases used by the recurrence. + z_alpha = np.exp(-1j * alpha) + expi_beta = np.exp(1j * beta) + z_gamma = np.exp(-1j * gamma) + return z_alpha, expi_beta, z_gamma + + +def _compute_wigner_d_complex( + ell_max: int, alpha: np.ndarray, beta: np.ndarray, gamma: np.ndarray +) -> np.ndarray: + """Compute complex Wigner-D matrix elements for all ell in ``[0, ell_max]``. + + Implements the five-step Risbo/Trapani-Navaza recurrence (see module docstring). + Results are stored in a flat array indexed by :func:`_wigner_d_index`. + + :param ell_max: maximum angular-momentum order + :param alpha: ZYZ first rotation angle, arbitrary shape + :param beta: ZYZ second rotation angle, same shape as ``alpha`` + :param gamma: ZYZ third rotation angle, same shape as ``alpha`` + :return: complex array of shape ``(*alpha.shape, dsize)`` + """ + if not (alpha.shape == beta.shape == gamma.shape): + raise ValueError("alpha, beta, and gamma must have identical shapes") + + mp_max = ell_max + a, b, d, g, h = _create_wigner_coefficients(ell_max) + hsize = _wigner_h_size(mp_max, ell_max) + dsize = _wigner_d_size(0, mp_max, ell_max) + result = np.zeros(alpha.shape + (dsize,), dtype=np.complex128) + + for index in np.ndindex(alpha.shape): + z_alpha, expi_beta, z_gamma = _to_euler_phases( + float(alpha[index]), float(beta[index]), float(gamma[index]) + ) + hwedge = np.zeros(hsize, dtype=np.float64) + hv = np.zeros((ell_max + 1) ** 2, dtype=np.float64) + hextra = np.zeros(ell_max + 2, dtype=np.float64) + + _step_1(hwedge) + _step_2(g, h, ell_max, mp_max, hwedge, hextra, hv, expi_beta) + _step_3(a, b, ell_max, mp_max, hwedge, hextra, expi_beta) + _step_4(d, ell_max, mp_max, hwedge, hv) + _step_5(d, ell_max, mp_max, hwedge, hv) + + z_alpha_powers = _complex_powers(z_alpha, ell_max) + z_gamma_powers = _complex_powers(z_gamma, ell_max) + out = result[index] + for ell in range(0, ell_max + 1): + for mp in range(-ell, 0): + i_d = _wigner_d_index(ell, mp, -ell, 0, mp_max) + for m in range(-ell, 0): + i_h = _wigner_h_index(ell, mp, m, mp_max) + out[i_d] = ( + _epsilon(mp) + * _epsilon(-m) + * hwedge[i_h] + * z_gamma_powers[-m].conjugate() + * z_alpha_powers[-mp].conjugate() + ) + i_d += 1 + for m in range(0, ell + 1): + i_h = _wigner_h_index(ell, mp, m, mp_max) + out[i_d] = ( + _epsilon(mp) + * _epsilon(-m) + * hwedge[i_h] + * z_gamma_powers[m] + * z_alpha_powers[-mp].conjugate() + ) + i_d += 1 + for mp in range(0, ell + 1): + i_d = _wigner_d_index(ell, mp, -ell, 0, mp_max) + for m in range(-ell, 0): + i_h = _wigner_h_index(ell, mp, m, mp_max) + out[i_d] = ( + _epsilon(mp) + * _epsilon(-m) + * hwedge[i_h] + * z_gamma_powers[-m].conjugate() + * z_alpha_powers[mp] + ) + i_d += 1 + for m in range(0, ell + 1): + i_h = _wigner_h_index(ell, mp, m, mp_max) + out[i_d] = ( + _epsilon(mp) + * _epsilon(-m) + * hwedge[i_h] + * z_gamma_powers[m] + * z_alpha_powers[mp] + ) + i_d += 1 + + return result + + +def compute_complex_wigner_d_matrices( + ell_max: int, + angles: tuple[np.ndarray, np.ndarray, np.ndarray], +) -> dict[int, np.ndarray]: + """Return complex Wigner-D matrices for ell in ``[0, ell_max]`` at the given ZYZ + angles. + + :param ell_max: maximum angular-momentum order + :param angles: ``(alpha, beta, gamma)`` ZYZ Euler-angle arrays of matching shape + :return: ``{ell: array of shape (*angles[0].shape, 2*ell+1, 2*ell+1)}`` + """ + alpha, beta, gamma = angles + raw = _compute_wigner_d_complex(ell_max, alpha, beta, gamma) + matrices: dict[int, np.ndarray] = {} + for ell in range(ell_max + 1): + shape = alpha.shape + (2 * ell + 1, 2 * ell + 1) + block = np.zeros(shape, dtype=np.complex128) + for mp in range(-ell, ell + 1): + for m in range(-ell, ell + 1): + block[..., mp + ell, m + ell] = raw[ + ..., _wigner_d_index(ell, mp, m, 0, ell_max) + ] + matrices[ell] = block + return matrices + + +def compute_real_wigner_d_matrices( + ell_max: int, + angles: tuple[np.ndarray, np.ndarray, np.ndarray], + complex_to_real: dict[int, np.ndarray], +) -> dict[int, torch.Tensor]: + """Convert complex Wigner-D matrices to real ones using the provided + change-of-basis. + + :param ell_max: maximum angular-momentum order + :param angles: ``(alpha, beta, gamma)`` ZYZ Euler-angle arrays + :param complex_to_real: ``{ell: (2*ell+1, 2*ell+1)}`` unitary transform from complex + to real spherical harmonics + :return: ``{ell: real tensor of shape (*angles[0].shape, 2*ell+1, 2*ell+1)}`` + :raises ValueError: if imaginary residuals exceed numerical tolerance + """ + complex_matrices = compute_complex_wigner_d_matrices(ell_max, angles) + real_matrices: dict[int, torch.Tensor] = {} + for ell, matrix in complex_matrices.items(): + transform = complex_to_real[ell] + matrix = np.einsum("ij,...jk,kl->...il", transform.conj(), matrix, transform.T) + # Recursion accumulates floating-point noise that grows with ell, so scale + # the tolerance by the magnitude of the matrix instead of using a fixed atol. + scale = float(np.max(np.abs(matrix.real))) if matrix.size else 1.0 + atol = max(1e-9, scale * 1e-10) + if not np.allclose(matrix.imag, 0.0, atol=atol): + raise ValueError("real Wigner matrix conversion produced complex values") + real_matrices[ell] = torch.from_numpy(matrix.real) + return real_matrices + + +@functools.lru_cache(maxsize=None) +def _complex_to_real_spherical_harmonics_transform(ell: int) -> np.ndarray: + """ + Generate the transformation matrix from complex spherical harmonics to real + spherical harmonics for a given l. Returns a transformation matrix of shape ``(2l+1, + 2l+1)``. + """ + if ell < 0 or not isinstance(ell, int): + raise ValueError("l must be a non-negative integer.") + + size = 2 * ell + 1 + T = np.zeros((size, size), dtype=complex) + + for m in range(-ell, ell + 1): + m_index = m + ell + if m > 0: + T[m_index, ell + m] = 1 / np.sqrt(2) * (-1) ** m + T[m_index, ell - m] = 1 / np.sqrt(2) + elif m < 0: + T[m_index, ell + abs(m)] = -1j / np.sqrt(2) * (-1) ** m + T[m_index, ell - abs(m)] = 1j / np.sqrt(2) + else: + T[m_index, ell] = 1 + + return T + + +def compute_real_wigner_matrices( + o3_lambda_max: int, + angles: tuple[np.ndarray, np.ndarray, np.ndarray], +) -> dict[int, torch.Tensor]: + """Build the real Wigner-D matrices for ``ell = 0..o3_lambda_max`` at the given + ZYZ Euler angles, using the cached complex-to-real transform per ell.""" + complex_to_real = { + ell: _complex_to_real_spherical_harmonics_transform(ell) + for ell in range(o3_lambda_max + 1) + } + return compute_real_wigner_d_matrices(o3_lambda_max, angles, complex_to_real) + + +def compute_wigner_batch( + ell_max: int, + angles: tuple[np.ndarray, np.ndarray, np.ndarray], + *, + device: torch.device, + dtype: torch.dtype, +) -> dict[int, torch.Tensor]: + """Real Wigner-D matrices for ``ell = 0..ell_max`` at the given angles, cast to + the requested device and dtype.""" + return { + ell: tensor.to(device=device, dtype=dtype) + for ell, tensor in compute_real_wigner_matrices(ell_max, angles).items() + } diff --git a/python/metatomic_torch/metatomic/torch/_jit_compat.py b/python/metatomic_torch/metatomic/torch/_jit_compat.py new file mode 100644 index 00000000..d6466203 --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/_jit_compat.py @@ -0,0 +1,13 @@ +import functools + + +def _identity_decorator(func): + return func + + +try: + import numba as _numba +except ImportError: # pragma: no cover + jit = _identity_decorator +else: + jit = functools.partial(_numba.njit, cache=True) diff --git a/python/metatomic_torch/tests/augmentation.py b/python/metatomic_torch/tests/augmentation.py new file mode 100644 index 00000000..9bcaba8e --- /dev/null +++ b/python/metatomic_torch/tests/augmentation.py @@ -0,0 +1,739 @@ +import numpy as np +import pytest +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + +from metatomic.torch import ( + NeighborListOptions, + System, + apply_transformations, + random_rotations, + register_autograd_neighbors, +) +from metatomic.torch._augmentation import _apply_transformations, _rotations_to_zyz +from metatomic.torch._augmentation._wigner import compute_wigner_batch + + +def _axis_angle(axis, theta): + """A general (non-degenerate, beta != 0) rotation matrix from axis and angle.""" + axis = np.asarray(axis, dtype=float) + axis = axis / np.linalg.norm(axis) + x, y, z = axis + c, s = np.cos(theta), np.sin(theta) + one_minus_c = 1.0 - c + return np.array( + [ + [ + c + x * x * one_minus_c, + x * y * one_minus_c - z * s, + x * z * one_minus_c + y * s, + ], + [ + y * x * one_minus_c + z * s, + c + y * y * one_minus_c, + y * z * one_minus_c - x * s, + ], + [ + z * x * one_minus_c - y * s, + z * y * one_minus_c + x * s, + c + z * z * one_minus_c, + ], + ] + ) + + +# Change of basis from Cartesian (x, y, z) to real ell=1 spherical harmonics, ordered +# (m=-1, 0, +1) = (y, z, x). The real ell=1 Wigner-D matrix satisfies D1 = C @ R @ C.T, +# which lets us cross-check the spherical path (Wigner-D) against the trivially-correct +# Cartesian path under arbitrary rotations. +_CART_TO_SPHERICAL_L1 = torch.tensor( + [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], dtype=torch.float64 +) + + +def _make_system(types, positions=None, cell=None, pbc=None): + n_atoms = len(types) + if positions is None: + positions = torch.zeros((n_atoms, 3), dtype=torch.float64) + if cell is None: + cell = torch.zeros((3, 3), dtype=torch.float64) + if pbc is None: + pbc = torch.tensor([False, False, False]) + return System( + types=torch.tensor(types, dtype=torch.int32), + positions=positions, + cell=cell, + pbc=pbc, + ) + + +def _rotation_batch(alphas): + transformations = [] + for alpha in alphas: + transformations.append( + torch.tensor( + [ + [np.cos(alpha), -np.sin(alpha), 0.0], + [np.sin(alpha), np.cos(alpha), 0.0], + [0.0, 0.0, 1.0], + ], + dtype=torch.float64, + ) + ) + + zeros = np.zeros(len(alphas)) + wigner_D_matrices = { + ell: list(matrix.unbind(0)) + for ell, matrix in compute_wigner_batch( + 1, + (np.asarray(alphas), zeros, zeros), + device=torch.device("cpu"), + dtype=torch.float64, + ).items() + } + return transformations, wigner_D_matrices + + +def _row_indices(samples, n_systems): + system_ids = samples.column("system").to(dtype=torch.long) + return [ + torch.nonzero(system_ids == system_index, as_tuple=False).reshape(-1) + for system_index in range(n_systems) + ] + + +def test_sparse_atomic_basis_rank2_augmentation_with_missing_system_rows(): + """Rank-2 spherical features rotate on each mu axis, and empty system row-groups + are a no-op. + + The single block carries two component axes (``o3_mu_1``, ``o3_mu_2``) and all of + its rows belong to system 0; system 1 contributes no rows (the "missing system + rows"). The test confirms only system 0's rows are transformed -- by + ``D1 @ v @ D1.T`` on the two mu axes -- that the empty row-group for system 1 leaves + the block untouched, and that the values actually change (guarding against an + accidental identity transform). + """ + systems = [_make_system([1, 1]), _make_system([8, 8])] + transformations, wigner_D_matrices = _rotation_batch([np.pi / 2, np.pi]) + + components = [ + Labels( + ["o3_mu_1"], + torch.arange(-1, 2, dtype=torch.int32).reshape(-1, 1), + ), + Labels( + ["o3_mu_2"], + torch.arange(-1, 2, dtype=torch.int32).reshape(-1, 1), + ), + ] + property_labels = Labels( + ["n_1", "n_2"], + torch.tensor([[0, 0]], dtype=torch.int32), + ) + values = torch.arange(18, dtype=torch.float64).reshape(2, 3, 3, 1) + tensor = TensorMap( + Labels( + ["o3_lambda_1", "o3_lambda_2", "o3_sigma_1", "o3_sigma_2", "atom_type"], + torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.int32), + ), + [ + TensorBlock( + values=values, + samples=Labels( + ["system", "atom"], + torch.tensor([[0, 0], [0, 1]], dtype=torch.int32), + ), + components=components, + properties=property_labels, + ) + ], + ) + + _, augmented_targets, _ = _apply_transformations( + systems, + {"target": tensor}, + transformations, + wigner_D_matrices, + ) + augmented = augmented_targets["target"] + + expected_values = values.clone() + rows = _row_indices(tensor.block().samples, len(systems))[0] + expected_values[rows] = torch.einsum( + "Aa,iabp,bB->iABp", + wigner_D_matrices[1][0], + values[rows], + wigner_D_matrices[1][0].T, + ) + expected = TensorMap( + tensor.keys, + [ + TensorBlock( + values=expected_values, + samples=tensor.block().samples, + components=tensor.block().components, + properties=tensor.block().properties, + ) + ], + ) + + assert augmented.block().samples == expected.block().samples + assert torch.allclose(augmented.block().values, expected.block().values, atol=1e-12) + assert not torch.allclose(augmented.block().values, values) + + +def test_system_positions_and_cell_are_rotated(): + # Non-trivial positions and cell so the rotation is observable; verifies that + # `_apply_transformations` does not silently leave the System unchanged. + positions_a = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + dtype=torch.float64, + ) + positions_b = torch.tensor( + [[0.5, 0.5, 0.0], [1.0, -1.0, 0.5]], + dtype=torch.float64, + ) + cell_a = torch.eye(3, dtype=torch.float64) * 3.0 + cell_b = torch.eye(3, dtype=torch.float64) * 4.0 + pbc = torch.tensor([True, True, True]) + systems = [ + _make_system([1, 1, 1], positions=positions_a, cell=cell_a, pbc=pbc), + _make_system([8, 8], positions=positions_b, cell=cell_b, pbc=pbc), + ] + transformations, wigner_D_matrices = _rotation_batch([np.pi / 3, np.pi / 4]) + + new_systems, _, _ = _apply_transformations( + systems, + {}, + transformations, + wigner_D_matrices, + ) + + assert len(new_systems) == 2 + for original, new, R in zip(systems, new_systems, transformations, strict=True): + assert torch.allclose(new.positions, original.positions @ R.T, atol=1e-12) + assert torch.allclose(new.cell, original.cell @ R.T, atol=1e-12) + # types and pbc must pass through unchanged + assert torch.equal(new.types, original.types) + assert torch.equal(new.pbc, original.pbc) + + +def test_random_rotations_are_orthogonal(): + rotations = random_rotations(20, device=torch.device("cpu"), dtype=torch.float64) + assert len(rotations) == 20 + identity = torch.eye(3, dtype=torch.float64) + for R in rotations: + assert R.shape == (3, 3) + assert torch.allclose(R @ R.T, identity, atol=1e-10) + assert abs(float(torch.det(R)) - 1.0) < 1e-10 + + +def test_random_rotations_include_inversions(): + # With n=100 the probability that all determinants have the same sign is 2^{-99}. + rotations = random_rotations( + 100, device=torch.device("cpu"), dtype=torch.float64, include_inversions=True + ) + dets = torch.tensor([float(torch.det(R)) for R in rotations], dtype=torch.float64) + assert torch.allclose(dets.abs(), torch.ones(100, dtype=torch.float64), atol=1e-10) + assert (dets > 0).any() and (dets < 0).any() + + +def test_apply_transformations_raises_on_length_mismatch(): + systems = [_make_system([1])] + two_rotations = [torch.eye(3, dtype=torch.float64)] * 2 + with pytest.raises(ValueError, match="one transformation per system"): + apply_transformations(systems, {}, two_rotations) + + +def test_apply_transformations_raises_on_non_orthogonal(): + systems = [_make_system([1])] + not_orthogonal = torch.tensor( + [[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], dtype=torch.float64 + ) + with pytest.raises(ValueError, match="not orthogonal"): + apply_transformations(systems, {}, [not_orthogonal]) + + +# Rotations exercising the full ZYZ decomposition + real Wigner-D path, including both +# gimbal-lock branches of `_rotations_to_zyz`: a generic beta != 0 rotation, a z-axis +# rotation (beta == 0), a 180-degree rotation about x (beta == pi), and an improper one. +_GENERAL_ROTATIONS = [ + torch.tensor(_axis_angle([1.0, 2.0, 3.0], 0.7), dtype=torch.float64), + torch.tensor(_axis_angle([-2.0, 1.0, 0.5], 2.4), dtype=torch.float64), + torch.tensor(_axis_angle([0.0, 0.0, 1.0], 0.9), dtype=torch.float64), # beta=0 + torch.tensor(_axis_angle([1.0, 0.0, 0.0], np.pi), dtype=torch.float64), # beta=pi + # improper rotation (det = -1): a proper rotation composed with inversion + -torch.tensor(_axis_angle([0.3, -1.0, 2.0], 1.1), dtype=torch.float64), +] + + +@pytest.mark.parametrize("R", _GENERAL_ROTATIONS) +def test_general_rotation_ell1_wigner_matches_cartesian(R): + # For ell=1 the real Wigner-D matrix must equal C @ R_proper @ C.T, where R_proper + # is the proper part of R. This validates the Euler decomposition and complex->real + # conversion for general (non-degenerate) rotations, independently of the rest of + # the augmentation machinery. + proper_R = R if torch.det(R) > 0 else -R + angles = _rotations_to_zyz([R]) + D = compute_wigner_batch(1, angles, device=torch.device("cpu"), dtype=torch.float64) + expected = _CART_TO_SPHERICAL_L1 @ proper_R @ _CART_TO_SPHERICAL_L1.T + assert torch.allclose(D[1][0], expected, atol=1e-12) + + +@pytest.mark.parametrize("R", _GENERAL_ROTATIONS) +def test_general_rotation_spherical_matches_cartesian_vector(R): + # End-to-end through the public API: a Cartesian vector target and a spherical ell=1 + # (sigma=1, i.e. a true vector) target encoding the same vectors via C must stay + # related by C after augmentation. Covers general rotations *and* the inversion + # parity factor for the improper case. + systems = [_make_system([1, 8])] + cartesian_vectors = torch.randn(2, 3, 1, dtype=torch.float64) + + cartesian = TensorMap( + Labels(["_"], torch.tensor([[0]], dtype=torch.int32)), + [ + TensorBlock( + values=cartesian_vectors, + samples=Labels( + ["system", "atom"], + torch.tensor([[0, 0], [0, 1]], dtype=torch.int32), + ), + components=[ + Labels(["xyz"], torch.arange(3, dtype=torch.int32).reshape(-1, 1)) + ], + properties=Labels(["p"], torch.tensor([[0]], dtype=torch.int32)), + ) + ], + ) + # spherical encoding: w = C @ v along the component axis + spherical_values = torch.einsum( + "Aa,iap->iAp", _CART_TO_SPHERICAL_L1, cartesian_vectors + ) + spherical = TensorMap( + Labels(["o3_lambda", "o3_sigma"], torch.tensor([[1, 1]], dtype=torch.int32)), + [ + TensorBlock( + values=spherical_values, + samples=Labels( + ["system", "atom"], + torch.tensor([[0, 0], [0, 1]], dtype=torch.int32), + ), + components=[ + Labels( + ["o3_mu"], torch.arange(-1, 2, dtype=torch.int32).reshape(-1, 1) + ) + ], + properties=Labels(["p"], torch.tensor([[0]], dtype=torch.int32)), + ) + ], + ) + + _, out, _ = apply_transformations( + systems, {"cart": cartesian, "spher": spherical}, [R] + ) + + rotated_cart = out["cart"].block().values + rotated_spher = out["spher"].block().values + expected_spher = torch.einsum("Aa,iap->iAp", _CART_TO_SPHERICAL_L1, rotated_cart) + assert torch.allclose(rotated_spher, expected_spher, atol=1e-12) + + +def test_scalar_energy_gradients_are_rotated(): + # A scalar (energy-like) block carrying positions and strain gradients across two + # systems. Positions gradients transform as vectors (R @ g); strain gradients as + # rank-2 Cartesian tensors (R @ S @ R.T). Verifies gradient support and per-system + # routing through the parent block's "system" column. + systems = [ + _make_system([1, 1]), + _make_system([8, 8, 8]), + ] + R0 = torch.tensor(_axis_angle([1.0, 2.0, 3.0], 0.7), dtype=torch.float64) + R1 = torch.tensor(_axis_angle([0.0, 1.0, 1.0], 1.9), dtype=torch.float64) + + values = torch.tensor([[1.0], [2.0]], dtype=torch.float64) + pos_grad = torch.randn(5, 3, 1, dtype=torch.float64) # 2 + 3 atoms + strain_grad = torch.randn(2, 3, 3, 1, dtype=torch.float64) + + block = TensorBlock( + values=values, + samples=Labels(["system"], torch.tensor([[0], [1]], dtype=torch.int32)), + components=[], + properties=Labels(["energy"], torch.tensor([[0]], dtype=torch.int32)), + ) + block.add_gradient( + "positions", + TensorBlock( + values=pos_grad, + samples=Labels( + ["sample", "atom"], + torch.tensor( + [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]], dtype=torch.int32 + ), + ), + components=[ + Labels(["xyz"], torch.arange(3, dtype=torch.int32).reshape(-1, 1)) + ], + properties=Labels(["energy"], torch.tensor([[0]], dtype=torch.int32)), + ), + ) + block.add_gradient( + "strain", + TensorBlock( + values=strain_grad, + samples=Labels(["sample"], torch.tensor([[0], [1]], dtype=torch.int32)), + components=[ + Labels(["xyz_1"], torch.arange(3, dtype=torch.int32).reshape(-1, 1)), + Labels(["xyz_2"], torch.arange(3, dtype=torch.int32).reshape(-1, 1)), + ], + properties=Labels(["energy"], torch.tensor([[0]], dtype=torch.int32)), + ), + ) + tensor = TensorMap(Labels(["_"], torch.tensor([[0]], dtype=torch.int32)), [block]) + + _, out, _ = apply_transformations(systems, {"energy": tensor}, [R0, R1]) + out_block = out["energy"].block() + + # scalar values unchanged + assert torch.allclose(out_block.values, values) + + # positions gradients: rows 0,1 -> R0, rows 2,3,4 -> R1 + expected_pos = pos_grad.clone() + expected_pos[:2] = torch.einsum("Aa,iap->iAp", R0, pos_grad[:2]) + expected_pos[2:] = torch.einsum("Aa,iap->iAp", R1, pos_grad[2:]) + assert torch.allclose(out_block.gradient("positions").values, expected_pos) + + # strain gradients: row 0 -> R0 S R0.T, row 1 -> R1 S R1.T + expected_strain = strain_grad.clone() + expected_strain[0] = torch.einsum("Aa,abp,Bb->ABp", R0, strain_grad[0], R0) + expected_strain[1] = torch.einsum("Aa,abp,Bb->ABp", R1, strain_grad[1], R1) + assert torch.allclose(out_block.gradient("strain").values, expected_strain) + + +def test_positions_gradient_routes_by_parent_system_label(): + """Positions gradients follow the parent block's "system" label, not row position. + + The parent energy rows are given in non-sorted label order ([1, 0]); each gradient + row must be rotated by the transformation of the system named in *its parent row's* + "system" label, exactly as the values are routed. A positional routing (pairing + gradient ``sample == i`` with ``transformations[i]``) would mis-pair the gradient + rows with the systems whenever the parent rows are not sorted by label. + """ + # systems are passed sorted by label: systems[0] <-> label 0, systems[1] <-> label 1 + systems = [_make_system([1, 1]), _make_system([8, 8, 8])] + R0 = torch.tensor(_axis_angle([1.0, 2.0, 3.0], 0.7), dtype=torch.float64) + R1 = torch.tensor(_axis_angle([0.0, 1.0, 1.0], 1.9), dtype=torch.float64) + + # parent rows in non-sorted order: row 0 -> system label 1, row 1 -> system label 0 + values = torch.tensor([[1.0], [2.0]], dtype=torch.float64) + block = TensorBlock( + values=values, + samples=Labels(["system"], torch.tensor([[1], [0]], dtype=torch.int32)), + components=[], + properties=Labels(["energy"], torch.tensor([[0]], dtype=torch.int32)), + ) + # sample 0 -> parent row 0 (label 1, 3 atoms); sample 1 -> parent row 1 (label 0, 2) + pos_grad = torch.randn(5, 3, 1, dtype=torch.float64) + block.add_gradient( + "positions", + TensorBlock( + values=pos_grad, + samples=Labels( + ["sample", "atom"], + torch.tensor( + [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], dtype=torch.int32 + ), + ), + components=[ + Labels(["xyz"], torch.arange(3, dtype=torch.int32).reshape(-1, 1)) + ], + properties=Labels(["energy"], torch.tensor([[0]], dtype=torch.int32)), + ), + ) + tensor = TensorMap(Labels(["_"], torch.tensor([[0]], dtype=torch.int32)), [block]) + + _, out, _ = apply_transformations(systems, {"energy": tensor}, [R0, R1]) + grad = out["energy"].block().gradient("positions").values + + expected = pos_grad.clone() + # parent row 0 has label 1 -> R1 (gradient rows with sample == 0: indices 0,1,2) + expected[:3] = torch.einsum("Aa,iap->iAp", R1, pos_grad[:3]) + # parent row 1 has label 0 -> R0 (gradient rows with sample == 1: indices 3,4) + expected[3:] = torch.einsum("Aa,iap->iAp", R0, pos_grad[3:]) + assert torch.allclose(grad, expected, atol=1e-12) + + +def test_unsupported_component_axis_raises(): + # A component axis that is neither Cartesian nor spherical must raise rather than be + # silently passed through unchanged. + systems = [_make_system([1])] + tensor = TensorMap( + Labels(["_"], torch.tensor([[0]], dtype=torch.int32)), + [ + TensorBlock( + values=torch.zeros(1, 3, 1, dtype=torch.float64), + samples=Labels(["system"], torch.tensor([[0]], dtype=torch.int32)), + components=[ + Labels( + ["direction"], torch.arange(3, dtype=torch.int32).reshape(-1, 1) + ) + ], + properties=Labels(["p"], torch.tensor([[0]], dtype=torch.int32)), + ) + ], + ) + with pytest.raises(ValueError, match="neither a Cartesian"): + apply_transformations( + systems, {"bad": tensor}, [torch.eye(3, dtype=torch.float64)] + ) + + +def test_neighbor_list_vectors_are_rotated(): + R = torch.tensor(_axis_angle([1.0, 2.0, 3.0], 0.7), dtype=torch.float64) + system = _make_system( + [1, 1], + positions=torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], dtype=torch.float64), + ) + options = NeighborListOptions(cutoff=2.0, full_list=True, strict=False) + vectors = torch.tensor( + [[[1.0], [0.0], [0.0]]], dtype=torch.float64 + ) # (1 pair, 3, 1) + neighbors = TensorBlock( + values=vectors.clone(), + samples=Labels( + [ + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ], + torch.tensor([[0, 1, 0, 0, 0]], dtype=torch.int32), + ), + components=[Labels(["xyz"], torch.arange(3, dtype=torch.int32).reshape(-1, 1))], + properties=Labels(["distance"], torch.tensor([[0]], dtype=torch.int32)), + ) + register_autograd_neighbors(system, neighbors) + system.add_neighbor_list(options, neighbors) + + new_systems, _, _ = apply_transformations([system], {}, [R]) + + new_vectors = new_systems[0].get_neighbor_list(options).values + expected = (vectors.squeeze(-1) @ R.T).unsqueeze(-1) + assert torch.allclose(new_vectors, expected, atol=1e-12) + + +def test_system_custom_data_is_rotated_by_tensor_type(): + """Registered System data is rotated by the same machinery as targets. + + Scalar blocks must pass through unchanged, ``xyz`` vector blocks must rotate by + ``R``, and a multi-block TensorMap must be handled block-by-block (the previous + implementation rejected anything other than a single block). This is the path + exercised when a model carries per-atom geometric quantities (e.g. local frames) as + System data that has to follow the rotation of the structure. + """ + R = torch.tensor(_axis_angle([0.3, -0.7, 0.5], 0.9), dtype=torch.float64) + system = _make_system( + [1, 8], + positions=torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float64), + ) + + # scalar per-atom data spread over two blocks: both must pass through untouched + scalar = TensorMap( + keys=Labels(["block"], torch.tensor([[0], [1]], dtype=torch.int32)), + blocks=[ + TensorBlock( + values=torch.tensor([[1.0], [2.0]], dtype=torch.float64), + samples=Labels.range("atom", 2), + components=[], + properties=Labels.range("p", 1), + ), + TensorBlock( + values=torch.tensor([[3.0], [4.0]], dtype=torch.float64), + samples=Labels.range("atom", 2), + components=[], + properties=Labels.range("p", 1), + ), + ], + ) + # xyz vector per-atom data: must rotate by R on the component axis + vector_values = torch.tensor( + [[[1.0], [0.0], [0.0]], [[0.0], [2.0], [0.0]]], dtype=torch.float64 + ) # (2 atoms, 3, 1) + vector = TensorMap( + keys=Labels.single(), + blocks=[ + TensorBlock( + values=vector_values.clone(), + samples=Labels.range("atom", 2), + components=[ + Labels(["xyz"], torch.arange(3, dtype=torch.int32).reshape(-1, 1)) + ], + properties=Labels.range("p", 1), + ) + ], + ) + system.add_data("custom::scalar", scalar) + system.add_data("custom::vector", vector) + + new_systems, _, _ = apply_transformations([system], {}, [R]) + new_system = new_systems[0] + + new_scalar = new_system.get_data("custom::scalar") + for block_id in range(len(scalar.keys)): + assert torch.allclose( + new_scalar.block_by_id(block_id).values, + scalar.block_by_id(block_id).values, + ) + + new_vector = new_system.get_data("custom::vector").block().values + expected_vector = (vector_values.squeeze(-1) @ R.T).unsqueeze(-1) + assert torch.allclose(new_vector, expected_vector, atol=1e-12) + # sanity: the rotation actually changed the vector data + assert not torch.allclose(new_vector, vector_values) + + +def test_spherical_system_data_is_passed_through_unrotated(): + """Spherical System data is allowed but not rotated (no Wigner-D computed for it). + + Attaching an ``o3_mu`` block to a System via ``add_data`` is permitted, but the + augmentation does not build Wigner-D matrices for System data. Rather than crash or + rotate it incorrectly, such data must be passed through unchanged while the geometry + is still rotated -- so a model relying on it is responsible for its equivariance. + """ + R = torch.tensor(_axis_angle([0.2, 0.5, -0.8], 1.1), dtype=torch.float64) + system = _make_system( + [1, 8], + positions=torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float64), + ) + spherical_values = torch.tensor( + [[[1.0], [2.0], [3.0]], [[4.0], [5.0], [6.0]]], dtype=torch.float64 + ) # (2 atoms, 2*ell+1 = 3, 1) + mu = Labels(["o3_mu"], torch.arange(-1, 2, dtype=torch.int32).reshape(-1, 1)) + spherical = TensorMap( + keys=Labels( + ["o3_lambda", "o3_sigma"], torch.tensor([[1, 1]], dtype=torch.int32) + ), + blocks=[ + TensorBlock( + values=spherical_values.clone(), + samples=Labels.range("atom", 2), + components=[mu], + properties=Labels.range("p", 1), + ) + ], + ) + system.add_data("custom::spherical", spherical) + + # no targets -> no Wigner-D matrices are computed; this would KeyError if the + # spherical block were (incorrectly) sent through the rotation path + new_systems, _, _ = apply_transformations([system], {}, [R]) + + new_spherical = new_systems[0].get_data("custom::spherical").block().values + assert torch.allclose(new_spherical, spherical_values) + # the geometry itself is still rotated + assert torch.allclose(new_systems[0].positions, system.positions @ R.T, atol=1e-12) + + +def test_random_rotations_generator_is_reproducible(): + g1 = torch.Generator().manual_seed(12345) + g2 = torch.Generator().manual_seed(12345) + a = random_rotations( + 8, + device=torch.device("cpu"), + dtype=torch.float64, + include_inversions=True, + generator=g1, + ) + b = random_rotations( + 8, + device=torch.device("cpu"), + dtype=torch.float64, + include_inversions=True, + generator=g2, + ) + for Ra, Rb in zip(a, b, strict=True): + assert torch.equal(Ra, Rb) + + +def test_apply_transformations_raises_on_dtype_mismatch(): + systems = [_make_system([1, 8])] + R = torch.eye( + 3, dtype=torch.float32 + ) # mismatched vs float64 target/positions below + tensor = TensorMap( + Labels(["_"], torch.tensor([[0]], dtype=torch.int32)), + [ + TensorBlock( + values=torch.zeros(2, 3, 1, dtype=torch.float64), + samples=Labels( + ["system", "atom"], + torch.tensor([[0, 0], [0, 1]], dtype=torch.int32), + ), + components=[ + Labels(["xyz"], torch.arange(3, dtype=torch.int32).reshape(-1, 1)) + ], + properties=Labels(["p"], torch.tensor([[0]], dtype=torch.int32)), + ) + ], + ) + with pytest.raises(ValueError, match="dtype/device"): + apply_transformations(systems, {"t": tensor}, [R]) + + +def test_arbitrary_system_labels_are_remapped(): + # metatrain keeps each structure's original dataset index in the "system" column, + # so a 2-system batch can be labelled e.g. [92, 38] rather than [0, 1]. The i-th + # sorted label maps to system i: 38 -> system 0 (R0), 92 -> system 1 (R1). + systems = [_make_system([1, 8]), _make_system([1, 8])] + R0 = torch.tensor(_axis_angle([1.0, 2.0, 3.0], 0.7), dtype=torch.float64) + R1 = torch.tensor(_axis_angle([0.0, 1.0, 1.0], 1.9), dtype=torch.float64) + vectors = torch.randn(2, 3, 1, dtype=torch.float64) + tensor = TensorMap( + Labels(["_"], torch.tensor([[0]], dtype=torch.int32)), + [ + TensorBlock( + values=vectors, + samples=Labels( + ["system", "atom"], + torch.tensor([[92, 0], [38, 0]], dtype=torch.int32), + ), + components=[ + Labels(["xyz"], torch.arange(3, dtype=torch.int32).reshape(-1, 1)) + ], + properties=Labels(["p"], torch.tensor([[0]], dtype=torch.int32)), + ) + ], + ) + _, out, _ = apply_transformations(systems, {"t": tensor}, [R0, R1]) + result = out["t"].block().values + # row 0 has label 92 (sorted position 1) -> R1; row 1 has label 38 -> R0 + expected = vectors.clone() + expected[0] = torch.einsum("Aa,ap->Ap", R1, vectors[0]) + expected[1] = torch.einsum("Aa,ap->Ap", R0, vectors[1]) + assert torch.allclose(result, expected, atol=1e-12) + + +def test_too_many_distinct_system_ids_raises(): + # More distinct "system" labels than systems is genuinely ambiguous and must raise. + systems = [_make_system([1]), _make_system([8])] + R = torch.eye(3, dtype=torch.float64) + tensor = TensorMap( + Labels(["o3_lambda", "o3_sigma"], torch.tensor([[1, 1]], dtype=torch.int32)), + [ + TensorBlock( + values=torch.zeros(3, 3, 1, dtype=torch.float64), + samples=Labels( + ["system", "atom"], + torch.tensor([[10, 0], [20, 0], [30, 0]], dtype=torch.int32), + ), + components=[ + Labels( + ["o3_mu"], torch.arange(-1, 2, dtype=torch.int32).reshape(-1, 1) + ) + ], + properties=Labels(["p"], torch.tensor([[0]], dtype=torch.int32)), + ) + ], + ) + with pytest.raises(ValueError, match="distinct system indices"): + apply_transformations(systems, {"t": tensor}, [R, R])