Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions benchmarks/nnx_graph_overhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_enum(
'mode', 'nnx', ['all', 'nnx', 'jax'], 'Mode to run the script in'
'mode', 'nnx', ['all', 'nnx', 'jax', 'jit_partial'], 'Mode to run the script in'
)
flags.DEFINE_integer('total_steps', 100, 'Total number of training steps')
flags.DEFINE_integer('width', 32, 'Hidden layer size')
Expand Down Expand Up @@ -91,7 +91,6 @@ def main(argv):
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
t0 = time()

@nnx.jit
def step_nnx(model: MLP, optimizer: nnx.Optimizer):
Expand All @@ -108,6 +107,35 @@ def step_nnx(model: MLP, optimizer: nnx.Optimizer):
print('total time:', total_time)
print(f'time per step: {time_per_step * 1e6:.2f} µs')
print(f'time per layer: {time_per_layer * 1e6:.2f} µs')
print()

# ------------------------------------------------------------
# JIT Partial
# ------------------------------------------------------------
if mode in ['all', 'jit_partial']:
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)

def step_partial(model: MLP, optimizer: nnx.Optimizer):
pass

step_partial_jit = nnx.jit_partial(
step_partial, model, optimizer, graph=False
)

t0 = time()
for _ in range(total_steps):
step_partial_jit()

total_time = time() - t0
time_per_step = total_time / total_steps
time_per_layer = time_per_step / depth
print('### JIT PARTIAL ###')
print('total time:', total_time)
print(f'time per step: {time_per_step * 1e6:.2f} µs')
print(f'time per layer: {time_per_layer * 1e6:.2f} µs')
print()

# ------------------------------------------------------------
# JAX
Expand All @@ -117,7 +145,6 @@ def step_nnx(model: MLP, optimizer: nnx.Optimizer):
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
t0 = time()

@jax.jit
def step_jax(graphdef, state):
Expand Down
84 changes: 78 additions & 6 deletions benchmarks/nnx_simple_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.

# %%
import cProfile
import pstats
import io
from functools import partial
import jax
import jax.numpy as jnp
Expand All @@ -27,12 +30,13 @@

FLAGS = flags.FLAGS
flags.DEFINE_enum(
'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in'
'mode', 'all', ['all', 'nnx', 'jax', 'jit_partial'], 'Mode to run the script in'
)
flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps')
flags.DEFINE_integer('batch_size', 32, 'Batch size')
flags.DEFINE_integer('width', 32, 'Hidden layer size')
flags.DEFINE_integer('depth', 5, 'Depth of the model')
flags.DEFINE_bool('profile', False, 'Enable cProfile profiling')


def dataset(X, Y, batch_size):
Expand Down Expand Up @@ -67,13 +71,13 @@ class MLP(nnx.Module):
def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs):
self.count = Count(jnp.array(0))
self.linear_in = Block(din, dhidden, rngs=rngs)
self.intermediates = [
self.intermediates = nnx.List([
Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
]
])
self.linear_out = Block(dhidden, dout, rngs=rngs)

def __call__(self, x):
self.count.value += 1
self.count[...] += 1
x = nnx.relu(self.linear_in(x))
for layer in self.intermediates:
x = nnx.relu(layer(x))
Expand Down Expand Up @@ -118,6 +122,7 @@ def test_step_nnx(model: MLP, batch):
loss = jnp.mean((y - y_pred) ** 2)
return {'loss': loss}

logs = {'loss': jnp.array(0.0)}
for step, batch in enumerate(dataset(X, Y, batch_size)):
train_step_nnx(model, optimizer, batch)

Expand All @@ -132,7 +137,73 @@ def test_step_nnx(model: MLP, batch):
total_time = time() - t0
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
print('times called:', model.count.value)
print('times called:', model.count[...])
print()

if mode == 'jit_partial' or mode == 'all':
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
t0 = time()

def train_step(model: MLP, optimizer: nnx.Optimizer, batch):
x, y = batch

def loss_fn(model: MLP):
y_pred = model(x)
return jnp.mean((y - y_pred) ** 2)

grads = nnx.grad(loss_fn)(model)
optimizer.update(model, grads)

def test_step(model: MLP, batch):
x, y = batch
y_pred = model(x)
loss = jnp.mean((y - y_pred) ** 2)
return {'loss': loss}

train_step_fn = nnx.jit_partial(
train_step, model, optimizer, graph=False
)
test_step_fn = nnx.jit_partial(test_step, model, graph=False)

logs = {'loss': jnp.array(0.0)}
# Warmup
for step, batch in enumerate(dataset(X, Y, batch_size)):
train_step_fn(batch)
if step >= 10:
break

pr = None
if FLAGS.profile:
pr = cProfile.Profile()
pr.enable()

for step, batch in enumerate(dataset(X, Y, batch_size)):
train_step_fn(batch)

if step % 1000 == 0:
logs = test_step_fn((X, Y))

if step >= total_steps - 1:
break

if pr is not None:
pr.disable()
for sort_key in ('cumulative', 'tottime'):
s = io.StringIO()
ps = pstats.Stats(pr, stream=s)
ps.sort_stats(sort_key)
ps.print_stats(40)
print(s.getvalue())

print('### JIT PARTIAL ###')
print(f'final loss: {logs["loss"]}')
total_time = time() - t0
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
print('times called:', model.count[...])
print()

if mode == 'jax' or mode == 'all':
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
Expand Down Expand Up @@ -165,6 +236,7 @@ def test_step_jax(state, batch):

