Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/attention/benchmark_attention_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def bench_forward(self, warmup, iters, timings_dir):
self.dropout_rng_sharding,
],
)
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
with jax.set_mesh(self.mesh), fp8_autocast(mesh_resource=self.mesh_resource):
for _ in range(warmup):
customcall_fused_dpa_jit(*customcall_args)

Expand Down Expand Up @@ -227,7 +227,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs):
),
out_shardings=(None, grad_shardings),
)
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
with jax.set_mesh(self.mesh), fp8_autocast(mesh_resource=self.mesh_resource):
for _ in range(warmup):
jitted_primitive(*customcall_args)

Expand Down
6 changes: 4 additions & 2 deletions tests/jax/test_distributed_dense.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -127,7 +129,7 @@ def test_distributed_gemm(

contracting_dims = ((2,), (0,)) # Contract on hidden_in dimension

with mesh, autocast(enabled=False, mesh_resource=mesh_resource):
with jax.set_mesh(mesh), autocast(enabled=False, mesh_resource=mesh_resource):
# TE GEMM result
te_result = _jitted_gemm(
x_sharded,
Expand Down Expand Up @@ -209,7 +211,7 @@ def test_te_distributed_dense_grad(

contracting_dims = ((2,), (0,))

with mesh, autocast(enabled=False, mesh_resource=mesh_resource):
with jax.set_mesh(mesh), autocast(enabled=False, mesh_resource=mesh_resource):
# Test gradients w.r.t. all inputs
te_grad_func = jax.jit(
jax.value_and_grad(self._te_sum_dense, argnums=(0, 1, 2)),
Expand Down
6 changes: 4 additions & 2 deletions tests/jax/test_distributed_layernorm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -135,7 +137,7 @@ def ref_func(x, gamma, beta):
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
with jax.set_mesh(mesh), autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_named_sharding = NamedSharding(mesh, x_pspec)
g_named_sharding = NamedSharding(mesh, g_pspec)
b_named_sharding = NamedSharding(mesh, b_pspec)
Expand Down Expand Up @@ -217,7 +219,7 @@ def ref_func(x, gamma):
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
with jax.set_mesh(mesh), autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_named_sharding = NamedSharding(mesh, x_pspec)
g_named_sharding = NamedSharding(mesh, g_pspec)
x_ = jax.device_put(x, x_named_sharding)
Expand Down
4 changes: 2 additions & 2 deletions tests/jax/test_distributed_layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _test_layernorm_mlp_grad(
# Multi GPUs
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, autocast(
with jax.set_mesh(mesh), autocast(
enabled=quantization_recipe is not None,
recipe=quantization_recipe,
mesh_resource=mesh_resource,
Expand Down Expand Up @@ -452,7 +452,7 @@ def _test_layernorm_mlp(
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, autocast(
with jax.set_mesh(mesh), autocast(
enabled=use_fp8, recipe=quantization_recipe, mesh_resource=mesh_resource
):
ln_mlp_sharded = LayerNormMLP(
Expand Down
4 changes: 3 additions & 1 deletion tests/jax/test_distributed_softmax.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -109,7 +111,7 @@ def impl_test_softmax(
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, autocast(mesh_resource=mesh_resource):
with jax.set_mesh(mesh), autocast(mesh_resource=mesh_resource):
x_named_sharding = NamedSharding(mesh, x_pspec)
mask_named_sharding = NamedSharding(mesh, mask_pspec)
x_ = jax.device_put(x, x_named_sharding)
Expand Down
8 changes: 4 additions & 4 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ def test_forward(self):
],
)

with self.mesh, autocast(mesh_resource=self.mesh_resource):
with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource):
primitive_out = customcall_fused_dpa_jit(*customcall_args)
primitive_out = self.cp_inverse_reorder_fn(primitive_out)

Expand All @@ -924,7 +924,7 @@ def test_forward(self):
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

if self.coll_count_ref is not None:
with self.mesh, autocast(mesh_resource=self.mesh_resource):
with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource):
target_hlo = (
customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
)
Expand Down Expand Up @@ -1038,7 +1038,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs):
)
)

with self.mesh, autocast(mesh_resource=self.mesh_resource):
with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource):
primitive_out, primitive_dgrad = jitted_primitive(*customcall_args)

reference_out, reference_dgrad = jitted_reference(*args)
Expand Down Expand Up @@ -1126,7 +1126,7 @@ def check_dqkv(primitive, reference, pad, idx):
)

if self.coll_count_ref is not None:
with self.mesh, autocast(mesh_resource=self.mesh_resource):
with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource):
target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
assert_equal_collectives(target_hlo, self.coll_count_ref)

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def name_of_wrapper_p():


for _name, _value in transformer_engine_jax.registrations().items():
ffi.register_ffi_target(_name, _value, platform="ROCM" if is_hip_extension else "CUDA")
ffi.register_ffi_target(_name, _value, platform="ROCM" if is_hip_extension() else "CUDA")


def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False):
Expand Down
62 changes: 55 additions & 7 deletions transformer_engine/jax/sharding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -11,6 +13,7 @@
"""
from contextlib import contextmanager
from dataclasses import dataclass
from packaging import version
from typing import Callable, Optional
import warnings

Expand All @@ -20,7 +23,8 @@
from jax.sharding import PartitionSpec, get_abstract_mesh
import numpy as np

_PXLA_THREAD_RESOURCES = pxla.thread_resources
if version.parse(jax.__version__) < version.parse("0.9.0"):
_PXLA_THREAD_RESOURCES = pxla.thread_resources

# Axis Names
BATCH_AXES = "nvte_batch"
Expand All @@ -39,9 +43,11 @@

def _get_mesh():
# Handle Mesh's set via `with mesh:`
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh is not None and not mesh.empty:
return mesh
# ROCm: add JAX version guard for all backends
if version.parse(jax.__version__) < version.parse("0.9.0"):
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
if mesh is not None and not mesh.empty:
return mesh
# Handle Mesh's set via `jax.set_mesh(mesh)`
return jax.sharding.get_abstract_mesh()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backward-incompatibility risk on the upstream/CUDA path. This removes the with mesh: discovery branch entirely, so any caller that still uses the (still-supported in JAX) with mesh: pattern — instead of jax.set_mesh() — will now silently see an empty AbstractMesh here, and TE will treat it as "no mesh" rather than raising. That changes behaviour for non-ROCm users on JAX versions where both patterns still work, not just JAX 0.9.

This file is shared with upstream NVIDIA TE; per the fork's review guidance, behavioural changes to CUDA-reachable code need either (a) a runtime guard so CUDA stays byte-identical, or (b) explicit classification as a generic JAX-0.9 compat fix worth upstreaming. The "ROCm remove deprecated…" comment reads as ROCm-specific, but the change applies unconditionally. Consider falling back to pxla.thread_resources.env.physical_mesh when the abstract mesh is empty (keeping both code paths) or flagging this in the PR description as a deliberate JAX-0.9 cutover.


Expand Down Expand Up @@ -164,6 +170,31 @@ def filter_manual_axes(name_or_tuple):
return x

cleaned_pspec = PartitionSpec(*cleaned_axis_names)

# ROCm: JAX 0.9 compat (all backends) — when an AbstractMesh is active,
# jax.lax.with_sharding_constraint requires the input to already carry a
# NamedSharding. This affects both concrete arrays in eager mode and traced
# values inside jax.jit whose abstract sharding is not a NamedSharding (e.g.
# Module.init() traces over a single-device input and JAX propagates the
# SingleDeviceSharding through the Tracer). In both cases the constraint must
# be skipped because JAX raises unconditionally.
# A UserWarning is emitted only for concrete (non-Tracer) arrays so the user
# gets a visible signal in eager mode; the jit-traced skip is unavoidable and
# kept silent to avoid spurious warnings from traced code.
if hasattr(x, "sharding") and not isinstance(x.sharding, jax.sharding.NamedSharding):
if not isinstance(x, jax.core.Tracer):
warnings.warn(
f"with_sharding_constraint: the sharding constraint {cleaned_pspec!r} was not"
f" applied because the input array carries a {type(x.sharding).__name__} rather"
" than a NamedSharding. This typically happens in eager mode when arrays have not"
" yet been placed on a mesh (e.g. during model initialisation). Wrap the call in"
" jax.jit or ensure the array is on a named mesh before applying sharding"
" constraints.",
UserWarning,
stacklevel=2,
)
return x

return jax.lax.with_sharding_constraint(x, cleaned_pspec)


Expand Down Expand Up @@ -359,6 +390,14 @@ def global_shard_guard(resource: MeshResource):
old_resources = _GLOBAL_MESH_RESOURCE
try:
_GLOBAL_MESH_RESOURCE = resource
# ROCm: JAX 0.9 compat (all backends)
# Validate once at context-setup time, where get_abstract_mesh() correctly
# reflects the physical mesh. Calling _validate_mesh_resource_configuration
# from global_mesh_resource() (i.e. on every access) breaks in JAX 0.9
# because get_abstract_mesh() returns an empty AbstractMesh when called
# from inside a custom_partitioning sharded_impl during jit(...).lower().
if resource is not None:
_validate_mesh_resource_configuration(resource)
yield
finally:
_GLOBAL_MESH_RESOURCE = old_resources
Expand All @@ -375,7 +414,13 @@ def global_mesh_resource() -> MeshResource:
" context. If you are not using multiple GPUs, you can use an empty MeshResource by"
" wrapping your program in 'with global_shard_guard(MeshResource()):'"
)
_validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE)
# ROCm: JAX 0.9 compat (all backends)
# _validate_mesh_resource_configuration is intentionally NOT called here.
# Validation is done once in global_shard_guard() at context-setup time, where
# get_abstract_mesh() correctly reflects the physical mesh. Calling it here
# would break in JAX 0.9 when global_mesh_resource() is invoked from inside a
# custom_partitioning sharded_impl during jit(...).lower(), at which point
# get_abstract_mesh() returns an empty AbstractMesh.
return _GLOBAL_MESH_RESOURCE


Expand Down Expand Up @@ -418,8 +463,11 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes
Returns:
Reduced tensor
"""
all_axes = get_all_mesh_axes()
for axis in all_axes:
# ROCm: JAX 0.9 compat (all backends)
# Use mesh.axis_names from the concrete mesh argument rather than calling
# get_all_mesh_axes() → _get_mesh() → get_abstract_mesh(), which returns
# empty in JAX 0.9 when called from inside a custom_partitioning sharded_impl.
for axis in mesh.axis_names:
if axis != global_mesh_resource().pp_resource:
x = lax_paral_op(x, jax.lax.pmax, axis, mesh)
return x
Expand Down
Loading