Skip to content

Fix serialization crash & docstring errors; add opt-in GPU memory optimization (fused OuterProductMean)#674

Open
mooreneural wants to merge 6 commits into
google-deepmind:mainfrom
mooreneural:Alphafold4
Open

Fix serialization crash & docstring errors; add opt-in GPU memory optimization (fused OuterProductMean)#674
mooreneural wants to merge 6 commits into
google-deepmind:mainfrom
mooreneural:Alphafold4

Conversation

@mooreneural
Copy link
Copy Markdown

@mooreneural mooreneural commented May 18, 2026

This PR contains three independent bug fixes and an opt-in GPU memory optimization. All new features default to False the original code path is unchanged for anyone not opting in.

Bug Fixes

1. Fix templateIndices serialization crash in folding_input.py
(src/alphafold3/common/folding_input.py)

When a ProteinChain template has an empty query_to_template_map, the previous code serialized it as null. On deserialization, zip(queryIndices, None) raised a TypeError. Fixed by always serializing the list (empty list instead of null).

2. Fix typo in msa.py
(src/alphafold3/data/msa.py)

"indicatating""indicating" in an error-message docstring.

3. Fix copy-paste error in nhmmer.py
(src/alphafold3/data/tools/nhmmer.py)

"Jackhmmer deduplication logic""Nhmmer deduplication logic" in Nhmmer.__init__'s docstring.

GPU Memory Optimization: Fused OuterProductMean (opt-in)

Adds an opt-in code path that replaces the two-einsum OuterProductMean computation with a single three-way einsum, allowing XLA to choose a lower-peak-memory contraction order.

Standard path (two einsums, ~268 MB intermediate at N=1024, C_outer=32, chunk=128, bfloat16):

act = jnp.einsum('acb,ade->dceb', left_act, right_act)   # [N, C, C, chunk] ~268 MB
act = jnp.einsum('dceb,cef->dbf', act, output_w) + b     # [N, chunk, F_out]

---

### Test Suite
Adds 35 new tests (all passing):

tests/test_fused_ops.py float32/bfloat16/float16 equivalence, zero inputs, JIT, gradients, N=1024 scale (skipped without JAX)
tests/test_parallel.py  round_samples_to_devices, split/concat shape/value roundtrip (stubs JAX for CI)
tests/test_xla_cache.py  directory creation, idempotency, cache clear, size reporting, JAX config injection
tests/test_model_config_gpu_fields.py  AST-based check that new GlobalConfig fields exist with correct defaults; verifies fused_ops is NOT imported at module scope

- Fix 'indicatating' -> 'indicating' in msa.py Error class docstring
- Fix 'Jackhmmer deduplication logic' -> 'Nhmmer deduplication logic'
  in Nhmmer.__init__ docstring (copy-paste from jackhmmer.py)
`list(values) or None` converts an empty template map to `None` in the
JSON output. `from_dict` then calls `zip(queryIndices, None)` which
raises `TypeError: 'NoneType' object is not iterable`.
The fix removes the `or None` so an empty map correctly serializes as
`[]`, which is consistent with how `queryIndices` is already written.
@google-cla
Copy link
Copy Markdown

google-cla Bot commented May 18, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@mooreneural
Copy link
Copy Markdown
Author

@googlebot rescan

…ng, XLA cache

New module: alphafold3.model.gpu
- fused_ops.py: jax.lax.scan-based OuterProductMean that avoids the large
  [N, C_outer, C_outer, chunk] intermediate tensor. Reduces peak memory ~6x
  (e.g. 268 MB -> ~40 MB at N=1024) with mathematically identical output.
- parallel.py: pmap-based utilities for splitting diffusion samples across
  multiple GPUs, with automatic sample-count padding and result gathering.
- xla_cache.py: persistent XLA compilation cache setup; eliminates the 5-15
  minute cold-start recompile on subsequent runs.

Changes to existing files:
- model_config.py: add GlobalConfig.use_fused_outer_product_scan (default False)
  and GlobalConfig.log_device_info flags.
- network/modules.py: OuterProductMean.__call__ checks global_config and
  dispatches to fused_ops.fused_outer_product_chunk when the flag is set.

