From e6bac9c086352c824c4396431cdd0b3e8c7a9aa1 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Sat, 30 May 2026 11:50:43 -0700 Subject: [PATCH] optimize jit_partial PiperOrigin-RevId: 923984648 --- benchmarks/nnx_graph_overhead.py | 33 ++- benchmarks/nnx_simple_training.py | 84 +++++- flax/nnx/extract.py | 72 ++--- flax/nnx/transforms/compilation.py | 428 +++++++++++++++++++++++------ flax/nnx/transforms/iteration.py | 11 +- tests/nnx/transforms_test.py | 11 +- 6 files changed, 493 insertions(+), 146 deletions(-) diff --git a/benchmarks/nnx_graph_overhead.py b/benchmarks/nnx_graph_overhead.py index fd20fc5a8..92ffc7c4e 100644 --- a/benchmarks/nnx_graph_overhead.py +++ b/benchmarks/nnx_graph_overhead.py @@ -25,7 +25,7 @@ FLAGS = flags.FLAGS flags.DEFINE_enum( - 'mode', 'nnx', ['all', 'nnx', 'jax'], 'Mode to run the script in' + 'mode', 'nnx', ['all', 'nnx', 'jax', 'jit_partial'], 'Mode to run the script in' ) flags.DEFINE_integer('total_steps', 100, 'Total number of training steps') flags.DEFINE_integer('width', 32, 'Hidden layer size') @@ -91,7 +91,6 @@ def main(argv): model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) tx = optax.sgd(1e-3) optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) - t0 = time() @nnx.jit def step_nnx(model: MLP, optimizer: nnx.Optimizer): @@ -108,6 +107,35 @@ def step_nnx(model: MLP, optimizer: nnx.Optimizer): print('total time:', total_time) print(f'time per step: {time_per_step * 1e6:.2f} µs') print(f'time per layer: {time_per_layer * 1e6:.2f} µs') + print() + + # ------------------------------------------------------------ + # JIT Partial + # ------------------------------------------------------------ + if mode in ['all', 'jit_partial']: + model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + + def step_partial(model: MLP, optimizer: nnx.Optimizer): + pass + + step_partial_jit = nnx.jit_partial( + step_partial, model, optimizer, graph=False + ) + + t0 = time() + for _ in range(total_steps): + step_partial_jit() + + total_time = time() - t0 + time_per_step = total_time / total_steps + time_per_layer = time_per_step / depth + print('### JIT PARTIAL ###') + print('total time:', total_time) + print(f'time per step: {time_per_step * 1e6:.2f} µs') + print(f'time per layer: {time_per_layer * 1e6:.2f} µs') + print() # ------------------------------------------------------------ # JAX @@ -117,7 +145,6 @@ def step_nnx(model: MLP, optimizer: nnx.Optimizer): model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) tx = optax.sgd(1e-3) optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) - t0 = time() @jax.jit def step_jax(graphdef, state): diff --git a/benchmarks/nnx_simple_training.py b/benchmarks/nnx_simple_training.py index e134fa317..2f543b1d8 100644 --- a/benchmarks/nnx_simple_training.py +++ b/benchmarks/nnx_simple_training.py @@ -13,6 +13,9 @@ # limitations under the License. # %% +import cProfile +import pstats +import io from functools import partial import jax import jax.numpy as jnp @@ -27,12 +30,13 @@ FLAGS = flags.FLAGS flags.DEFINE_enum( - 'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in' + 'mode', 'all', ['all', 'nnx', 'jax', 'jit_partial'], 'Mode to run the script in' ) flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps') flags.DEFINE_integer('batch_size', 32, 'Batch size') flags.DEFINE_integer('width', 32, 'Hidden layer size') flags.DEFINE_integer('depth', 5, 'Depth of the model') +flags.DEFINE_bool('profile', False, 'Enable cProfile profiling') def dataset(X, Y, batch_size): @@ -67,13 +71,13 @@ class MLP(nnx.Module): def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs): self.count = Count(jnp.array(0)) self.linear_in = Block(din, dhidden, rngs=rngs) - self.intermediates = [ + self.intermediates = nnx.List([ Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2) - ] + ]) self.linear_out = Block(dhidden, dout, rngs=rngs) def __call__(self, x): - self.count.value += 1 + self.count[...] += 1 x = nnx.relu(self.linear_in(x)) for layer in self.intermediates: x = nnx.relu(layer(x)) @@ -118,6 +122,7 @@ def test_step_nnx(model: MLP, batch): loss = jnp.mean((y - y_pred) ** 2) return {'loss': loss} + logs = {'loss': jnp.array(0.0)} for step, batch in enumerate(dataset(X, Y, batch_size)): train_step_nnx(model, optimizer, batch) @@ -132,7 +137,73 @@ def test_step_nnx(model: MLP, batch): total_time = time() - t0 print('total time:', total_time) print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') - print('times called:', model.count.value) + print('times called:', model.count[...]) + print() + + if mode == 'jit_partial' or mode == 'all': + model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + t0 = time() + + def train_step(model: MLP, optimizer: nnx.Optimizer, batch): + x, y = batch + + def loss_fn(model: MLP): + y_pred = model(x) + return jnp.mean((y - y_pred) ** 2) + + grads = nnx.grad(loss_fn)(model) + optimizer.update(model, grads) + + def test_step(model: MLP, batch): + x, y = batch + y_pred = model(x) + loss = jnp.mean((y - y_pred) ** 2) + return {'loss': loss} + + train_step_fn = nnx.jit_partial( + train_step, model, optimizer, graph=False + ) + test_step_fn = nnx.jit_partial(test_step, model, graph=False) + + logs = {'loss': jnp.array(0.0)} + # Warmup + for step, batch in enumerate(dataset(X, Y, batch_size)): + train_step_fn(batch) + if step >= 10: + break + + pr = None + if FLAGS.profile: + pr = cProfile.Profile() + pr.enable() + + for step, batch in enumerate(dataset(X, Y, batch_size)): + train_step_fn(batch) + + if step % 1000 == 0: + logs = test_step_fn((X, Y)) + + if step >= total_steps - 1: + break + + if pr is not None: + pr.disable() + for sort_key in ('cumulative', 'tottime'): + s = io.StringIO() + ps = pstats.Stats(pr, stream=s) + ps.sort_stats(sort_key) + ps.print_stats(40) + print(s.getvalue()) + + print('### JIT PARTIAL ###') + print(f'final loss: {logs["loss"]}') + total_time = time() - t0 + print('total time:', total_time) + print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') + print('times called:', model.count[...]) + print() if mode == 'jax' or mode == 'all': model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) @@ -165,6 +236,7 @@ def test_step_jax(state, batch): graphdef, state = nnx.split((model, optimizer)) + logs = {'loss': jnp.array(0.0)} for step, batch in enumerate(dataset(X, Y, batch_size)): state = train_step_jax(state, batch) @@ -181,7 +253,7 @@ def test_step_jax(state, batch): total_time = time() - t0 print('total time:', total_time) print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') - print('times called:', model.count.value) + print('times called:', model.count[...]) if __name__ == '__main__': diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index cc816075b..ae8b71ae8 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -732,7 +732,7 @@ def check_prefix( graph_updates: bool, none_leaf: bool = True, ): - unique_prefixes: OrderedDict[tp.Any, tp.Any] = OrderedDict() + unique_prefixes: set[tp.Any] = set() def _check_prefix(path, leaf): if isinstance(leaf, variablelib.Variable): @@ -798,12 +798,12 @@ def _check_prefix(path, leaf): ) def _collect_prefix(_, leaf): - unique_prefixes[leaf] = leaf + unique_prefixes.add(leaf) jax.tree.map_with_path( _collect_prefix, prefix, is_leaf=lambda x: x is None and none_leaf ) - return unique_prefixes + return list(unique_prefixes) def variable_changed(post: variablelib.Variable, pre: variablelib.Variable) -> bool: @@ -818,27 +818,28 @@ def variable_changed(post: variablelib.Variable, pre: variablelib.Variable) -> b ] +@dataclasses.dataclass(slots=True) class Updates( tp.Sequence[tuple[jax.tree_util.KeyPath, variablelib.Variable]], reprlib.Representable, ): - __slots__ = ('_keys', '_values') + _keys: list[tp.Any] = dataclasses.field(default_factory=list) + _values: list[variablelib.Variable] = dataclasses.field(default_factory=list) - _keys: list[jax.tree_util.KeyPath] - _values: list[variablelib.Variable] - - def __init__( - self, + @classmethod + def create( + cls, items: tp.Iterable[ tuple[jax.tree_util.KeyPath, variablelib.Variable] ] = (), - ): - self._keys, self._values = [], [] + ) -> 'Updates': + keys, values = [], [] for key, value in items: - self._keys.append(key) - self._values.append(value) + keys.append(key) + values.append(value) + return cls(_keys=keys, _values=values) - def append(self, key: jax.tree_util.KeyPath, value: variablelib.Variable): + def append(self, key: tp.Any, value: variablelib.Variable): self._keys.append(key) self._values.append(value) @@ -880,7 +881,7 @@ def __len__(self): return len(self._keys) def __iter__(self): - return iter(zip(self._keys, self._values)) + return zip(self._keys, self._values) def __nnx_repr__(self): yield reprlib.Object(type=type(self), kv_sep=': ', start='({', end='})') @@ -892,30 +893,10 @@ def __nnx_repr__(self): ) -def _updates_flatten_with_keys(x: Updates): - key_children = [ - (jax.tree_util.FlattenedIndexKey(i), v) - for i, v in enumerate(x._values) - ] - return key_children, x._keys - - -def _updates_flatten(x: Updates): - return x._values, x._keys - - -def _updates_unflatten(keys, values) -> Updates: - updates = object.__new__(Updates) - updates._keys = keys - updates._values = list(values) - return updates - - -jax.tree_util.register_pytree_with_keys( +jax.tree_util.register_dataclass( Updates, - _updates_flatten_with_keys, - _updates_unflatten, - flatten_func=_updates_flatten, + data_fields=['_values'], + meta_fields=['_keys'], ) def get_updates( @@ -929,7 +910,7 @@ def get_updates( if keep_fn is None: keep_fn = lambda _, _pfx, cur, snap: variable_changed(cur, snap) - updates = OrderedDict((pfx, Updates()) for pfx in known_prefixes) + updates = {pfx: Updates.create() for pfx in known_prefixes} def _mask_updates(path, prefix_leaf, current, snapshot): if isinstance(current, variablelib.Variable): @@ -944,14 +925,14 @@ def _mask_updates(path, prefix_leaf, current, snapshot): _mask_updates, prefix, current_tree, snapshot_tree, is_leaf=is_leaf, prefix_leaf=prefix_leaf, ) - return updates + return list(updates.values()) def apply_updates( variables: dict[jax.tree_util.KeyPath, variablelib.Variable], - updates: OrderedDict[tp.Any, Updates], + updates: list[Updates], ): - for _, flat_state in updates.items(): + for flat_state in updates: for path, update in flat_state: if path in variables: variable = variables[path] @@ -965,6 +946,7 @@ def apply_updates( ) + def treemap_copy_args(f: F) -> F: @functools.wraps(f) def wrapper(*args, **kwargs): @@ -1110,9 +1092,9 @@ def _apply_prefix(jax_path, leaf): return jax.tree.map_with_path(_apply_prefix, node, is_leaf=is_leaf) -def to_masked(tree, all_updates: OrderedDict[tp.Any, Updates]): - combined: OrderedDict[tp.Any, tp.Any] = OrderedDict() - for updates in all_updates.values(): +def to_masked(tree, all_updates: list[Updates]): + combined: dict[tp.Any, tp.Any] = {} + for updates in all_updates: combined.update(updates) return jax.tree.map_with_path( lambda path, _: combined.get(path, None), tree, diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index 52ec0205a..5f01f382e 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -30,6 +30,8 @@ statelib, variablelib, ) +from flax import errors +from flax.nnx import tracers from flax.nnx.extract import labeled from flax.nnx.transforms.transforms import ( _resolve_bound_callable, @@ -383,7 +385,8 @@ def jit( update_shardings = extract.check_prefix( in_shardings, 'in_shardings', 'jit', graph, graph_updates ) - update_shardings[None] = None # kwargs sharding + if None not in update_shardings: + update_shardings.append(None) # kwargs sharding extract.check_prefix( out_shardings, 'out_shardings', 'jit', graph, graph_updates ) @@ -413,39 +416,48 @@ def jit( ) -@dataclasses.dataclass(frozen=True, slots=True) -class PartialState: - """Container for a pre-flattened partial argument. - - Stores the pytree structure (``treedef``) as static metadata and the - flattened leaves as dynamic data. Variables within the original argument - are kept as leaves so their values can change between calls without - triggering recompilation. - """ - treedef: jax.tree_util.PyTreeDef - leaves: list[tp.Any] -jax.tree_util.register_dataclass( - PartialState, - data_fields=['leaves'], - meta_fields=['treedef'], -) +@dataclasses.dataclass(eq=False) +class SimpleJitFn: + f: tp.Callable[..., tp.Any] + in_shardings: tp.Any + out_shardings: tp.Any + donate_argnums: frozenset[int] + donate_argnames: frozenset[str] + graph: bool + update_shardings: tuple[tp.Any, ...] + def __post_init__(self): + functools.update_wrapper(self, self.f, updated=()) -def _flatten_to_partial_state( - arg: tp.Any, - ref_index: graphlib.RefMap | None, -) -> PartialState: - if ref_index is not None: - graphdef, flat_state = graphlib.flatten(arg, ref_index=ref_index, graph=True) - return PartialState(treedef=graphdef, leaves=flat_state.leaves) - is_leaf = lambda x: isinstance(x, variablelib.Variable) - leaves, treedef = jax.tree.flatten(arg, is_leaf=is_leaf) - return PartialState(treedef=treedef, leaves=leaves) + @extract.treemap_copy_args + def __call__(self, *args, **kwargs): + current, snapshot = extract.snapshot( + labeled(args=args, kwargs=kwargs) + ) + if self.graph: + args, kwargs = extract.from_tree2((args, kwargs)) + out = self.f(*args, **kwargs) + if self.graph: + out = extract.to_tree2(out, prefix=self.out_shardings) + extract.check_no_aliases('jit', **current, out=out, check=['out']) + def keep_fn(jax_path, prefix, c, s): + if extract.variable_changed(c, s): + return True + arg_type, arg_key, *_ = graphlib.jax_to_nnx_path(jax_path) + if arg_type == 'args': + return arg_key in self.donate_argnums + else: # arg_type == 'kwargs': + return arg_key in self.donate_argnames + updates = extract.get_updates( + current, snapshot, prefix=labeled(args=self.in_shardings, kwargs=None), + known_prefixes=self.update_shardings, keep_fn=keep_fn + ) + return out, updates @dataclasses.dataclass(eq=False) -class SimpleJitFn: +class SimpleJitPartialFn: f: tp.Callable[..., tp.Any] in_shardings: tp.Any out_shardings: tp.Any @@ -468,8 +480,14 @@ def __call__(self, *args, **kwargs): if self.graph: out = extract.to_tree2(out, prefix=self.out_shardings) extract.check_no_aliases('jit', **current, out=out, check=['out']) - def keep_fn(jax_path, prefix, c, s): + def keep_fn(jax_path, prefix, c: variablelib.Variable, s: variablelib.Variable): if extract.variable_changed(c, s): + if c.get_metadata() != s.get_metadata(): + path_str = jax.tree_util.keystr(jax_path) + raise ValueError( + f'Variable metadata changed inside jit at path {path_str}. ' + f'Changing Variable metadata inside jit is not supported.' + ) return True arg_type, arg_key, *_ = graphlib.jax_to_nnx_path(jax_path) if arg_type == 'args': @@ -480,6 +498,13 @@ def keep_fn(jax_path, prefix, c, s): current, snapshot, prefix=labeled(args=self.in_shardings, kwargs=None), known_prefixes=self.update_shardings, keep_fn=keep_fn ) + for update_group in updates: + update_group._keys = [ + k[-1].idx for k in update_group._keys + ] + update_group._values = [ + v.get_raw_value() for v in update_group._values + ] return out, updates @@ -498,9 +523,9 @@ def __init__( device: tp.Optional[jax.Device], backend: tp.Optional[str], inline: bool, - partial_args: tuple[PartialState, ...], + partial_args: tuple[tp.Any, ...], graph: bool, - update_shardings: extract.OrderedDict, + update_shardings: tuple[tp.Any, ...], ): functools.update_wrapper(self, fun) self.fun: tp.Callable[P, R] = fun @@ -549,12 +574,17 @@ def __init__( def _maybe_to_tree(self, args, kwargs): if self.graph: + if self.in_shardings is not None and isinstance(self.in_shardings, (tuple, list)): + runtime_prefix = self.in_shardings[len(self.partial_args):] + else: + runtime_prefix = self.in_shardings + args, kwargs = extract.to_tree2( (args, kwargs), - prefix=(self.in_shardings, None) - if self.in_shardings is not None + prefix=(runtime_prefix, None) + if runtime_prefix is not None else None, - check_aliasing=self.in_shardings is not None, + check_aliasing=runtime_prefix is not None, ) return args, kwargs @@ -581,7 +611,7 @@ def eval_shape(self, *args, **kwargs): args, kwargs = self._maybe_to_tree(args, kwargs) if not self.graph: extract.check_no_aliases('jit', args=args, kwargs=kwargs) - out, updates = self.jitted_fn.eval_shape(*args, **kwargs) + out, _ = self.jitted_fn.eval_shape(*args, **kwargs) return self._maybe_from_tree(out) def trace(self, *args, **kwargs): @@ -599,6 +629,109 @@ def lower(self, *args, **kwargs): extract.check_no_aliases('jit', args=args, kwargs=kwargs) lowered = self.jitted_fn.lower(*args, **kwargs) return SimpleLowered(lowered, self) + + +def _apply_raw_updates( + partial_args: list[tp.Any], + updates: list[extract.Updates], +): + """Apply updates containing raw values using integer indices into partial_args.""" + trace = tracers.current_jax_trace() + for flat_state in updates: + for index, raw_value in flat_state: + var = partial_args[index] + if var._trace_state._jax_trace != trace: + raise errors.TraceContextError( + f'Cannot mutate {type(var).__name__} from a different trace level' + ) + object.__setattr__(var, '_raw_value', raw_value) + + +class SimpleJitPartialWrapped(tp.Generic[P, R]): + + def __init__( + self, + fun: tp.Callable[P, R], + in_shardings: tp.Any, + out_shardings: tp.Any, + static_argnums: int | tp.Sequence[int] | None, + static_argnames: str | tp.Iterable[str] | None, + donate_argnums: int | tp.Sequence[int] | None, + donate_argnames: str | tp.Iterable[str] | None, + keep_unused: bool, + device: tp.Optional[jax.Device], + backend: tp.Optional[str], + inline: bool, + partial_args: list[tp.Any], + graph: bool, + update_shardings: list[tp.Any], + ): + functools.update_wrapper(self, fun) + self.fun: tp.Callable[P, R] = fun + self.in_shardings = in_shardings + self.out_shardings = out_shardings + self.partial_args = partial_args + self.graph = graph + + donate_argnums_set = frozenset( + (donate_argnums,) if isinstance(donate_argnums, int) + else donate_argnums or () + ) + donate_argnames_set = frozenset( + (donate_argnames,) if isinstance(donate_argnames, str) + else donate_argnames or () + ) + self.jitted_fn = jax.jit( + SimpleJitPartialFn( + fun, + in_shardings, + out_shardings, + donate_argnums_set, + donate_argnames_set, + graph, + tuple(update_shardings), + ), + in_shardings=in_shardings, + out_shardings=(out_shardings, update_shardings), + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + donate_argnames=donate_argnames, + keep_unused=keep_unused, + device=device, + backend=backend, + inline=inline, + ) + + def _maybe_from_tree(self, out): + if self.graph: + out = extract.from_tree2(out) + return out + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: + out, updates = self.jitted_fn(self.partial_args, *args, **kwargs) + _apply_raw_updates(self.partial_args, updates) + return self._maybe_from_tree(out) + + def __get__(self, obj, objtype=None): + if obj is None: + return self + return functools.partial(self, obj) + + def eval_shape(self, *args, **kwargs): + out, _ = self.jitted_fn.eval_shape( + self.partial_args, *args, **kwargs + ) + return self._maybe_from_tree(out) + + def trace(self, *args, **kwargs): + traced = self.jitted_fn.trace(self.partial_args, *args, **kwargs) + return SimplePartialTraced(traced, self) + + def lower(self, *args, **kwargs): + lowered = self.jitted_fn.lower(self.partial_args, *args, **kwargs) + return SimplePartialLowered(lowered, self) + def jit_partial( fun: tp.Callable[..., R], *partial_args: tp.Any, @@ -614,12 +747,12 @@ def jit_partial( inline: bool = False, graph: bool | None = None, graph_updates: bool | None = None, -) -> SimpleJitWrapped[..., R]: +) -> SimpleJitPartialWrapped[..., R]: """JIT-compile ``fun`` with pre-flattened partial arguments. Similar to ``nnx.cached_partial`` but designed for tree-mode (``graph=False``). Each ``partial_arg`` is flattened into a - ``PartialState`` whose pytree structure is fixed at construction time. + list of Variables and Arrays whose pytree structure is fixed at construction time. Variable values inside partial arguments can still change between calls without triggering recompilation, and any mutations to Variables are propagated back to the originals after each call. @@ -684,7 +817,8 @@ def jit_partial( update_shardings = extract.check_prefix( in_shardings, 'in_shardings', 'jit_partial', graph, graph_updates ) - update_shardings[None] = None # kwargs sharding + if None not in update_shardings: + update_shardings.append(None) # kwargs sharding if any(isinstance(x, StateSharding) for x in jax.tree.leaves(in_shardings)): raise ValueError( '`in_shardings` cannot contain `StateSharding` objects ' @@ -697,65 +831,89 @@ def jit_partial( ) is_variable = lambda x: isinstance(x, variablelib.Variable) - ref_index = graphlib.RefMap() if graph else None - flat_partial_args = tuple( - _flatten_to_partial_state(arg, ref_index=ref_index) - for arg in partial_args + + # 1. Graph->tree conversion and alias check beforehand + if graph: + if in_shardings is not None and isinstance(in_shardings, (tuple, list)): + partial_in_axes = in_shardings[:len(partial_args)] + else: + partial_in_axes = in_shardings + tree_partial_args = extract.to_tree2( + partial_args, + prefix=partial_in_axes, + check_aliasing=partial_in_axes is not None, + ) + else: + tree_partial_args = partial_args + + # Check no aliases beforehand + extract.check_no_aliases('jit_partial', args=tree_partial_args) + + # 2. Flatten the partial_args to a single list of Variables and Arrays + flat_partial_args, partial_treedef = jax.tree.flatten( + tree_partial_args, is_leaf=is_variable ) + + # 4. Sharding calculation + # partial_args is passed as a single list argument, so in_shardings + # for that argument is a list matching its pytree structure. jit_in_shardings: tp.Any = None - if in_shardings is not None and isinstance(in_shardings, (tuple, list)) and not graph: + if in_shardings is not None and isinstance(in_shardings, (tuple, list)): num_partial = len(partial_args) partial_shardings = in_shardings[:num_partial] runtime_shardings = in_shardings[num_partial:] - flat_partial_shardings = [] - for flat_arg, orig_arg, sharding in zip( - flat_partial_args, partial_args, partial_shardings): - broadcasted = extract.broadcast_prefix( - sharding, orig_arg, + broadcasted = extract.broadcast_prefix( + partial_shardings, tree_partial_args, prefix_is_leaf=lambda x: x is None or isinstance(x, variablelib.Variable), tree_is_leaf=is_variable, - ) - flat_partial_shardings.append( - PartialState(treedef=flat_arg.treedef, leaves=broadcasted) - ) - jit_in_shardings = (*flat_partial_shardings, *runtime_shardings) + ) + flat_partial_shardings = jax.tree.leaves(broadcasted, is_leaf=lambda x: x is None) + jit_in_shardings = (flat_partial_shardings, *runtime_shardings) else: jit_in_shardings = in_shardings + # 5. wrapped_fun accepts partial_args as its first argument (a list) @functools.wraps(fun) - def wrapped_fun(*args, **kwargs): - index_ref = graphlib.IndexMap() if graph else None - def _unflatten(arg): - if not isinstance(arg, PartialState): - return arg - elif graph: - return graphlib.unflatten( - arg.treedef, arg.leaves, index_ref=index_ref, - copy_variables=False, - ) - else: - return jax.tree.unflatten(arg.treedef, arg.leaves) - args = (_unflatten(a) for a in args) - return fun(*args, **kwargs) + def wrapped_fun(flat_partial_list, *args, **kwargs): + # Check no Variables in runtime args/kwargs + runtime_leaves = jax.tree.leaves((args, kwargs), is_leaf=is_variable) + if any(is_variable(x) for x in runtime_leaves): + raise ValueError( + 'Found Variable in non-partial arguments. ' + 'jit_partial only supports Variables in partial arguments.' + ) - return SimpleJitWrapped( - wrapped_fun, - in_shardings=jit_in_shardings, - out_shardings=out_shardings, - static_argnums=static_argnums, - static_argnames=static_argnames, - donate_argnums=donate_argnums, - donate_argnames=donate_argnames, - keep_unused=keep_unused, - device=device, - backend=backend, - inline=inline, - partial_args=flat_partial_args, - graph=graph, - update_shardings=update_shardings, + # Unflatten to tree_partial_args (which contains TreeState if graph=True) + tree_partial_args = jax.tree.unflatten(partial_treedef, flat_partial_list) + + # Convert TreeState back to Modules if graph=True, preserving Variable identity + if graph: + reconstructed_partial_args = extract.from_tree2( + tree_partial_args, recreate_variables=False + ) + else: + reconstructed_partial_args = tree_partial_args + + return fun(*reconstructed_partial_args, *args, **kwargs) + + return SimpleJitPartialWrapped( + wrapped_fun, + in_shardings=jit_in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + donate_argnames=donate_argnames, + keep_unused=keep_unused, + device=device, + backend=backend, + inline=inline, + partial_args=flat_partial_args, + graph=graph, + update_shardings=update_shardings, ) @@ -1265,6 +1423,112 @@ def lower( ) -> SimpleLowered: lowered = self.traced.lower(lowering_platforms=lowering_platforms) return SimpleLowered(lowered, self.jit_wrapped) + + +@dataclasses.dataclass(frozen=True, slots=True) +class SimplePartialCompiled(Stage): + compiled: jax.stages.Compiled + jit_wrapped: SimpleJitPartialWrapped + + @property + def _inner_obj(self): + return self.compiled + + @property + def args_info(self) -> tp.Any: + raise self.compiled.args_info + + @staticmethod + def call(*args, **kwargs): + raise NotImplementedError + + def __call__(self, *args, **kwargs): + out, updates = self.compiled(self.jit_wrapped.partial_args, *args, **kwargs) + _apply_raw_updates(self.jit_wrapped.partial_args, updates) + return self.jit_wrapped._maybe_from_tree(out) + + @property + def out_tree(self) -> jax.tree_util.PyTreeDef: + return self.compiled.out_tree + + def as_text(self) -> str | None: + return self.compiled.as_text() + + def cost_analysis(self) -> tp.Any | None: + return self.compiled.cost_analysis() + + def memory_analysis(self) -> tp.Any | None: + return self.compiled.memory_analysis() + + def runtime_executable(self) -> tp.Any | None: + return self.compiled.runtime_executable() + + @property + def input_shardings(self): + return self.compiled.input_shardings + + @property + def output_shardings(self): + return self.compiled.output_shardings + + @property + def input_layouts(self): + return self.compiled.input_formats + + +@dataclasses.dataclass(frozen=True, slots=True) +class SimplePartialLowered(Stage): + lowered: jax.stages.Lowered + jit_wrapped: SimpleJitPartialWrapped + + @property + def _inner_obj(self): + return self.lowered + + @property + def args_info(self) -> tp.Any: + return self.lowered.args_info + + @property + def out_tree(self): + return self.lowered.out_tree + + def compile( + self, compiler_options: jax.stages.CompilerOptions | None = None + ) -> SimplePartialCompiled: + compiled = self.lowered.compile(compiler_options) + return SimplePartialCompiled(compiled, self.jit_wrapped) + + def as_text( + self, dialect: str | None = None, *, debug_info: bool = False + ) -> str: + return self.lowered.as_text(dialect=dialect, debug_info=debug_info) + + def compiler_ir(self, dialect: str | None = None) -> tp.Any | None: + return self.lowered.compiler_ir(dialect=dialect) + + def cost_analysis(self) -> tp.Any | None: + return self.lowered.cost_analysis() + + +@dataclasses.dataclass(frozen=True, slots=True) +class SimplePartialTraced(Stage): + traced: jax.stages.Traced + jit_wrapped: SimpleJitPartialWrapped + + @property + def _inner_obj(self): + return self.traced + + @property + def out_info(self): + return self.traced.out_info + + def lower( + self, *, lowering_platforms: tuple[str, ...] | None = None + ) -> SimplePartialLowered: + lowered = self.traced.lower(lowering_platforms=lowering_platforms) + return SimplePartialLowered(lowered, self.jit_wrapped) # ------------------------------- # shard_map # ------------------------------- diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index a13516600..791265a11 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -517,7 +517,8 @@ def vmap( _raise_bound_method_error('vmap') update_axes = extract.check_prefix(in_axes, 'in_axes', 'vmap', graph, graph_updates) - update_axes[0] = 0 # kwargs axes + if 0 not in update_axes: + update_axes.append(0) # kwargs axes extract.check_prefix(out_axes, 'out_axes', 'vmap', graph, graph_updates) if not (graph and graph_updates): @@ -771,7 +772,8 @@ def pmap( _raise_bound_method_error('pmap') update_axes = extract.check_prefix(in_axes, 'in_axes', 'pmap', graph, graph_updates) - update_axes[0] = 0 # kwargs axes + if 0 not in update_axes: + update_axes.append(0) # kwargs axes extract.check_prefix(out_axes, 'out_axes', 'pmap', graph, graph_updates) if not (graph and graph_updates): @@ -1662,12 +1664,11 @@ def _simple_scan( f, f_unbound, *, graph, in_axes, out_axes, length, reverse, unroll, _split_transpose, - updates_axes: extract.OrderedDict, + updates_axes: list[tp.Any], ): _validate_scan_axes(in_axes, out_axes) # None and Carry aren't valid update axes - updates_axes.pop(None, None) - updates_axes.pop(Carry, None) + updates_axes = [ax for ax in updates_axes if ax is not None and ax is not Carry] out_is_tuple = isinstance(out_axes, tuple) was_carry = in_axes is Carry diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index a1795caf0..b1d75bc73 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -1087,7 +1087,7 @@ def test_jit_partial_no_partial_args(self, graph_mode): y = f_partial(jnp.array(3.0)) np.testing.assert_allclose(y, 6.0) - @parameterized.parameters((False,)) + @parameterized.parameters(True, False) def test_jit_partial_in_shardings_none_broadcast(self, graph_mode): n_devices = jax.local_device_count() devices = mesh_utils.create_device_mesh((n_devices,)) @@ -1103,7 +1103,7 @@ def f(m, x): y = f_jit(x) self.assertEqual(y.shape, (n_devices, 3)) - @parameterized.parameters((False,)) + @parameterized.parameters(True, False) def test_jit_partial_in_shardings_named(self, graph_mode): n_devices = jax.local_device_count() devices = mesh_utils.create_device_mesh((n_devices,)) @@ -1124,7 +1124,7 @@ def f(v, x): y = f_jit(x) self.assertEqual(y.shape, (n_devices, 4)) - @parameterized.parameters((False,)) + @parameterized.parameters(True, False) def test_jit_partial_mixed_shardings(self, graph_mode): n_devices = jax.local_device_count() devices = mesh_utils.create_device_mesh((n_devices,)) @@ -1196,12 +1196,13 @@ def f(c1, c2, x): c1.v[...] += x return c1.v[...] + c2.v[...] - f_jit = nnx.jit_partial(f, c1, c2, graph=graph, graph_updates=False) if not graph: with self.assertRaisesRegex(ValueError, 'Duplicate Param'): - f_jit(jnp.array(1.0)) + nnx.jit_partial(f, c1, c2, graph=graph, graph_updates=False) return + f_jit = nnx.jit_partial(f, c1, c2, graph=graph, graph_updates=False) + y = f_jit(jnp.array(1.0)) np.testing.assert_allclose(y, 4.0) np.testing.assert_allclose(v[...], 2.0)