From 8955e87f15098cf6c0922b10145f7ee2b1df65da Mon Sep 17 00:00:00 2001 From: Flax Team Date: Thu, 19 Mar 2026 03:48:08 -0700 Subject: [PATCH] Upgrade third_party/python_runtime/v3_13_unstable to 3.13.11 Replaced custom PriorityStr with standard tuple sorting in _graph_node_flatten. Python 3.13 optimized str comparison bypasses custom subclasses (causing TypeError), and used getattr in _setattr because unflattening order becomes non-deterministic (preventing AttributeError). PiperOrigin-RevId: 886082903 --- flax/nnx/bridge/module.py | 42 ++++++++------------------------------- 1 file changed, 8 insertions(+), 34 deletions(-) diff --git a/flax/nnx/bridge/module.py b/flax/nnx/bridge/module.py index 5dde1343f..443e368da 100644 --- a/flax/nnx/bridge/module.py +++ b/flax/nnx/bridge/module.py @@ -188,34 +188,6 @@ class AttrPriority(enum.IntEnum): LOW = 100 -class PriorityStr(str): - _priority: AttrPriority - - def __new__(cls, priority: AttrPriority, value: str): - obj = super().__new__(cls, value) - obj._priority = priority - return obj - - def _check_and_get_priority(self, other) -> AttrPriority: - if not isinstance(other, (str, PriorityStr)): - raise NotImplementedError( - f'Cannot compare {type(self)} with {type(other)}' - ) - if isinstance(other, PriorityStr): - return other._priority - return AttrPriority.DEFAULT - - def __lt__(self, other) -> bool: - other_priority = self._check_and_get_priority(other) - if self._priority == other_priority: - return super().__lt__(other) - return self._priority < other_priority - - def __gt__(self, other) -> bool: - other_priority = self._check_and_get_priority(other) - if self._priority == other_priority: - return super().__gt__(other) - return self._priority > other_priority class ModuleBase: if tp.TYPE_CHECKING: @@ -241,7 +213,7 @@ def _getattr(self, name: str) -> tp.Any: return value def _setattr(self, name: str, value: tp.Any) -> None: - if self.scope is not None: + if getattr(self, 'scope', None) is not None: if name in vars(self) and isinstance( state := vars(self)[name], ModuleState ): @@ -254,11 +226,13 @@ def _setattr(self, name: str, value: tp.Any) -> None: def _graph_node_flatten(self): nodes = vars(self).copy() - keys = ( - PriorityStr(self.attr_priorities.get(k, AttrPriority.DEFAULT), k) - for k in nodes.keys() - ) - sorted_nodes = list((k, nodes[k]) for k in sorted(keys)) + def get_priority(k): + if k in ('scope', '_pytree__state', 'attr_priorities'): + return AttrPriority.HIGH + return self.attr_priorities.get(k, AttrPriority.DEFAULT) + + sorted_keys = sorted(nodes.keys(), key=lambda k: (get_priority(k), k)) + sorted_nodes = list((k, nodes[k]) for k in sorted_keys) return sorted_nodes, type(self) def set_attr_priority(self, name: str, value: AttrPriority):