Skip to content

Out sharding for modules initialized with JIT is incorrect #5127

@qGentry

Description

@qGentry

Hey folks, me again.

I've recently faced the following problem when initializing the model with multiple meshes. Basically, output sharding from jitted init_fn returns completely random sharding instead of sticking to specified ones. Also seems like output tensors's mesh actually depends on ordering of the flattened tree. Check out this repro script:

import jax
import flax.nnx as nnx

import jax
import jax.numpy as jnp
import flax.nnx as nnx


mesh1 = jax.make_mesh((2, 4), ("a", "b"))
rules1 = (("A", "a"), ("B", "b"))
mesh2 = jax.make_mesh((2, 2, 2), ("x", "y", "z"))
rules2 = (("X", "x"), ("Y", "y"), ("Z", "z"))
mesh3 = jax.make_mesh((8,), ("c",))
rules3 = (("C", "c"),)

mesh_data = jax.make_mesh((4, 2), ("data", "context"))


class Model(nnx.Module):
    def __init__(self):
        self.small_linear1 = nnx.Param(
            jnp.ones((16, 16)), 
            sharding=("A", "B"), 
            mesh=mesh1,
            sharding_rules=rules1,
        )
        self.small_linear2 = nnx.Param(
            jnp.ones((16, 16, 16)), 
            sharding=("X", "Y", "Z"), 
            mesh=mesh2,
            sharding_rules=rules2,
        )
        self.small_linear3 = nnx.Param(
            jnp.ones((16, 16)),
            sharding=("C",), 
            mesh=mesh3,
            sharding_rules=rules3,
        )


def init_model_no_jit():
    return Model()


@nnx.jit
def init_model_nnx_jit():
    model = init_model_no_jit()
    return model


with mesh_data:
    model_nnx_jit = init_model_nnx_jit()
    model_no_jit = init_model_no_jit()

    def _print_t_shading(key, t):
        print(f"Key: {'.'.join(map(str, key))}, shape: {t.shape}, sharding: {t.sharding}")

    print("\nSharding without JIT:")
    jax.tree.map_with_path(_print_t_shading, model_no_jit)

    print("Sharding with NNX.JIT:")
    jax.tree.map_with_path(_print_t_shading, model_nnx_jit)

output:

Sharding without JIT:
Key: .small_linear1..value, shape: (16, 16), sharding: NamedSharding(mesh=Mesh('a': 2, 'b': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('a', 'b'), memory_kind=device)
Key: .small_linear2..value, shape: (16, 16, 16), sharding: NamedSharding(mesh=Mesh('x': 2, 'y': 2, 'z': 2, axis_types=(Auto, Auto, Auto)), spec=PartitionSpec('x', 'y', 'z'), memory_kind=device)
Key: .small_linear3..value, shape: (16, 16), sharding: NamedSharding(mesh=Mesh('c': 8, axis_types=(Auto,)), spec=PartitionSpec('c',), memory_kind=device)
Sharding with NNX.JIT:
Key: .small_linear1..value, shape: (16, 16), sharding: NamedSharding(mesh=Mesh('a': 2, 'b': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('a', 'b'), memory_kind=device)
Key: .small_linear2..value, shape: (16, 16, 16), sharding: GSPMDSharding({devices=[2,2,2]<=[8]}, memory_kind=device)
Key: .small_linear3..value, shape: (16, 16), sharding: NamedSharding(mesh=Mesh('a': 2, 'b': 4, axis_types=(Auto, Auto)), spec=PartitionSpec(('a', 'b'),), memory_kind=device)

No-JIT version, on the other hand, works correctly (but as one may imagine is not suitable for large-scale init).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions