From a147752a945e503b8de21f8adf9be7466fc72f81 Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Tue, 19 May 2026 18:32:43 -0700 Subject: [PATCH] Migrate away from private jax._src.tree_util._registry PiperOrigin-RevId: 918133843 --- flax/nnx/graphlib.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flax/nnx/graphlib.py b/flax/nnx/graphlib.py index a3a0aa8fb..fc00c83ba 100644 --- a/flax/nnx/graphlib.py +++ b/flax/nnx/graphlib.py @@ -3627,7 +3627,7 @@ class Static(tp.Generic[A]): class GenericPytree: ... -from jax._src.tree_util import _registry as JAX_PYTREE_REGISTRY + def is_pytree_node( @@ -3637,7 +3637,7 @@ def is_pytree_node( return False elif isinstance(x, Variable): return False - elif type(x) in JAX_PYTREE_REGISTRY: + elif jax.tree_util.is_tree_node(type(x)): return True elif isinstance(x, tuple): return True