Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,3 +1137,8 @@ def filter_kwargs(f, **kwargs):
}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in named_params}
return filtered_kwargs

def update_captures(captures: list[variablelib.Capture]):
for capture in captures:
variable, update = capture.ref.obj, capture.variable
variable.update_from_state(update)
52 changes: 42 additions & 10 deletions flax/nnx/transforms/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,20 +78,24 @@ def __post_init__(self):

@extract.treemap_copy_args
def __call__(self, *args, **kwargs):
ctx = variablelib.current_capture_context(True)
ctx.set_trace_state()
current, snapshot = extract.snapshot(labeled(args=args, kwargs=kwargs))
if self.graph:
args, kwargs = extract.from_tree2((args, kwargs))
out = self.f(*args, **kwargs)
if self.graph:
out = extract.to_tree2(out)
extract.check_no_aliases('grad', **current, out=out, check=['out'])
extract.check_no_aliases(
'grad', **current, captures=ctx.captures, out=out, check=['out']
)
updates = extract.get_updates(current, snapshot)

if self.has_aux:
loss, aux = out
return loss, (updates, aux)
return loss, (updates, ctx.captures, aux)
else:
return out, updates
return out, (updates, ctx.captures)


@dataclasses.dataclass(eq=False)
Expand Down Expand Up @@ -159,7 +163,7 @@ def _grad_general(
allow_int=allow_int,
)

