Skip to content

Remove lineax dependency#697

Open
marcocuturi wants to merge 3 commits into
mainfrom
remove-lineax
Open

Remove lineax dependency#697
marcocuturi wants to merge 3 commits into
mainfrom
remove-lineax

Conversation

@marcocuturi
Copy link
Copy Markdown
Contributor

Summary

  • Replace lineax with JAX-native CG solver (jax.lax.while_loop) in implicit differentiation
  • Replace lineax linear operator abstractions in regularizers.py with plain jnp.ndarray operations
  • Remove lineax from project dependencies

Motivation

lineax uses equinox's filter_closure_convert which produces jaxpr/consts mismatches when called inside VJP backward passes on JAX >= 0.10. This caused 39 test failures in sinkhorn_diff_test.py. The JAX-native CG implementation works correctly under all JAX transformations.

Test plan

  • All 91 sinkhorn_diff_test.py tests pass locally
  • All 94 regularizers_test.py tests pass
  • All soft_sort_test.py and potentials_test.py tests pass
  • CI passes on all Python versions

🤖 Generated with Claude Code

marcocuturi and others added 3 commits May 22, 2026 13:59
Replace all lineax usage with plain JAX operations:
- lineax_implicit.py: JAX-native CG via jax.lax.while_loop
- regularizers.py: plain ndarray operations instead of lineax operators
- Remove lineax from pyproject.toml dependencies
- Update tests and docs accordingly

lineax's equinox closure conversion is incompatible with JAX >= 0.10
when called inside VJP backward passes, causing 39 test failures.
The JAX-native implementation avoids this entirely.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
# Conflicts:
#	src/ott/solvers/linear/lineax_implicit.py
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@codecov
Copy link
Copy Markdown

codecov Bot commented May 22, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 86.88%. Comparing base (2ac97aa) to head (aa532e2).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #697      +/-   ##
==========================================
- Coverage   86.88%   86.88%   -0.01%     
==========================================
  Files          83       83              
  Lines        8656     8653       -3     
  Branches      593      593              
==========================================
- Hits         7521     7518       -3     
  Misses        983      983              
  Partials      152      152              
Files with missing lines Coverage Δ
src/ott/geometry/regularizers.py 98.63% <100.00%> (-0.02%) ⬇️
src/ott/solvers/linear/implicit_differentiation.py 89.41% <ø> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant