Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 113 additions & 1 deletion flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import dataclasses
import functools
import inspect
import types
import typing as tp

from flax import struct
Expand Down Expand Up @@ -676,11 +677,25 @@ def copy_var_structure(tree: A) -> A:
)

def check_no_aliases(
fn_name: str, /, *, check: tp.Iterable[str] = (), **kwargs
fn_name: str,
/,
*,
check: tp.Iterable[str] = (),
capture_ids: frozenset[int] = frozenset(),
**kwargs,
) -> dict[jax.tree_util.KeyPath, variablelib.Variable]:
"""Check that no Variable appears more than once in the given collections.

When ``capture_ids`` is provided (ids of Variables belonging to closure
captures), a Variable that was first seen as a capture is allowed to appear
exactly once more as an explicit argument (deduplication case). A Variable
seen twice in non-capture positions, or three or more times total, always
raises.
"""
container = labeled(**kwargs)
is_leaf = lambda x: isinstance(x, variablelib.Variable)
seen: dict[int, jax.tree_util.KeyPath] = {}
dedup_used: set[int] = set() # capture_ids that have matched one explicit arg
all_variables: dict[jax.tree_util.KeyPath, variablelib.Variable] = {}
for path, leaf in jax.tree.leaves_with_path(container, is_leaf=is_leaf):
if not isinstance(leaf, variablelib.Variable):
Expand All @@ -700,6 +715,12 @@ def check_no_aliases(

var_id = id(leaf)
if var_id in seen:
if var_id in capture_ids and var_id not in dedup_used:
# First duplicate: capture prepended + same Variable passed explicitly.
# Allow this once (deduplication); record so a third copy still errors.
dedup_used.add(var_id)
all_variables[path] = leaf
continue
path_str = jax.tree_util.keystr(path)
seen_path_str = jax.tree_util.keystr(seen[var_id])
raise ValueError(
Expand Down Expand Up @@ -1119,6 +1140,97 @@ def to_masked(tree, all_updates: OrderedDict[tp.Any, Updates]):
is_leaf=lambda x: x is None
)

@dataclasses.dataclass(frozen=True)
class CapturedInfo:
"""Info about closure-captured graph nodes/Variables."""
nodes: tuple[object, ...]
freevar_indices: tuple[int, ...]

def __len__(self):
return len(self.nodes)

def __bool__(self):
return len(self.nodes) > 0

@property
def variable_ids(self) -> frozenset[int]:
"""Ids of all Variables contained within the captured nodes."""
ids: set[int] = set()
for _, value in graphlib.iter_graph(self.nodes, graph=True):
if isinstance(value, variablelib.Variable):
ids.add(id(value))
return frozenset(ids)

def exclude_from_args(
self,
args: tuple,
kwargs: dict | tp.Mapping,
) -> 'CapturedInfo':
"""Remove captured nodes that also appear in explicit args (by id)."""
if not self.nodes:
return self
arg_ids: set[int] = set()
for _, value in graphlib.iter_graph((args, kwargs), graph=True):
if isinstance(value, variablelib.Variable):
arg_ids.add(id(value))
keep = [(node, idx) for node, idx in zip(self.nodes, self.freevar_indices)
if id(node) not in arg_ids]
if len(keep) == len(self.nodes):
return self
if not keep:
return CapturedInfo((), ())
nodes, indices = zip(*keep)
return CapturedInfo(tuple(nodes), tuple(indices))


_EMPTY_CAPTURED = CapturedInfo((), ())


def find_captured_nodes(f) -> CapturedInfo:
"""Find top-level nnx Variables and graph nodes in f's closure."""
if not hasattr(f, '__closure__') or f.__closure__ is None:
return _EMPTY_CAPTURED
closure = f.__closure__
seen: set[int] = set()
captured_nodes: list[object] = []
captured_indices: list[int] = []
for i, cell in enumerate(closure):
try:
obj = cell.cell_contents
except ValueError:
continue
if graphlib.is_graph_node(obj) or isinstance(obj, variablelib.Variable):
obj_id = id(obj)
if obj_id not in seen:
seen.add(obj_id)
captured_nodes.append(obj)
captured_indices.append(i)
return CapturedInfo(tuple(captured_nodes), tuple(captured_indices))


def _make_cell(val):
"""Create a closure cell containing val."""
x = val
return (lambda: x).__closure__[0] # type: ignore


def replace_closure_cells(
f: tp.Callable,
captured_info: CapturedInfo,
replacements: tuple,
) -> tp.Callable:
"""Return a copy of f with captured closure cells replaced by replacements."""
if not captured_info or not f.__closure__:
return f
cells = list(f.__closure__)
for idx, new_val in zip(captured_info.freevar_indices, replacements):
cells[idx] = _make_cell(new_val)
return types.FunctionType(
f.__code__, f.__globals__, f.__name__,
f.__defaults__, tuple(cells),
)


def filter_kwargs(f, **kwargs):
sig = inspect.signature(f)
has_var_keyword = any(
Expand Down
40 changes: 31 additions & 9 deletions flax/nnx/transforms/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class SimpleGradFn:
f: tp.Callable[..., tp.Any]
has_aux: bool
graph: bool
captured_info: extract.CapturedInfo = extract._EMPTY_CAPTURED

def __post_init__(self):
functools.update_wrapper(self, self.f, updated=())
Expand All @@ -81,7 +82,14 @@ def __call__(self, *args, **kwargs):
current, snapshot = extract.snapshot(labeled(args=args, kwargs=kwargs))
if self.graph:
args, kwargs = extract.from_tree2((args, kwargs))
out = self.f(*args, **kwargs)
n_captured = len(self.captured_info)
user_args = args[n_captured:]
if self.captured_info:
f = extract.replace_closure_cells(
self.f, self.captured_info, args[:n_captured])
else:
f = self.f
out = f(*user_args, **kwargs)
if self.graph:
out = extract.to_tree2(out)
extract.check_no_aliases('grad', **current, out=out, check=['out'])
Expand Down Expand Up @@ -151,27 +159,41 @@ def _grad_general(

if not graph or not graph_updates:

captured_info = extract.find_captured_nodes(f)

# argnums index into all_args = (*captured_nodes, *explicit_args).
# Captured nodes occupy positions 0..n_cap-1; explicit args follow.
# No shift is applied: argnums=0 refers to the first captured node (e.g. the
# model), which is the common case for closure-based training functions.
gradded_fn = transform(
SimpleGradFn(f, has_aux, graph=graph),
SimpleGradFn(f, has_aux, graph=graph, captured_info=captured_info),
argnums=argnums, # type: ignore[arg-type]
has_aux=True,
holomorphic=holomorphic,
allow_int=allow_int,
)

def tree_grad_wrapper(*args, **kwargs):
all_args = (*captured_info.nodes, *args)
if graph:
diff_argnums = (argnums,) if isinstance(argnums, int) else argnums
args_prefix = tuple(
i in diff_argnums for i in range(len(args))
n_cap = len(captured_info)
diff_argnum_ints = frozenset(
a.argnum if isinstance(a, DiffState) else a
for a in ((argnums,) if isinstance(argnums, (int, DiffState)) else argnums)
)
args, kwargs = extract.to_tree2(
(args, kwargs), prefix=(args_prefix, False),
args_prefix = tuple(j in diff_argnum_ints for j in range(n_cap)) + tuple(
(n_cap + i) in diff_argnum_ints for i in range(len(args))
)
all_args, kwargs = extract.to_tree2(
(all_args, kwargs), prefix=(args_prefix, False),
)

variables = extract.check_no_aliases('grad', args=args, kwargs=kwargs)
variables = extract.check_no_aliases(
'grad', args=all_args, kwargs=kwargs,
capture_ids=captured_info.variable_ids,
)

fn_out = gradded_fn(*args, **kwargs)
fn_out = gradded_fn(*all_args, **kwargs)

if return_value:
if has_aux:
Expand Down
65 changes: 50 additions & 15 deletions flax/nnx/transforms/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ class SimpleJitFn:
donate_argnames: frozenset[str]
graph: bool
update_shardings: tuple[tp.Any, ...]
captured_info: extract.CapturedInfo = extract._EMPTY_CAPTURED

def __post_init__(self):
functools.update_wrapper(self, self.f, updated=())
Expand All @@ -464,7 +465,14 @@ def __call__(self, *args, **kwargs):
)
if self.graph:
args, kwargs = extract.from_tree2((args, kwargs))
out = self.f(*args, **kwargs)
n_captured = len(self.captured_info)
user_args = args[n_captured:]
if self.captured_info:
f = extract.replace_closure_cells(
self.f, self.captured_info, args[:n_captured])
else:
f = self.f
out = f(*user_args, **kwargs)
if self.graph:
out = extract.to_tree2(out, prefix=self.out_shardings)
extract.check_no_aliases('jit', **current, out=out, check=['out'])
Expand Down Expand Up @@ -508,6 +516,9 @@ def __init__(
self.partial_args = partial_args
self.graph = graph

# Capture closure nodes once at construction time
self._captured_info = extract.find_captured_nodes(fun)

resolved = _resolve_argnums(fun, static_argnums, static_argnames)
if isinstance(in_shardings, (tuple, list)) and resolved:
expanded = list(in_shardings)
Expand All @@ -517,6 +528,13 @@ def __init__(
else:
self.in_shardings = in_shardings

# Prepend None shardings for captured nodes
n_captured = len(self._captured_info)
if n_captured > 0 and isinstance(in_shardings, (tuple, list)):
jit_in_shardings = (None,) * n_captured + tuple(in_shardings)
else:
jit_in_shardings = in_shardings

donate_argnums_set = frozenset(
(donate_argnums,) if isinstance(donate_argnums, int)
else donate_argnums or ()
Expand All @@ -528,14 +546,15 @@ def __init__(
self.jitted_fn = jax.jit(
SimpleJitFn(
fun,
self.in_shardings,
jit_in_shardings if n_captured > 0 else self.in_shardings,
out_shardings,
donate_argnums_set,
donate_argnames_set,
graph,
tuple(update_shardings),
captured_info=self._captured_info,
),
in_shardings=in_shardings,
in_shardings=jit_in_shardings,
out_shardings=(out_shardings, update_shardings),
static_argnums=static_argnums,
static_argnames=static_argnames,
Expand Down Expand Up @@ -565,9 +584,13 @@ def _maybe_from_tree(self, out):

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
args = (*self.partial_args, *args) # type: ignore[assignment]
args, kwargs = self._maybe_to_tree(args, kwargs)
variables = extract.check_no_aliases('jit', args=args, kwargs=kwargs)
out, updates = self.jitted_fn(*args, **kwargs)
all_args = (*self._captured_info.nodes, *args)
all_args, kwargs = self._maybe_to_tree(all_args, kwargs)
variables = extract.check_no_aliases(
'jit', args=all_args, kwargs=kwargs,
capture_ids=self._captured_info.variable_ids,
)
out, updates = self.jitted_fn(*all_args, **kwargs)
extract.apply_updates(variables, updates)
return self._maybe_from_tree(out)

Expand All @@ -578,26 +601,38 @@ def __get__(self, obj, objtype=None):

def eval_shape(self, *args, **kwargs):
args = (*self.partial_args, *args)
args, kwargs = self._maybe_to_tree(args, kwargs)
all_args = (*self._captured_info.nodes, *args)
all_args, kwargs = self._maybe_to_tree(all_args, kwargs)
if not self.graph:
extract.check_no_aliases('jit', args=args, kwargs=kwargs)
out, updates = self.jitted_fn.eval_shape(*args, **kwargs)
extract.check_no_aliases(
'jit', args=all_args, kwargs=kwargs,
capture_ids=self._captured_info.variable_ids,
)
out, updates = self.jitted_fn.eval_shape(*all_args, **kwargs)
return self._maybe_from_tree(out)

def trace(self, *args, **kwargs):
args = (*self.partial_args, *args)
args, kwargs = self._maybe_to_tree(args, kwargs)
all_args = (*self._captured_info.nodes, *args)
all_args, kwargs = self._maybe_to_tree(all_args, kwargs)
if not self.graph:
extract.check_no_aliases('jit', args=args, kwargs=kwargs)
traced = self.jitted_fn.trace(*args, **kwargs)
extract.check_no_aliases(
'jit', args=all_args, kwargs=kwargs,
capture_ids=self._captured_info.variable_ids,
)
traced = self.jitted_fn.trace(*all_args, **kwargs)
return SimpleTraced(traced, self)

def lower(self, *args, **kwargs):
args = (*self.partial_args, *args)
args, kwargs = self._maybe_to_tree(args, kwargs)
all_args = (*self._captured_info.nodes, *args)
all_args, kwargs = self._maybe_to_tree(all_args, kwargs)
if not self.graph:
extract.check_no_aliases('jit', args=args, kwargs=kwargs)
lowered = self.jitted_fn.lower(*args, **kwargs)
extract.check_no_aliases(
'jit', args=all_args, kwargs=kwargs,
capture_ids=self._captured_info.variable_ids,
)
lowered = self.jitted_fn.lower(*all_args, **kwargs)
return SimpleLowered(lowered, self)
def jit_partial(
fun: tp.Callable[..., R],
Expand Down
Loading
Loading