def tree_grad_wrapper(*args, **kwargs):
def simple_grad_wrapper(*args, **kwargs):
if graph:
diff_argnums = (argnums,) if isinstance(argnums, int) else argnums
args_prefix = tuple(
Expand All @@ -171,31 +175,33 @@ def tree_grad_wrapper(*args, **kwargs):

variables = extract.check_no_aliases('grad', args=args, kwargs=kwargs)

fn_out = gradded_fn(*args, **kwargs)
with variablelib.capture_context(variables):
fn_out = gradded_fn(*args, **kwargs)

if return_value:
if has_aux:
(loss, (updates, aux)), grads = fn_out
(loss, (updates, captures, aux)), grads = fn_out
if graph: grads, aux = extract.from_tree2((grads, aux))
result = (loss, aux), grads
else:
(loss, updates), grads = fn_out
(loss, (updates, captures)), grads = fn_out
if graph: grads = extract.from_tree2(grads)
result = loss, grads
else:
if has_aux:
grads, (updates, aux) = fn_out
grads, (updates, captures, aux) = fn_out
if graph: grads, aux = extract.from_tree2((grads, aux))
result = grads, aux
else:
grads, updates = fn_out
grads, (updates, captures) = fn_out
if graph: grads = extract.from_tree2(grads)
result = grads

extract.update_captures(captures)
extract.apply_updates(variables, updates)
return result

return tree_grad_wrapper
return simple_grad_wrapper

jax_argnums: int | tuple[int, ...]
if isinstance(argnums, (int, DiffState)):
Expand Down Expand Up @@ -288,6 +294,8 @@ def process_out(pure_out: A, /) -> A:


@tp.overload


def grad(
f: tp.Callable[..., tp.Any],
*,
Expand All @@ -300,6 +308,8 @@ def grad(
graph_updates: bool | None = None,
) -> tp.Callable[..., tp.Any]: ...
@tp.overload


def grad(
*,
argnums: int | DiffState | tp.Sequence[int | DiffState] = 0,
Expand Down Expand Up @@ -437,6 +447,8 @@ def grad(


@tp.overload


def value_and_grad(
f: tp.Callable[..., tp.Any],
*,
Expand All @@ -449,6 +461,8 @@ def value_and_grad(
graph_updates: bool | None = None,
) -> tp.Callable[..., tp.Any]: ...
@tp.overload


def value_and_grad(
*,
argnums: int | DiffState | tp.Sequence[int | DiffState] = 0,
Expand Down Expand Up @@ -581,6 +595,8 @@ def __call__(self, *primals):


@tp.overload


def vjp(
f: tp.Callable[..., tp.Any],
*primals: tp.Any,
Expand All @@ -590,6 +606,8 @@ def vjp(
graph_updates: bool | None = None,
) -> tuple[tp.Any, tp.Callable] | tuple[tp.Any, tp.Callable, tp.Any]: ...
@tp.overload


def vjp(
*,
has_aux: bool = False,
Expand Down Expand Up @@ -747,6 +765,8 @@ def __call__(self, *primals):


@tp.overload


def jvp(
f: tp.Callable[..., tp.Any],
primals: tuple[tp.Any, ...],
Expand All @@ -757,13 +777,17 @@ def jvp(
graph_updates: bool | None = None,
) -> tuple[tp.Any, ...]: ...
@tp.overload


def jvp(
*,
has_aux: bool = False,
graph: bool | None = None,
graph_updates: bool | None = None,
) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ...
@tp.overload


def jvp(
f: tp.Callable[..., tp.Any],
*,
Expand Down Expand Up @@ -1338,6 +1362,8 @@ def defvjp(


@tp.overload


def custom_vjp(
fun: tp.Callable[..., A],
*,
Expand All @@ -1346,6 +1372,8 @@ def custom_vjp(
graph_updates: bool | None = None,
) -> CustomVjp[A] | SimpleCustomVjp[A]: ...
@tp.overload


def custom_vjp(
*,
nondiff_argnums: tuple[int | DiffState, ...] = (),
Expand Down Expand Up @@ -1595,6 +1623,8 @@ def __call__(self, *args, **kwargs):
return out, updates

@tp.overload


def remat(
*,
prevent_cse: bool = True,
Expand All @@ -1604,6 +1634,8 @@ def remat(
graph_updates: bool | None = None,
) -> tp.Callable[[F], F]: ...
@tp.overload


def remat(
f: F,
*,
Expand Down
106 changes: 101 additions & 5 deletions flax/nnx/variablelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# pytype: skip-file
from __future__ import annotations

import contextlib
import dataclasses
import functools
from functools import partial
Expand Down Expand Up @@ -62,17 +63,86 @@
# JAX v0.7.0 or older
from jax.experimental import MutableArray as Ref # type: ignore[no-redef]

@dataclasses.dataclass(eq=False, slots=True)
class Reference(tp.Generic[A]):
obj: A

@jtu.register_dataclass
@dataclasses.dataclass(frozen=True, slots=True)
class Capture:
ref: Reference[Variable] = jax.tree.static()
variable: Variable

@classmethod
def create(cls, value: A):
return cls(Reference(value), value)

@dataclasses.dataclass(slots=True)
class VariableCaptureContext:
input_variables: dict[int, tuple] # id -> path
trace_state: tracers.TraceState | None
captures: list[Capture]

def set_trace_state(self):
assert self.trace_state is None
self.trace_state = tracers.TraceState()

def is_valid(self):
return self.trace_state is not None and self.trace_state.is_valid()

def capture(self, variable: Variable):
if id(variable) in self.input_variables:
path = self.input_variables[id(variable)]
raise ValueError(
f'Cannot mutate captured Variable {variable!r} because it is an alias of '
f'input Variable at {jtu.keystr(path)}.'
)
self.captures.append(Capture.create(variable))

@dataclasses.dataclass
class VariableContext(threading.local):
variable_hijax_stack: list[bool] = dataclasses.field(default_factory=list)
variable_ref_stack: list[bool] = dataclasses.field(default_factory=list)
eager_shard_stack: list[bool] = dataclasses.field(default_factory=list)
capture_stack: list[VariableCaptureContext] = dataclasses.field(default_factory=list)


VARIABLE_CONTEXT = VariableContext()


@tp.overload
def current_capture_context(check: tp.Literal[True]) -> VariableCaptureContext:
...

@tp.overload
def current_capture_context(
check: bool = False,
) -> VariableCaptureContext | None:
...

def current_capture_context(
check: bool = False,
) -> VariableCaptureContext | None:
if not VARIABLE_CONTEXT.capture_stack:
if check:
raise ValueError('No capture context found.')
return None
return VARIABLE_CONTEXT.capture_stack[-1]


@contextlib.contextmanager
def capture_context(variables: dict[jtu.KeyPath, Variable]):
ctx = VariableCaptureContext(
input_variables={id(v): path for path, v in variables.items()},
trace_state=None,
captures=[],
)
VARIABLE_CONTEXT.capture_stack.append(ctx)
try:
yield ctx
finally:
VARIABLE_CONTEXT.capture_stack.pop()

class use_eager_sharding(BaseConfigContext):
"""Sets whether Variables should use eager sharding by default or not.

Expand Down Expand Up @@ -147,10 +217,18 @@ def __len__(self) -> int:


@tp.overload




def var_defaults() -> VarDefaults: ...


@tp.overload




def var_defaults(
*, hijax: bool | None = None, ref: bool | None = None
) -> VarDefaultsContext: ...
Expand Down Expand Up @@ -1388,16 +1466,26 @@ def __init__(
@property
def _can_update(self) -> bool:
"""Whether the Variable can be updated in-place in the current trace context."""
if self.hijax:
return True
else:
return self._trace_state.is_valid()
can_update, _ = self._can_update_and_captured()
return can_update

def _check_can_update(self):
if not self.hijax and not self._trace_state.is_valid():
can_update, got_captured = self._can_update_and_captured()
if not can_update:
raise errors.TraceContextError(
f'Cannot mutate {type(self).__name__} from a different trace level'
)
if got_captured:
current_capture_context(True).capture(self)

def _can_update_and_captured(self):
if self.hijax:
return True, False
elif self._trace_state.is_valid():
return True, False
elif (capture_ctx := current_capture_context()) and capture_ctx.is_valid():
return True, True
return False, False

def __getattr__(self, name: str) -> tp.Any:
if name in object.__getattribute__(self, '_var_metadata'):
Expand Down Expand Up @@ -2381,6 +2469,10 @@ def variable_name_from_type(


@tp.overload




def register_variable_name(
name: str,
typ: type[Variable[tp.Any]],
Expand All @@ -2390,6 +2482,10 @@ def register_variable_name(


@tp.overload




def register_variable_name(
name: str,
*,
Expand Down
Loading
Loading