Skip to content

Implement augmentation routines in metatomic#264

Open
ppegolo wants to merge 3 commits into
mainfrom
augmentation
Open

Implement augmentation routines in metatomic#264
ppegolo wants to merge 3 commits into
mainfrom
augmentation

Conversation

@ppegolo

@ppegolo ppegolo commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Augmentation is currently implemented in metatrain, and it depends on spherical for augmentation of spherical targets.
In prepraration for the SymmetrizedModel PR (currently #119, that would depend on metatrain just for the augmentation routines) I re-implemented the relevant parts of spherical (MIT license) here, getting rid of unnecessary stuff and most dependencies, and all the necessary functions that handle augmentation.

Contributor (creator of pull-request) checklist

  • Tests updated (for new features and bugfixes)?
  • Documentation updated (for new features)?
  • Issue referenced (for PRs that solve an issue)?

Reviewer checklist

  • CHANGELOG updated with public API or any other important changes?

📚 Download documentation for this pull-request

@ppegolo ppegolo marked this pull request as draft June 17, 2026 11:13
@jwa7 jwa7 marked this pull request as ready for review June 26, 2026 11:35
@jwa7 jwa7 requested a review from Luthaf June 26, 2026 11:35

@Luthaf Luthaf left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall!

@@ -0,0 +1,527 @@
"""Private Wigner-d/Wigner-D helpers for symmetry operations.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wondering if this would be better inside https://github.com/Luthaf/wigners? but OK for now

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe for later, yes, unless you think it's trivial?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trivial is a strong word, but I might give it a quick try. Otherwise we can merge this and change it later

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe move this to metateomic/torch/_augmentation/_wigner.py

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit hard to follow, would be nice to make it clearer/better documented

from ._wigner import compute_wigner_batch


def _block_row_indices_by_system(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this faster than doing samples.select(Labels("system", [[i]]))?

if len(data) != 1:
raise ValueError(
f"System data '{data_name}' has {len(data)} blocks; "
"only single-block data is supported."

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No reason, now it's fine

Comment on lines +83 to +98
new_system.add_data(
data_name,
TensorMap(
keys=data.keys,
blocks=[
TensorBlock(
values=(
block.values.swapaxes(-1, -2) @ transformation.T
).swapaxes(-1, -2),
samples=block.samples,
components=block.components,
properties=block.properties,
)
],
),
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should use _transform_tmap

transformations: list[torch.Tensor],
extra_data: dict[str, TensorMap] | None = None,
) -> tuple[list[System], dict[str, TensorMap], dict[str, TensorMap]]:
"""Apply a batch of O(3) transformations to systems and TensorMaps simultaneously.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the transformations are rotation + inversion, right? I'm wondering if a better API would be expressed in term of transformations or rotations, instead of calling these "augmentations".

What you can do in the end is the same, just makes it a bit clearer when looking for functionality

@ppegolo ppegolo force-pushed the augmentation branch 2 times, most recently from 5e6bae5 to adfce14 Compare June 26, 2026 14:53
ppegolo added 3 commits June 26, 2026 21:48
Real Wigner-D matrices for O(3), used to rotate spherical (o3_mu) tensors.
`compute_wigner_batch` builds a batch of real D matrices up to a given
angular momentum from ZYZ Euler angles via the standard recursion. A small
TorchScript-compatibility `jit` helper (`_jit_compat`) lets the recursion be
scripted.
`apply_transformations` applies a batch of per-system O(3) matrices (proper
or improper rotations) to a list of Systems and their target/extra-data
TensorMaps simultaneously, deriving the needed Wigner-D matrices from the
matrices themselves; `random_rotations` samples a uniform O(3) batch.

Each component axis is transformed by tensor type inferred from its name:
Cartesian axes (xyz/xyz_1/xyz_2) are contracted with R directly (so improper
rotations flip them), and spherical axes (o3_mu/_1/_2) with the Wigner-D
matrix of their o3_lambda, plus a (-1)^l * sigma parity per spherical axis
when R is improper. A single TensorMap may mix scalar, Cartesian and
spherical blocks, and gradients are transformed as blocks with extra axes.

Value rows are routed to systems by their "system" label (remapping
arbitrary dataset indices onto the provided systems); gradients are routed
by the parent block's "system" label via the gradient "sample" column.
System geometry, registered per-atom data and neighbor-list vectors are
rotated too. The public entry point validates that the transformations are
3x3 and orthogonal and that transformations, systems and TensorMaps share a
dtype and device.
Cross-checks the spherical (Wigner-D) path against the trivially-correct
Cartesian path for general rotations, including both gimbal-lock branches of
the ZYZ decomposition (beta = 0 and beta = pi) and improper rotations.
Covers gradient rotation and per-system routing (including unsorted parent
"system" labels and the arbitrary-dataset-index remap), System geometry,
registered scalar/Cartesian/spherical data, neighbor-list vectors,
`random_rotations` (orthogonality, inversions, reproducible generator) and
the public-API validation errors.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants