Fix serialization crash & docstring errors; add opt-in GPU memory optimization (fused OuterProductMean)#674
Fix serialization crash & docstring errors; add opt-in GPU memory optimization (fused OuterProductMean)#674mooreneural wants to merge 6 commits into
Conversation
- Fix 'indicatating' -> 'indicating' in msa.py Error class docstring - Fix 'Jackhmmer deduplication logic' -> 'Nhmmer deduplication logic' in Nhmmer.__init__ docstring (copy-paste from jackhmmer.py)
`list(values) or None` converts an empty template map to `None` in the JSON output. `from_dict` then calls `zip(queryIndices, None)` which raises `TypeError: 'NoneType' object is not iterable`. The fix removes the `or None` so an empty map correctly serializes as `[]`, which is consistent with how `queryIndices` is already written.
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
|
@googlebot rescan |
…ng, XLA cache New module: alphafold3.model.gpu - fused_ops.py: jax.lax.scan-based OuterProductMean that avoids the large [N, C_outer, C_outer, chunk] intermediate tensor. Reduces peak memory ~6x (e.g. 268 MB -> ~40 MB at N=1024) with mathematically identical output. - parallel.py: pmap-based utilities for splitting diffusion samples across multiple GPUs, with automatic sample-count padding and result gathering. - xla_cache.py: persistent XLA compilation cache setup; eliminates the 5-15 minute cold-start recompile on subsequent runs. Changes to existing files: - model_config.py: add GlobalConfig.use_fused_outer_product_scan (default False) and GlobalConfig.log_device_info flags. - network/modules.py: OuterProductMean.__call__ checks global_config and dispatches to fused_ops.fused_outer_product_chunk when the flag is set. New file: benchmarks/gpu_benchmark.py - Standalone benchmark script comparing standard vs. fused OuterProductMean, TriangleMultiplication throughput, diffusion step timing, and device info. Usage: python benchmarks/gpu_benchmark.py [--benchmark NAME] [--num_res N]
tests/test_xla_cache.py (9 tests): - Directory creation, idempotency, clear, size measurement - JAX config key injection via stub - RuntimeError when JAX.config is absent tests/test_parallel.py (13 tests, 1 skipped without JAX): - round_samples_to_devices: already-divisible, rounds-up, single-device - _split_along_leading_axis: shape 2D/3D, non-divisible raises ValueError - _concat_along_leading_axis: inverts split, preserves values tests/test_fused_ops.py (11 tests, skipped without JAX): - Numerical equivalence in float32/float16/bfloat16 - Edge cases: zero inputs, single MSA row, single channel - Large-sequence correctness at N=1024 - JIT compilability and gradient compatibility tests/test_model_config_gpu_fields.py (10 tests): - New GlobalConfig fields exist with correct defaults - Pre-existing fields not accidentally removed - modules.py imports fused_ops and uses global_config flag Also fix: parallel.py used jnp.ndarray as a type annotation. jnp.ndarray does not exist in JAX (use np.ndarray for plain type hints).
1. modules.py: make fused_ops import lazy (was top-level, broke CPU-only envs)
The 'from alphafold3.model.gpu import fused_ops' import now lives inside
the 'if global_config.use_fused_outer_product_scan:' branch so environments
without the gpu module are unaffected when the flag is False (the default).
2. fused_ops.py: replace jax.lax.scan with a single 3-way einsum
The original scan over 32 C_outer iterations added per-kernel-launch
overhead on GPU. A single einsum('acb,ade,cef->dbf', left_T, right, W)
lets XLA choose its own contraction order (typically contracting both
C_outer dims first, ~32 MB intermediate vs ~268 MB) with no Python loop.
Mathematically identical; faster and simpler.
3. xla_cache.py: document AF3 Issue google-deepmind#468 (cross-process cache limitation)
AF3's own issue tracker notes that the persistent XLA compilation cache
does not reliably persist across separate Python processes. The module
is kept but both the module docstring and setup_compilation_cache()
docstring now reference the issue so users are not misled.
4. parallel.py: add EXPERIMENTAL / NOT YET INTEGRATED notice
The pmap utilities are not wired into model._sample_diffusion() and
require hk.lift() for correct Haiku state handling under nested jax.pmap.
Docstring updated to document this clearly.
5. test_model_config_gpu_fields.py: update import test
Now asserts that fused_ops is NOT imported at module scope (previously
asserted it IS imported at module scope — which was the bug we fixed).
Manual testing: serialization bug fix (
|
This PR contains three independent bug fixes and an opt-in GPU memory optimization. All new features default to
Falsethe original code path is unchanged for anyone not opting in.Bug Fixes
1. Fix
templateIndicesserialization crash infolding_input.py(
src/alphafold3/common/folding_input.py)When a
ProteinChaintemplate has an emptyquery_to_template_map, the previous code serialized it asnull. On deserialization,zip(queryIndices, None)raised aTypeError. Fixed by always serializing the list (empty list instead ofnull).2. Fix typo in
msa.py(
src/alphafold3/data/msa.py)"indicatating"→"indicating"in an error-message docstring.3. Fix copy-paste error in
nhmmer.py(
src/alphafold3/data/tools/nhmmer.py)"Jackhmmer deduplication logic"→"Nhmmer deduplication logic"inNhmmer.__init__'s docstring.GPU Memory Optimization: Fused OuterProductMean (opt-in)
Adds an opt-in code path that replaces the two-einsum
OuterProductMeancomputation with a single three-way einsum, allowing XLA to choose a lower-peak-memory contraction order.Standard path (two einsums, ~268 MB intermediate at N=1024, C_outer=32, chunk=128, bfloat16):