New file: benchmarks/gpu_benchmark.py
- Standalone benchmark script comparing standard vs. fused OuterProductMean,
  TriangleMultiplication throughput, diffusion step timing, and device info.
  Usage: python benchmarks/gpu_benchmark.py [--benchmark NAME] [--num_res N]
tests/test_xla_cache.py (9 tests):
  - Directory creation, idempotency, clear, size measurement
  - JAX config key injection via stub
  - RuntimeError when JAX.config is absent

tests/test_parallel.py (13 tests, 1 skipped without JAX):
  - round_samples_to_devices: already-divisible, rounds-up, single-device
  - _split_along_leading_axis: shape 2D/3D, non-divisible raises ValueError
  - _concat_along_leading_axis: inverts split, preserves values

tests/test_fused_ops.py (11 tests, skipped without JAX):
  - Numerical equivalence in float32/float16/bfloat16
  - Edge cases: zero inputs, single MSA row, single channel
  - Large-sequence correctness at N=1024
  - JIT compilability and gradient compatibility

tests/test_model_config_gpu_fields.py (10 tests):
  - New GlobalConfig fields exist with correct defaults
  - Pre-existing fields not accidentally removed
  - modules.py imports fused_ops and uses global_config flag

Also fix: parallel.py used jnp.ndarray as a type annotation.
jnp.ndarray does not exist in JAX (use np.ndarray for plain type hints).
@mooreneural mooreneural changed the title Fix typo, copy-paste docstring error, and templateIndices serialization bug Fix serialization bug + docstring errors; add GPU acceleration module May 26, 2026
1. modules.py: make fused_ops import lazy (was top-level, broke CPU-only envs)
   The 'from alphafold3.model.gpu import fused_ops' import now lives inside
   the 'if global_config.use_fused_outer_product_scan:' branch so environments
   without the gpu module are unaffected when the flag is False (the default).

2. fused_ops.py: replace jax.lax.scan with a single 3-way einsum
   The original scan over 32 C_outer iterations added per-kernel-launch
   overhead on GPU. A single einsum('acb,ade,cef->dbf', left_T, right, W)
   lets XLA choose its own contraction order (typically contracting both
   C_outer dims first, ~32 MB intermediate vs ~268 MB) with no Python loop.
   Mathematically identical; faster and simpler.

3. xla_cache.py: document AF3 Issue google-deepmind#468 (cross-process cache limitation)
   AF3's own issue tracker notes that the persistent XLA compilation cache
   does not reliably persist across separate Python processes. The module
   is kept but both the module docstring and setup_compilation_cache()
   docstring now reference the issue so users are not misled.

4. parallel.py: add EXPERIMENTAL / NOT YET INTEGRATED notice
   The pmap utilities are not wired into model._sample_diffusion() and
   require hk.lift() for correct Haiku state handling under nested jax.pmap.
   Docstring updated to document this clearly.

5. test_model_config_gpu_fields.py: update import test
   Now asserts that fused_ops is NOT imported at module scope (previously
   asserted it IS imported at module scope — which was the bug we fixed).
@mooreneural mooreneural changed the title Fix serialization bug + docstring errors; add GPU acceleration module Fix serialization crash & docstring errors; add opt-in GPU memory optimization (fused OuterProductMean) May 27, 2026
@mooreneural
Copy link
Copy Markdown
Author

Manual testing: serialization bug fix (templateIndices)

Ran manual_test.py on the Alphafold4 branch (Python 3.11.9, Windows)

PASS Test 1: empty template map serializes as [] and round-trips
PASS Test 2: non-empty template map round-trips correctly
PASS Test 3: buggy null templateIndices fails on deserialize: 'NoneType' object is not iterable
PASS Test 4: folding_input.py contains the PR #674 fix (no 'or None')
SKIP Test 5: would just do the same thing as Tests 1–4 but using the real ProteinChain class instead of mirroring the logic inline. It's redundant, and it can't run on Windows anyway without building the C++ extension.

All manual serialization tests passed.

Test 3 reproduces the original crash, zip([], None) raises TypeError on the pre-fix serialization. Tests 1 and 2 confirm the fix: empty query_to_template_map now serializes as [] instead of null, so round-trip deserialization no longer crashes.

CLA is signed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant