uv add hyjax # or: pip install hyjaxDev setup:
git clone https://github.com/mlegls/hyjax
cd hyjax
uv sync
uv run pytest # run the test suite
uv run hy examples/experiments.hy # run the example scriptWhen walking about the countryside of Italy, the people will not hesitate to tell you that JAX has “una anima di pura programmazione funzionale”. –– 🔪 JAX - The Sharp Bits 🔪
“As one would expect from its goals, artificial intelligence research generates many significant programming problems. In other programming cultures this spate of problems spawns new languages. ... We toast the Lisp programmer who pens his thoughts within nests of parentheses.” -- Alan J. Perlis
You now know enough to be dangerous with Hy. You may now smile villainously and sneak off to your Hydeaway to do unspeakable things. –– Hy Tutorial
The goals are
- the name pun was funny
- let fully JIT-compiled JAX code be written in idiomatic Lisp
Some examples loosely inspired by the JAX Quickstart tutorials are in examples/experiments.hy. The Hy and JAX docs might also be helpful.
-
(defn/j f [args] body)as@jit def f(args): .../jfor jit; identical syntax to Hy'sdefn; supports other decorators, annotations, variadic & keyword args
-
(mapv f vec)asvmap(f)(vec) -
(if/l pred then else)aslax.cond/lfor lax; if-like binary form
-
(cond/l pred1 result1 pred2 result2 ... default)as a chain oflax.cond- multi-branch form, mirroring Lisp
cond
- multi-branch form, mirroring Lisp
-
(if/lp [bindings] pred then else)— explicit-binding ("lax pure") form ofif/l; you list the closed-over symbols yourself instead of relying on auto-detection -
(while/l [carry...] cond body-expr)aslax.while_loop, in pure-expression style. Carry names are listed explicitly and their initial values are taken from the enclosing scope. The body is a single expression that evaluates to the new carry (a scalar for one name, an N-tuple for N > 1). The loop form itself returns the final carry.; single-state (setv v 0) (setv v (while/l [v] (< v 10) (+ v 1))) ; multi-state (setv a 0 b 1) (setv [a b] (while/l [a b] (< a 100) #(b (+ a b))))
-
(fori/l [carry...] [i (range ...)] body-expr)aslax.fori_loop, same style. Bindings-first ordering matches Clojure'sloopand CLdo— the state carry leads, and the iteration descriptor follows.(setv acc 0) (setv acc (fori/l [acc] [i (range n)] (+ acc i))) (setv a 0 b 1) (setv [a b] (fori/l [a b] [i (range n)] #(b (+ a b))))
-
(while/lp [pattern init] cond body-expr)and(fori/lp [pattern init] [i lo hi] body-expr)— explicit-binding ("lax pure") forms of the loops. Use these when you want full control over what's in the carry, e.g. to keep loop-local temporaries out of it, or when you want to initialize with a different expression than a same-named outer binding. -
(scan/l [carry...] [x xs] body-expr)aslax.scan. Body evaluates to a#(new-carry y)pair; the form returns#(final-carry stacked-ys). Same carry-from-scope convention aswhile/l/fori/l.scan/lpis the explicit[pattern init]variant.; cumulative sum (setv acc 0) (setv [final sums] (scan/l [acc] [x xs] #((+ acc x) (+ acc x))))
-
(case/l idx expr0 expr1 ...)aslax.switch. Likecond/lbut branches are selected by integer index; free symbols are auto-detected and closed over. -
(ascan/l f xs)aslax.associative_scan. Extra kwargs are spliced through:(ascan/l + xs :reverse True). -
(mapl f xs)aslax.map(sequential), complementingmapvwhich isjax.vmap. -
(mapt f tree ...)asjax.tree.map. Accepts multiple pytrees for co-map.(mapt (fn [p g] (- p (* lr g))) params grads)
-
(with-keys key [k1 k2 k3] body)— PRNG key-splitting sugar. Consumeskey, rebinds it tosplit(key, 4)[0], and bindsk1,k2,k3to the remaining slots. The rebinding escapes into the enclosing scope so subsequent code sees the advanced key (matches how Python JAX code threads keys).(setv key (jax.random.PRNGKey 0)) (with-keys key [k1 k2] (setv x (jax.random.normal k1 #(4))) (setv y (jax.random.normal k2 #(4))))
/j— jit. Currently justdefn/j./l— lax, in its idiomatic form. Always an expression. Forif/l/cond/l, free symbols in the branches are auto-detected and closed over. Forwhile/l/fori/l, you list the carry names explicitly and their initial values are pulled from the enclosing scope./lp— lax pure. Lower-level explicit forms: forif/lp, you list the closed-over bindings yourself; forwhile/lp/fori/lp, you list the carry names and their initial values ([name init]). Use these when the/lsugar is inconvenient — e.g. initializing with a different expression than a same-named outer binding.