graphdef, state = nnx.split((model, optimizer))

logs = {'loss': jnp.array(0.0)}
for step, batch in enumerate(dataset(X, Y, batch_size)):
state = train_step_jax(state, batch)

Expand All @@ -181,7 +253,7 @@ def test_step_jax(state, batch):
total_time = time() - t0
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
print('times called:', model.count.value)
print('times called:', model.count[...])


if __name__ == '__main__':
Expand Down
72 changes: 27 additions & 45 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ def check_prefix(
graph_updates: bool,
none_leaf: bool = True,
):
unique_prefixes: OrderedDict[tp.Any, tp.Any] = OrderedDict()
unique_prefixes: set[tp.Any] = set()

def _check_prefix(path, leaf):
if isinstance(leaf, variablelib.Variable):
Expand Down Expand Up @@ -798,12 +798,12 @@ def _check_prefix(path, leaf):
)

def _collect_prefix(_, leaf):
unique_prefixes[leaf] = leaf
unique_prefixes.add(leaf)

jax.tree.map_with_path(
_collect_prefix, prefix, is_leaf=lambda x: x is None and none_leaf
)
return unique_prefixes
return list(unique_prefixes)


def variable_changed(post: variablelib.Variable, pre: variablelib.Variable) -> bool:
Expand All @@ -818,27 +818,28 @@ def variable_changed(post: variablelib.Variable, pre: variablelib.Variable) -> b
]


@dataclasses.dataclass(slots=True)
class Updates(
tp.Sequence[tuple[jax.tree_util.KeyPath, variablelib.Variable]],
reprlib.Representable,
):
__slots__ = ('_keys', '_values')
_keys: list[tp.Any] = dataclasses.field(default_factory=list)
_values: list[variablelib.Variable] = dataclasses.field(default_factory=list)

_keys: list[jax.tree_util.KeyPath]
_values: list[variablelib.Variable]

def __init__(
self,
@classmethod
def create(
cls,
items: tp.Iterable[
tuple[jax.tree_util.KeyPath, variablelib.Variable]
] = (),
):
self._keys, self._values = [], []
) -> 'Updates':
keys, values = [], []
for key, value in items:
self._keys.append(key)
self._values.append(value)
keys.append(key)
values.append(value)
return cls(_keys=keys, _values=values)

def append(self, key: jax.tree_util.KeyPath, value: variablelib.Variable):
def append(self, key: tp.Any, value: variablelib.Variable):
self._keys.append(key)
self._values.append(value)

Expand Down Expand Up @@ -880,7 +881,7 @@ def __len__(self):
return len(self._keys)

def __iter__(self):
return iter(zip(self._keys, self._values))
return zip(self._keys, self._values)

def __nnx_repr__(self):
yield reprlib.Object(type=type(self), kv_sep=': ', start='({', end='})')
Expand All @@ -892,30 +893,10 @@ def __nnx_repr__(self):
)


def _updates_flatten_with_keys(x: Updates):
key_children = [
(jax.tree_util.FlattenedIndexKey(i), v)
for i, v in enumerate(x._values)
]
return key_children, x._keys


def _updates_flatten(x: Updates):
return x._values, x._keys


def _updates_unflatten(keys, values) -> Updates:
updates = object.__new__(Updates)
updates._keys = keys
updates._values = list(values)
return updates


jax.tree_util.register_pytree_with_keys(
jax.tree_util.register_dataclass(
Updates,
_updates_flatten_with_keys,
_updates_unflatten,
flatten_func=_updates_flatten,
data_fields=['_values'],
meta_fields=['_keys'],
)

def get_updates(
Expand All @@ -929,7 +910,7 @@ def get_updates(
if keep_fn is None:
keep_fn = lambda _, _pfx, cur, snap: variable_changed(cur, snap)

updates = OrderedDict((pfx, Updates()) for pfx in known_prefixes)
updates = {pfx: Updates.create() for pfx in known_prefixes}

def _mask_updates(path, prefix_leaf, current, snapshot):
if isinstance(current, variablelib.Variable):
Expand All @@ -944,14 +925,14 @@ def _mask_updates(path, prefix_leaf, current, snapshot):
_mask_updates, prefix, current_tree, snapshot_tree, is_leaf=is_leaf,
prefix_leaf=prefix_leaf,
)
return updates
return list(updates.values())


def apply_updates(
variables: dict[jax.tree_util.KeyPath, variablelib.Variable],
updates: OrderedDict[tp.Any, Updates],
updates: list[Updates],
):
for _, flat_state in updates.items():
for flat_state in updates:
for path, update in flat_state:
if path in variables:
variable = variables[path]
Expand All @@ -965,6 +946,7 @@ def apply_updates(
)



def treemap_copy_args(f: F) -> F:
@functools.wraps(f)
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -1110,9 +1092,9 @@ def _apply_prefix(jax_path, leaf):

return jax.tree.map_with_path(_apply_prefix, node, is_leaf=is_leaf)

def to_masked(tree, all_updates: OrderedDict[tp.Any, Updates]):
combined: OrderedDict[tp.Any, tp.Any] = OrderedDict()
for updates in all_updates.values():
def to_masked(tree, all_updates: list[Updates]):
combined: dict[tp.Any, tp.Any] = {}
for updates in all_updates:
combined.update(updates)
return jax.tree.map_with_path(
lambda path, _: combined.get(path, None), tree,
Expand Down
Loading
Loading