diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index bbabaa834..4ca8052f8 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -136,6 +136,7 @@ from .nn.normalization import BatchNorm as BatchNorm from .nn.normalization import LayerNorm as LayerNorm from .nn.normalization import RMSNorm as RMSNorm +from .nn.normalization import DyT as DyT from .nn.normalization import GroupNorm as GroupNorm from .nn.normalization import InstanceNorm as InstanceNorm from .nn.normalization import SpectralNorm as SpectralNorm diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index e06ab9fcd..cdc61c85d 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -719,6 +719,165 @@ def __call__(self, x, mask: tp.Optional[jax.Array] = None): ) +class DyT(Module): + """Dynamic Tanh (DyT) normalization (https://arxiv.org/abs/2503.10622). + + DyT is a normalization-free alternative to LayerNorm/RMSNorm that replaces + statistical normalization (mean/variance reduction) with a learned pointwise + bounding function: + + .. math:: + + y = \\gamma \\cdot \\tanh(\\alpha \\cdot x) + \\beta + + where ``alpha`` is a learnable scalar controlling input scaling, ``gamma`` is + a per-feature scale, and ``beta`` is a per-feature bias. Unlike LayerNorm + and RMSNorm, DyT requires **no reduction operations** across the feature + dimension, making it particularly efficient for tensor-parallel training + where reductions require cross-device communication. + + Example usage:: + + >>> from flax import nnx + >>> import jax + + >>> x = jax.random.normal(jax.random.key(0), (5, 6)) + >>> layer = nnx.DyT(num_features=6, rngs=nnx.Rngs(0)) + + >>> nnx.state(layer, nnx.Param) + State({ + 'alpha': Param( + value=Array(0.5, dtype=float32) + ), + 'bias': Param( + value=Array([0., 0., 0., 0., 0., 0.], dtype=float32) + ), + 'scale': Param( + value=Array([1., 1., 1., 1., 1., 1.], dtype=float32) + ) + }) + + >>> y = layer(x) + + Reference: *Transformers without normalization* (Zhu et al., 2025). + + Args: + num_features: the number of input features. + alpha_init: initial value for the learnable scalar ``alpha`` that controls + the input scaling before tanh (default: 0.5). + dtype: the dtype of the result (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + use_bias: If True, bias (beta) is added. + use_scale: If True, multiply by scale (gamma). When the next layer is + linear (also e.g. nn.relu), this can be disabled since the scaling + will be done by the next layer. + bias_init: Initializer for bias, by default, zero. + scale_init: Initializer for scale, by default, one. + feature_axes: Feature axes for learned scale and bias parameters. All + specified axes will have independent learned parameters. + promote_dtype: function to promote the dtype of all input array arguments + (including Variables accessed through ``self``) to the desired dtype. The + function should accept a tuple of ``(inputs, scale, bias)`` and a + ``dtype`` keyword argument, and return a tuple of arrays with the + promoted dtype. + rngs: rng key. + bias_metadata: Optional metadata dictionary to set when initializing + the bias. + scale_metadata: Optional metadata dictionary to set when initializing + the scale. + alpha_metadata: Optional metadata dictionary to set when initializing + the alpha. + """ + + def __init__( + self, + num_features: int, + *, + alpha_init: float = 0.5, + dtype: tp.Optional[Dtype] = None, + param_dtype: Dtype = jnp.float32, + use_bias: bool = True, + use_scale: bool = True, + bias_init: Initializer = initializers.zeros, + scale_init: Initializer = initializers.ones, + feature_axes: Axes = -1, + promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, + rngs: rnglib.Rngs, + bias_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), + scale_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), + alpha_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), + ): + feature_shape = (num_features,) + + self.alpha = nnx.Param( + jnp.array(alpha_init, dtype=param_dtype), **alpha_metadata + ) + + self.scale: nnx.Param[jax.Array] | None + if use_scale: + key = rngs.params() + self.scale = nnx.Param( + scale_init(key, feature_shape, param_dtype), **scale_metadata + ) + else: + self.scale = nnx.data(None) + + self.bias: nnx.Param[jax.Array] | None + if use_bias: + key = rngs.params() + self.bias = nnx.Param( + bias_init(key, feature_shape, param_dtype), **bias_metadata + ) + else: + self.bias = nnx.data(None) + + self.num_features = num_features + self.dtype = dtype + self.param_dtype = param_dtype + self.use_bias = use_bias + self.use_scale = use_scale + self.feature_axes = feature_axes + self.promote_dtype = promote_dtype + + def __call__(self, x): + """Applies Dynamic Tanh normalization on the input. + + Args: + x: the inputs. + + Returns: + Normalized inputs (the same shape as inputs). + """ + feature_axes = _canonicalize_axes(x.ndim, self.feature_axes) + feature_shape = [1] * x.ndim + for ax in feature_axes: + feature_shape[ax] = x.shape[ax] + + # Promote dtypes for input and all Variables + alpha = self.alpha[...] + scale = self.scale[...] if self.scale else None + bias = self.bias[...] if self.bias else None + x, alpha, scale, bias = self.promote_dtype( + (x, alpha, scale, bias), dtype=self.dtype + ) + + # Core DyT operation: y = gamma * tanh(alpha * x) + beta + y = jnp.tanh(alpha * x) + + args = [x] + if scale is not None: + scale = scale.reshape(feature_shape) + y = y * scale + args.append(scale) + if bias is not None: + bias = bias.reshape(feature_shape) + y = y + bias + args.append(bias) + + dtype = dtypes.canonicalize_dtype(*args, dtype=self.dtype) + return jnp.asarray(y, dtype) + + class GroupNorm(Module): """Group normalization (arxiv.org/abs/1803.08494). diff --git a/tests/nnx/nn/normalization_test.py b/tests/nnx/nn/normalization_test.py index 7766252da..81aa1b744 100644 --- a/tests/nnx/nn/normalization_test.py +++ b/tests/nnx/nn/normalization_test.py @@ -590,5 +590,161 @@ def __call__(self, x): ) +class TestDyT(parameterized.TestCase): + """Tests for DyT (Dynamic Tanh) normalization module.""" + + def test_dyt_default_initialization(self): + """Test that DyT initializes with correct default parameter values.""" + layer = nnx.DyT(num_features=6, rngs=nnx.Rngs(0)) + state = nnx.state(layer, nnx.Param) + + # alpha should be a scalar initialized to 0.5 + np.testing.assert_allclose(state['alpha'].value, 0.5) + # scale (gamma) should be ones + np.testing.assert_array_equal( + state['scale'].value, np.ones(6, dtype=np.float32) + ) + # bias (beta) should be zeros + np.testing.assert_array_equal( + state['bias'].value, np.zeros(6, dtype=np.float32) + ) + + def test_dyt_forward_pass_manual(self): + """Test DyT forward pass matches manual computation.""" + layer = nnx.DyT(num_features=4, rngs=nnx.Rngs(0)) + x = jnp.array([[1.0, -2.0, 3.0, -4.0]]) + + y = layer(x) + # With default params: y = 1.0 * tanh(0.5 * x) + 0.0 + expected = jnp.tanh(0.5 * x) + np.testing.assert_allclose(y, expected, atol=1e-6) + + def test_dyt_output_shape(self): + """Test that DyT preserves input shape.""" + layer = nnx.DyT(num_features=6, rngs=nnx.Rngs(0)) + x = jax.random.normal(jax.random.key(0), (5, 6)) + y = layer(x) + self.assertEqual(x.shape, y.shape) + + def test_dyt_output_bounded(self): + """Test that DyT output is bounded by scale (gamma) values.""" + layer = nnx.DyT(num_features=4, rngs=nnx.Rngs(0)) + # Large inputs should be bounded by tanh saturation + x = jnp.array([[100.0, -100.0, 1000.0, -1000.0]]) + y = layer(x) + # With default scale=1 and bias=0, output should be in [-1, 1] + self.assertTrue(jnp.all(jnp.abs(y) <= 1.0 + 1e-6)) + + def test_dyt_gradient_flow(self): + """Test that gradients flow through all learnable parameters.""" + layer = nnx.DyT(num_features=4, rngs=nnx.Rngs(0)) + x = jax.random.normal(jax.random.key(0), (3, 4)) + + def loss_fn(model): + return jnp.sum(model(x) ** 2) + + grads = nnx.grad(loss_fn)(layer) + grad_state = nnx.state(grads, nnx.Param) + + # All parameters should have non-zero gradients + self.assertFalse(jnp.allclose(grad_state['alpha'].value, 0.0)) + self.assertFalse(jnp.allclose(grad_state['scale'].value, 0.0)) + self.assertFalse(jnp.allclose(grad_state['bias'].value, 0.0)) + + @parameterized.product( + dtype=[jnp.float32, jnp.float16], + param_dtype=[jnp.float32, jnp.float16], + ) + def test_dyt_dtype_handling( + self, + dtype: Dtype, + param_dtype: Dtype, + ): + """Test DyT with different input and parameter dtypes.""" + layer = nnx.DyT( + num_features=6, + dtype=dtype, + param_dtype=param_dtype, + rngs=nnx.Rngs(0), + ) + x = jax.random.normal(jax.random.key(0), (5, 6)) + y = layer(x) + self.assertEqual(y.dtype, dtype) + + def test_dyt_no_scale(self): + """Test DyT with use_scale=False.""" + layer = nnx.DyT( + num_features=4, use_scale=False, rngs=nnx.Rngs(0) + ) + x = jnp.array([[1.0, -2.0, 3.0, -4.0]]) + y = layer(x) + # Without scale: y = tanh(0.5 * x) + 0.0 + expected = jnp.tanh(0.5 * x) + np.testing.assert_allclose(y, expected, atol=1e-6) + + def test_dyt_no_bias(self): + """Test DyT with use_bias=False.""" + layer = nnx.DyT( + num_features=4, use_bias=False, rngs=nnx.Rngs(0) + ) + x = jnp.array([[1.0, -2.0, 3.0, -4.0]]) + y = layer(x) + # Without bias: y = 1.0 * tanh(0.5 * x) + expected = jnp.tanh(0.5 * x) + np.testing.assert_allclose(y, expected, atol=1e-6) + + def test_dyt_no_scale_no_bias(self): + """Test DyT with both use_scale=False and use_bias=False.""" + layer = nnx.DyT( + num_features=4, + use_scale=False, + use_bias=False, + rngs=nnx.Rngs(0), + ) + x = jnp.array([[1.0, -2.0, 3.0, -4.0]]) + y = layer(x) + expected = jnp.tanh(0.5 * x) + np.testing.assert_allclose(y, expected, atol=1e-6) + + def test_dyt_custom_alpha(self): + """Test DyT with a custom alpha initialization.""" + layer = nnx.DyT( + num_features=4, alpha_init=1.0, rngs=nnx.Rngs(0) + ) + x = jnp.array([[1.0, -2.0, 3.0, -4.0]]) + y = layer(x) + expected = jnp.tanh(1.0 * x) + np.testing.assert_allclose(y, expected, atol=1e-6) + + def test_dyt_multidim_input(self): + """Test DyT with higher-dimensional inputs (e.g., image-like).""" + # (batch, height, width, channels) + layer = nnx.DyT(num_features=8, rngs=nnx.Rngs(0)) + x = jax.random.normal(jax.random.key(0), (2, 4, 4, 8)) + y = layer(x) + self.assertEqual(y.shape, (2, 4, 4, 8)) + + def test_dyt_jit_compatible(self): + """Test that DyT works correctly under jax.jit.""" + layer = nnx.DyT(num_features=4, rngs=nnx.Rngs(0)) + x = jax.random.normal(jax.random.key(0), (3, 4)) + + @nnx.jit + def forward(model, x): + return model(x) + + y_jit = forward(layer, x) + y_eager = layer(x) + np.testing.assert_allclose(y_jit, y_eager, atol=1e-6) + + def test_dyt_zero_input(self): + """Test DyT with zero input produces zero output (with default params).""" + layer = nnx.DyT(num_features=4, rngs=nnx.Rngs(0)) + x = jnp.zeros((2, 4)) + y = layer(x) + # tanh(0) = 0, so with scale=1 and bias=0, output should be 0 + np.testing.assert_allclose(y, jnp.zeros_like(x), atol=1e-7) + + if __name__ == '__main__': absltest.main()