Skip to content

feat(detection): add handler for remat2 primitive#141

Merged
adrhill merged 1 commit into
mainfrom
feat/remat2-handler
Jun 1, 2026
Merged

feat(detection): add handler for remat2 primitive#141
adrhill merged 1 commit into
mainfrom
feat/remat2-handler

Conversation

@adrhill
Copy link
Copy Markdown
Owner

@adrhill adrhill commented May 28, 2026

Summary

  • Add support for jax.checkpoint / jax.remat by handling the remat2 primitive
  • The primitive uses the same "jaxpr" parameter key as jit/pjit, so it reuses the existing prop_closed_jaxpr implementation

Test plan

  • test_remat2_checkpoint — sparsity detection through checkpoint
  • test_remat2_closure_captured_index — closure-captured indices resolve precisely
  • test_remat2_decompression — full Jacobian computation works

Closes #118

🤖 Generated with Claude Code

Add support for jax.checkpoint/jax.remat by handling the remat2 primitive.
The primitive uses the same "jaxpr" parameter key as jit/pjit,
so it reuses the existing prop_closed_jaxpr implementation.

Closes #118

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@codecov-commenter
Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 93.14%. Comparing base (b5b3c54) to head (0a8be0b).

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #141   +/-   ##
=======================================
  Coverage   93.14%   93.14%           
=======================================
  Files          53       53           
  Lines        3518     3518           
=======================================
  Hits         3277     3277           
  Misses        241      241           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds sparsity-detection support for jax.checkpoint / jax.remat by handling the remat2 JAX primitive as a nested-jaxpr call, consistent with how jit/pjit/named_call are already handled in the interpreter.

Changes:

  • Dispatch remat2 equations through the existing prop_closed_jaxpr(..., param_key="jaxpr") path.
  • Add interpreter tests covering checkpoint/remat sparsity, closure-captured indices, and full Jacobian decompression through jax.checkpoint.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
src/asdex/detection/_interpret/__init__.py Treats remat2 like other closed-jaxpr call primitives by reusing the jaxpr-param propagation handler.
tests/_interpret/test_nested_jaxpr.py Adds regression tests to ensure checkpoint/remat traces correctly for sparsity and decompression, including precise gather tracking with closure-captured indices.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@adrhill adrhill merged commit 8b233fe into main Jun 1, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature: Add handler for remat2 primitive (checkpoint/remat)

3 participants