Static array shape checking for JAX powered by jax.eval_shape.
Reads jaxtyping annotations and verifies shapes at analysis time -- no runtime cost, no FLOPs. Unlike jaxtyping's runtime beartype checks, jaxtyc catches shape bugs before your code ever runs on a GPU.
jaxtyc traces your annotated functions with jax.eval_shape, which propagates shapes through JAX operations without allocating any arrays. Each named dimension (batch, d_model, etc.) is assigned a unique prime number (>= 101), so distinct dimension names always produce distinct sizes -- making shape mismatches unambiguous and impossible to mask by accident.
- Zero runtime cost --
jax.eval_shapeonly; no arrays allocated, no computation executed - Prime-based symbolic shapes -- each dimension name maps to a unique prime (>= 101), so
d_in != d_outis guaranteed - 18 diagnostic rules -- shape/rank mismatch, cross-function propagation, parameter consistency, sharding validation, tuple return checking, trace errors
- Sharding-in-types -- annotate mesh axes inline with
|syntax ("batch|dp seq|None d_model|mp"), validated against mesh topology with 8 dedicated sharding rules - Einops integration -- detects
einops.rearrange/reduce/repeatcalls, parses pattern strings for shape checking, and provides einops-aware fix suggestions - Inline suppressions --
# jaxtyc: ignoreand# jaxtyc: ignore[rule-name] - LSP server -- diagnostics, hover, CodeLens, go-to-definition, references, rename, code actions, completion, semantic tokens, inlay hints, signature help, linked editing, folding, call hierarchy
- LSP multiplexer --
jaxtyc muxruns ty/pyright + jaxtyc behind a single stdio pipe - CLI with 4 output formats --
full,concise,json,github(inline PR annotations) - Flax NNX + Equinox support -- traces bound methods on module instances
- Configurable via
pyproject.toml-- severity threshold, rule ignoring, file exclusion, einops preferences
pip install jaxtyc
# or
uv add jaxtycExtras:
| Extra | Installs | Use case |
|---|---|---|
jaxtyc[watch] |
watchfiles |
jaxtyc watch -- re-check on file save |
jaxtyc[flax] |
flax >=0.10 |
Flax NNX module tracing |
jaxtyc[equinox] |
equinox >=0.11 |
Equinox module tracing |
jaxtyc[einops] |
einops >=0.8 |
einops-style fix suggestions + inlay hints with pattern dim names |
jaxtyc[all] |
All of the above | Everything |
# model.py
import jax.numpy as jnp
from jaxtyping import Array, Float
def linear(
x: Float[Array, "batch seq d_in"],
w: Float[Array, "d_in d_out"],
) -> Float[Array, "batch seq d_out"]:
return jnp.matmul(x, w.T) # Bug: .T swaps dims, produces (batch, seq, d_in)$ jaxtyc check model.py
model.py:8:0: error[shape-mismatch]
Shape mismatch in return of `linear`
Expected: (batch, seq, d_out)
Got: (batch, seq, d_in)
Found 1 error(s) in 1 function(s) checked (0.03s)Fix: replace w.T with w and annotate w as "d_out d_in", or use jnp.matmul(x, w) with w: Float[Array, "d_in d_out"].
Install from the VS Code Marketplace or search "jaxtyc" in the Extensions view.
Or build from source:
cd editors/vscode && npm install && npm run bundle
npx @vscode/vsce package --allow-missing-repository
code --install-extension jaxtyc-*.vsixOr use the justfile: just vscode-update
The extension auto-discovers your Python environment (.venv, VIRTUAL_ENV, or jaxtyc on PATH) and starts the LSP server automatically. Supports multi-root workspaces with per-folder LSP clients. Includes jaxtyping snippets, a trace visualization webview, and a status bar quick pick menu.
jaxtyc works in any editor that supports LSP (Neovim, Helix, etc.). See the editor setup docs for configuration.
jaxtyc check <paths>... # Shape-check files or directories
jaxtyc trace <file.py::func> # Trace intermediate shapes through a function
jaxtyc watch <paths>... # Watch and re-check on change
jaxtyc lsp # Start the LSP server (stdio)
jaxtyc mux # Start the LSP multiplexer (ty/pyright + jaxtyc)
jaxtyc version # Print version
Use --format github to get inline annotations on pull request diffs:
name: Shape Check
on: [push, pull_request]
jobs:
jaxtyc:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: astral-sh/setup-uv@v7
- run: uv sync
- run: uv run jaxtyc check src/ --format githubEach shape error appears as an annotation on the exact file and line in the PR. See the CI docs for JSON output, configuration, and full pipeline examples.
Full docs at beegass.github.io/jaxtyc.
Contributions are welcome! See CONTRIBUTING.md for guidelines.
MIT


