Skip to content

mlegls/hyjax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

28 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Hyjax

uv add hyjax         # or: pip install hyjax

Dev setup:

git clone https://github.com/mlegls/hyjax
cd hyjax
uv sync
uv run pytest                       # run the test suite
uv run hy examples/experiments.hy   # run the example script

Hyjax is Hy bindings for JAX.

When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has “una anima di pura programmazione funzionale”. –– 🔪 JAX - The Sharp Bits 🔪

“As one would expect from its goals, artificial intelligence research generates many significant programming problems. In other programming cultures this spate of problems spawns new languages. ... We toast the Lisp programmer who pens his thoughts within nests of parentheses.” -- Alan J. Perlis

You now know enough to be dangerous with Hy. You may now smile villainously and sneak off to your Hydeaway to do unspeakable things. –– Hy Tutorial

The goals are

  1. the name pun was funny
  2. let fully JIT-compiled JAX code be written in idiomatic Lisp

Some examples loosely inspired by the JAX Quickstart tutorials are in examples/experiments.hy. The Hy and JAX docs might also be helpful.

Features

  • (defn/j f [args] body) as @jit def f(args): ...
    • /j for jit; identical syntax to Hy's defn; supports other decorators, annotations, variadic & keyword args
  • (mapv f vec) as vmap(f)(vec)
  • (if/l pred then else) as lax.cond
    • /l for lax; if-like binary form
  • (cond/l pred1 result1 pred2 result2 ... default) as a chain of lax.cond
    • multi-branch form, mirroring Lisp cond
  • (if/lp [bindings] pred then else) — explicit-binding ("lax pure") form of if/l; you list the closed-over symbols yourself instead of relying on auto-detection
  • (while/l [carry...] cond body-expr) as lax.while_loop, in pure-expression style. Carry names are listed explicitly and their initial values are taken from the enclosing scope. The body is a single expression that evaluates to the new carry (a scalar for one name, an N-tuple for N > 1). The loop form itself returns the final carry.
    ; single-state
    (setv v 0)
    (setv v (while/l [v] (< v 10) (+ v 1)))
    
    ; multi-state
    (setv a 0  b 1)
    (setv [a b] (while/l [a b] (< a 100) #(b (+ a b))))
  • (fori/l [carry...] [i (range ...)] body-expr) as lax.fori_loop, same style. Bindings-first ordering matches Clojure's loop and CL do — the state carry leads, and the iteration descriptor follows.
    (setv acc 0)
    (setv acc (fori/l [acc] [i (range n)] (+ acc i)))
    
    (setv a 0  b 1)
    (setv [a b] (fori/l [a b] [i (range n)] #(b (+ a b))))
  • (while/lp [pattern init] cond body-expr) and (fori/lp [pattern init] [i lo hi] body-expr) — explicit-binding ("lax pure") forms of the loops. Use these when you want full control over what's in the carry, e.g. to keep loop-local temporaries out of it, or when you want to initialize with a different expression than a same-named outer binding.
  • (scan/l [carry...] [x xs] body-expr) as lax.scan. Body evaluates to a #(new-carry y) pair; the form returns #(final-carry stacked-ys). Same carry-from-scope convention as while/l / fori/l. scan/lp is the explicit [pattern init] variant.
    ; cumulative sum
    (setv acc 0)
    (setv [final sums]
          (scan/l [acc] [x xs] #((+ acc x) (+ acc x))))
  • (case/l idx expr0 expr1 ...) as lax.switch. Like cond/l but branches are selected by integer index; free symbols are auto-detected and closed over.
  • (ascan/l f xs) as lax.associative_scan. Extra kwargs are spliced through: (ascan/l + xs :reverse True).
  • (mapl f xs) as lax.map (sequential), complementing mapv which is jax.vmap.
  • (mapt f tree ...) as jax.tree.map. Accepts multiple pytrees for co-map.
    (mapt (fn [p g] (- p (* lr g))) params grads)
  • (with-keys key [k1 k2 k3] body) — PRNG key-splitting sugar. Consumes key, rebinds it to split(key, 4)[0], and binds k1, k2, k3 to the remaining slots. The rebinding escapes into the enclosing scope so subsequent code sees the advanced key (matches how Python JAX code threads keys).
    (setv key (jax.random.PRNGKey 0))
    (with-keys key [k1 k2]
      (setv x (jax.random.normal k1 #(4)))
      (setv y (jax.random.normal k2 #(4))))

Naming convention

  • /jjit. Currently just defn/j.
  • /llax, in its idiomatic form. Always an expression. For if/l / cond/l, free symbols in the branches are auto-detected and closed over. For while/l / fori/l, you list the carry names explicitly and their initial values are pulled from the enclosing scope.
  • /lplax pure. Lower-level explicit forms: for if/lp, you list the closed-over bindings yourself; for while/lp / fori/lp, you list the carry names and their initial values ([name init]). Use these when the /l sugar is inconvenient — e.g. initializing with a different expression than a same-named outer binding.

About

Hy bindings for JAX

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors