Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
from .nn.normalization import BatchNorm as BatchNorm
from .nn.normalization import LayerNorm as LayerNorm
from .nn.normalization import RMSNorm as RMSNorm
from .nn.normalization import DyT as DyT
from .nn.normalization import GroupNorm as GroupNorm
from .nn.normalization import InstanceNorm as InstanceNorm
from .nn.normalization import SpectralNorm as SpectralNorm
Expand Down
159 changes: 159 additions & 0 deletions flax/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,165 @@ def __call__(self, x, mask: tp.Optional[jax.Array] = None):
)


class DyT(Module):
"""Dynamic Tanh (DyT) normalization (https://arxiv.org/abs/2503.10622).

DyT is a normalization-free alternative to LayerNorm/RMSNorm that replaces
statistical normalization (mean/variance reduction) with a learned pointwise
bounding function:

.. math::

y = \\gamma \\cdot \\tanh(\\alpha \\cdot x) + \\beta

where ``alpha`` is a learnable scalar controlling input scaling, ``gamma`` is
a per-feature scale, and ``beta`` is a per-feature bias. Unlike LayerNorm
and RMSNorm, DyT requires **no reduction operations** across the feature
dimension, making it particularly efficient for tensor-parallel training
where reductions require cross-device communication.

Example usage::

>>> from flax import nnx
>>> import jax

>>> x = jax.random.normal(jax.random.key(0), (5, 6))
>>> layer = nnx.DyT(num_features=6, rngs=nnx.Rngs(0))

>>> nnx.state(layer, nnx.Param)
State({
'alpha': Param(
value=Array(0.5, dtype=float32)
),
'bias': Param(
value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
),
'scale': Param(
value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
)
})

>>> y = layer(x)

Reference: *Transformers without normalization* (Zhu et al., 2025).

Args:
num_features: the number of input features.
alpha_init: initial value for the learnable scalar ``alpha`` that controls
the input scaling before tanh (default: 0.5).
dtype: the dtype of the result (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
use_bias: If True, bias (beta) is added.
use_scale: If True, multiply by scale (gamma). When the next layer is
linear (also e.g. nn.relu), this can be disabled since the scaling
will be done by the next layer.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one.
feature_axes: Feature axes for learned scale and bias parameters. All
specified axes will have independent learned parameters.
promote_dtype: function to promote the dtype of all input array arguments
(including Variables accessed through ``self``) to the desired dtype. The
function should accept a tuple of ``(inputs, scale, bias)`` and a
``dtype`` keyword argument, and return a tuple of arrays with the
promoted dtype.
rngs: rng key.
bias_metadata: Optional metadata dictionary to set when initializing
the bias.
scale_metadata: Optional metadata dictionary to set when initializing
the scale.
alpha_metadata: Optional metadata dictionary to set when initializing
the alpha.
"""

def __init__(
self,
num_features: int,
*,
alpha_init: float = 0.5,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
use_bias: bool = True,
use_scale: bool = True,
bias_init: Initializer = initializers.zeros,
scale_init: Initializer = initializers.ones,
feature_axes: Axes = -1,
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
rngs: rnglib.Rngs,
bias_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}),
scale_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}),
alpha_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}),
):
feature_shape = (num_features,)

self.alpha = nnx.Param(
jnp.array(alpha_init, dtype=param_dtype), **alpha_metadata
)

self.scale: nnx.Param[jax.Array] | None
if use_scale:
key = rngs.params()
self.scale = nnx.Param(
scale_init(key, feature_shape, param_dtype), **scale_metadata
)
else:
self.scale = nnx.data(None)

self.bias: nnx.Param[jax.Array] | None
if use_bias:
key = rngs.params()
self.bias = nnx.Param(
bias_init(key, feature_shape, param_dtype), **bias_metadata
)
else:
self.bias = nnx.data(None)

self.num_features = num_features
self.dtype = dtype
self.param_dtype = param_dtype
self.use_bias = use_bias
self.use_scale = use_scale
self.feature_axes = feature_axes
self.promote_dtype = promote_dtype

def __call__(self, x):
"""Applies Dynamic Tanh normalization on the input.

Args:
x: the inputs.

Returns:
Normalized inputs (the same shape as inputs).
"""
feature_axes = _canonicalize_axes(x.ndim, self.feature_axes)
feature_shape = [1] * x.ndim
for ax in feature_axes:
feature_shape[ax] = x.shape[ax]

# Promote dtypes for input and all Variables
alpha = self.alpha[...]
scale = self.scale[...] if self.scale else None
bias = self.bias[...] if self.bias else None
x, alpha, scale, bias = self.promote_dtype(
(x, alpha, scale, bias), dtype=self.dtype
)

# Core DyT operation: y = gamma * tanh(alpha * x) + beta
y = jnp.tanh(alpha * x)

args = [x]
if scale is not None:
scale = scale.reshape(feature_shape)
y = y * scale
args.append(scale)
if bias is not None:
bias = bias.reshape(feature_shape)
y = y + bias
args.append(bias)

dtype = dtypes.canonicalize_dtype(*args, dtype=self.dtype)
return jnp.asarray(y, dtype)


class GroupNorm(Module):
"""Group normalization (arxiv.org/abs/1803.08494).

Expand Down
156 changes: 156 additions & 0 deletions tests/nnx/nn/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,5 +590,161 @@ def __call__(self, x):
)


class TestDyT(parameterized.TestCase):
"""Tests for DyT (Dynamic Tanh) normalization module."""

def test_dyt_default_initialization(self):
"""Test that DyT initializes with correct default parameter values."""
layer = nnx.DyT(num_features=6, rngs=nnx.Rngs(0))
state = nnx.state(layer, nnx.Param)

# alpha should be a scalar initialized to 0.5
np.testing.assert_allclose(state['alpha'].value, 0.5)
# scale (gamma) should be ones
np.testing.assert_array_equal(
state['scale'].value, np.ones(6, dtype=np.float32)
)
# bias (beta) should be zeros
np.testing.assert_array_equal(
state['bias'].value, np.zeros(6, dtype=np.float32)
)

def test_dyt_forward_pass_manual(self):
"""Test DyT forward pass matches manual computation."""
layer = nnx.DyT(num_features=4, rngs=nnx.Rngs(0))
x = jnp.array([[1.0, -2.0, 3.0, -4.0]])

y = layer(x)
# With default params: y = 1.0 * tanh(0.5 * x) + 0.0
expected = jnp.tanh(0.5 * x)
np.testing.assert_allclose(y, expected, atol=1e-6)

def test_dyt_output_shape(self):
"""Test that DyT preserves input shape."""
layer = nnx.DyT(num_features=6, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(0), (5, 6))
y = layer(x)
self.assertEqual(x.shape, y.shape)

def test_dyt_output_bounded(self):
"""Test that DyT output is bounded by scale (gamma) values."""
layer = nnx.DyT(num_features=4, rngs=nnx.Rngs(0))
# Large inputs should be bounded by tanh saturation
x = jnp.array([[100.0, -100.0, 1000.0, -1000.0]])
y = layer(x)
# With default scale=1 and bias=0, output should be in [-1, 1]
self.assertTrue(jnp.all(jnp.abs(y) <= 1.0 + 1e-6))

def test_dyt_gradient_flow(self):
"""Test that gradients flow through all learnable parameters."""
layer = nnx.DyT(num_features=4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(0), (3, 4))

def loss_fn(model):
return jnp.sum(model(x) ** 2)

grads = nnx.grad(loss_fn)(layer)
grad_state = nnx.state(grads, nnx.Param)

# All parameters should have non-zero gradients
self.assertFalse(jnp.allclose(grad_state['alpha'].value, 0.0))
self.assertFalse(jnp.allclose(grad_state['scale'].value, 0.0))
self.assertFalse(jnp.allclose(grad_state['bias'].value, 0.0))

@parameterized.product(
dtype=[jnp.float32, jnp.float16],
param_dtype=[jnp.float32, jnp.float16],
)
def test_dyt_dtype_handling(
self,
dtype: Dtype,
param_dtype: Dtype,
):
"""Test DyT with different input and parameter dtypes."""
layer = nnx.DyT(
num_features=6,
dtype=dtype,
param_dtype=param_dtype,
rngs=nnx.Rngs(0),
)
x = jax.random.normal(jax.random.key(0), (5, 6))
y = layer(x)
self.assertEqual(y.dtype, dtype)

def test_dyt_no_scale(self):
"""Test DyT with use_scale=False."""
layer = nnx.DyT(
num_features=4, use_scale=False, rngs=nnx.Rngs(0)
)
x = jnp.array([[1.0, -2.0, 3.0, -4.0]])
y = layer(x)
# Without scale: y = tanh(0.5 * x) + 0.0
expected = jnp.tanh(0.5 * x)
np.testing.assert_allclose(y, expected, atol=1e-6)

def test_dyt_no_bias(self):
"""Test DyT with use_bias=False."""
layer = nnx.DyT(
num_features=4, use_bias=False, rngs=nnx.Rngs(0)
)
x = jnp.array([[1.0, -2.0, 3.0, -4.0]])
y = layer(x)
# Without bias: y = 1.0 * tanh(0.5 * x)
expected = jnp.tanh(0.5 * x)
np.testing.assert_allclose(y, expected, atol=1e-6)

def test_dyt_no_scale_no_bias(self):
"""Test DyT with both use_scale=False and use_bias=False."""
layer = nnx.DyT(
num_features=4,
use_scale=False,
use_bias=False,
rngs=nnx.Rngs(0),
)
x = jnp.array([[1.0, -2.0, 3.0, -4.0]])
y = layer(x)
expected = jnp.tanh(0.5 * x)
np.testing.assert_allclose(y, expected, atol=1e-6)

def test_dyt_custom_alpha(self):
"""Test DyT with a custom alpha initialization."""
layer = nnx.DyT(
num_features=4, alpha_init=1.0, rngs=nnx.Rngs(0)
)
x = jnp.array([[1.0, -2.0, 3.0, -4.0]])
y = layer(x)
expected = jnp.tanh(1.0 * x)
np.testing.assert_allclose(y, expected, atol=1e-6)

def test_dyt_multidim_input(self):
"""Test DyT with higher-dimensional inputs (e.g., image-like)."""
# (batch, height, width, channels)
layer = nnx.DyT(num_features=8, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(0), (2, 4, 4, 8))
y = layer(x)
self.assertEqual(y.shape, (2, 4, 4, 8))

def test_dyt_jit_compatible(self):
"""Test that DyT works correctly under jax.jit."""
layer = nnx.DyT(num_features=4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(0), (3, 4))

@nnx.jit
def forward(model, x):
return model(x)

y_jit = forward(layer, x)
y_eager = layer(x)
np.testing.assert_allclose(y_jit, y_eager, atol=1e-6)

def test_dyt_zero_input(self):
"""Test DyT with zero input produces zero output (with default params)."""
layer = nnx.DyT(num_features=4, rngs=nnx.Rngs(0))
x = jnp.zeros((2, 4))
y = layer(x)
# tanh(0) = 0, so with scale=1 and bias=0, output should be 0
np.testing.assert_allclose(y, jnp.zeros_like(x), atol=1e-7)


if __name__ == '__main__':
absltest.main()