diff --git a/benchmarks/attention/benchmark_attention_jax.py b/benchmarks/attention/benchmark_attention_jax.py index 54dd28505..b0b60cdb8 100644 --- a/benchmarks/attention/benchmark_attention_jax.py +++ b/benchmarks/attention/benchmark_attention_jax.py @@ -136,7 +136,7 @@ def bench_forward(self, warmup, iters, timings_dir): self.dropout_rng_sharding, ], ) - with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), fp8_autocast(mesh_resource=self.mesh_resource): for _ in range(warmup): customcall_fused_dpa_jit(*customcall_args) @@ -227,7 +227,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs): ), out_shardings=(None, grad_shardings), ) - with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), fp8_autocast(mesh_resource=self.mesh_resource): for _ in range(warmup): jitted_primitive(*customcall_args) diff --git a/tests/jax/test_distributed_dense.py b/tests/jax/test_distributed_dense.py index b8caf188d..818298ed8 100644 --- a/tests/jax/test_distributed_dense.py +++ b/tests/jax/test_distributed_dense.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -127,7 +129,7 @@ def test_distributed_gemm( contracting_dims = ((2,), (0,)) # Contract on hidden_in dimension - with mesh, autocast(enabled=False, mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(enabled=False, mesh_resource=mesh_resource): # TE GEMM result te_result = _jitted_gemm( x_sharded, @@ -209,7 +211,7 @@ def test_te_distributed_dense_grad( contracting_dims = ((2,), (0,)) - with mesh, autocast(enabled=False, mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(enabled=False, mesh_resource=mesh_resource): # Test gradients w.r.t. all inputs te_grad_func = jax.jit( jax.value_and_grad(self._te_sum_dense, argnums=(0, 1, 2)), diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 21359cedf..eb4497a0a 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -135,7 +137,7 @@ def ref_func(x, gamma, beta): ) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): x_named_sharding = NamedSharding(mesh, x_pspec) g_named_sharding = NamedSharding(mesh, g_pspec) b_named_sharding = NamedSharding(mesh, b_pspec) @@ -217,7 +219,7 @@ def ref_func(x, gamma): ) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): x_named_sharding = NamedSharding(mesh, x_pspec) g_named_sharding = NamedSharding(mesh, g_pspec) x_ = jax.device_put(x, x_named_sharding) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 6a2f395b1..4ed9e3cf5 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -261,7 +261,7 @@ def _test_layernorm_mlp_grad( # Multi GPUs devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast( + with jax.set_mesh(mesh), autocast( enabled=quantization_recipe is not None, recipe=quantization_recipe, mesh_resource=mesh_resource, @@ -452,7 +452,7 @@ def _test_layernorm_mlp( device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast( + with jax.set_mesh(mesh), autocast( enabled=use_fp8, recipe=quantization_recipe, mesh_resource=mesh_resource ): ln_mlp_sharded = LayerNormMLP( diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index 0665baa4e..ff44f249c 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -109,7 +111,7 @@ def impl_test_softmax( collective_count_ref = self.generate_collectives_count_ref() devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, autocast(mesh_resource=mesh_resource): + with jax.set_mesh(mesh), autocast(mesh_resource=mesh_resource): x_named_sharding = NamedSharding(mesh, x_pspec) mask_named_sharding = NamedSharding(mesh, mask_pspec) x_ = jax.device_put(x, x_named_sharding) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 3b3db30bd..6e9cf23cb 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -907,7 +907,7 @@ def test_forward(self): ], ) - with self.mesh, autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource): primitive_out = customcall_fused_dpa_jit(*customcall_args) primitive_out = self.cp_inverse_reorder_fn(primitive_out) @@ -924,7 +924,7 @@ def test_forward(self): assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) if self.coll_count_ref is not None: - with self.mesh, autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource): target_hlo = ( customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text() ) @@ -1038,7 +1038,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs): ) ) - with self.mesh, autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource): primitive_out, primitive_dgrad = jitted_primitive(*customcall_args) reference_out, reference_dgrad = jitted_reference(*args) @@ -1126,7 +1126,7 @@ def check_dqkv(primitive, reference, pad, idx): ) if self.coll_count_ref is not None: - with self.mesh, autocast(mesh_resource=self.mesh_resource): + with jax.set_mesh(self.mesh), autocast(mesh_resource=self.mesh_resource): target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text() assert_equal_collectives(target_hlo, self.coll_count_ref) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index e65215bec..670f59cfe 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -223,7 +223,7 @@ def name_of_wrapper_p(): for _name, _value in transformer_engine_jax.registrations().items(): - ffi.register_ffi_target(_name, _value, platform="ROCM" if is_hip_extension else "CUDA") + ffi.register_ffi_target(_name, _value, platform="ROCM" if is_hip_extension() else "CUDA") def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False): diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 9b13412c1..a3df47f11 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,6 +13,7 @@ """ from contextlib import contextmanager from dataclasses import dataclass +from packaging import version from typing import Callable, Optional import warnings @@ -20,7 +23,8 @@ from jax.sharding import PartitionSpec, get_abstract_mesh import numpy as np -_PXLA_THREAD_RESOURCES = pxla.thread_resources +if version.parse(jax.__version__) < version.parse("0.9.0"): + _PXLA_THREAD_RESOURCES = pxla.thread_resources # Axis Names BATCH_AXES = "nvte_batch" @@ -39,9 +43,11 @@ def _get_mesh(): # Handle Mesh's set via `with mesh:` - mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh - if mesh is not None and not mesh.empty: - return mesh + # ROCm: add JAX version guard for all backends + if version.parse(jax.__version__) < version.parse("0.9.0"): + mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh + if mesh is not None and not mesh.empty: + return mesh # Handle Mesh's set via `jax.set_mesh(mesh)` return jax.sharding.get_abstract_mesh() @@ -164,6 +170,31 @@ def filter_manual_axes(name_or_tuple): return x cleaned_pspec = PartitionSpec(*cleaned_axis_names) + + # ROCm: JAX 0.9 compat (all backends) — when an AbstractMesh is active, + # jax.lax.with_sharding_constraint requires the input to already carry a + # NamedSharding. This affects both concrete arrays in eager mode and traced + # values inside jax.jit whose abstract sharding is not a NamedSharding (e.g. + # Module.init() traces over a single-device input and JAX propagates the + # SingleDeviceSharding through the Tracer). In both cases the constraint must + # be skipped because JAX raises unconditionally. + # A UserWarning is emitted only for concrete (non-Tracer) arrays so the user + # gets a visible signal in eager mode; the jit-traced skip is unavoidable and + # kept silent to avoid spurious warnings from traced code. + if hasattr(x, "sharding") and not isinstance(x.sharding, jax.sharding.NamedSharding): + if not isinstance(x, jax.core.Tracer): + warnings.warn( + f"with_sharding_constraint: the sharding constraint {cleaned_pspec!r} was not" + f" applied because the input array carries a {type(x.sharding).__name__} rather" + " than a NamedSharding. This typically happens in eager mode when arrays have not" + " yet been placed on a mesh (e.g. during model initialisation). Wrap the call in" + " jax.jit or ensure the array is on a named mesh before applying sharding" + " constraints.", + UserWarning, + stacklevel=2, + ) + return x + return jax.lax.with_sharding_constraint(x, cleaned_pspec) @@ -359,6 +390,14 @@ def global_shard_guard(resource: MeshResource): old_resources = _GLOBAL_MESH_RESOURCE try: _GLOBAL_MESH_RESOURCE = resource + # ROCm: JAX 0.9 compat (all backends) + # Validate once at context-setup time, where get_abstract_mesh() correctly + # reflects the physical mesh. Calling _validate_mesh_resource_configuration + # from global_mesh_resource() (i.e. on every access) breaks in JAX 0.9 + # because get_abstract_mesh() returns an empty AbstractMesh when called + # from inside a custom_partitioning sharded_impl during jit(...).lower(). + if resource is not None: + _validate_mesh_resource_configuration(resource) yield finally: _GLOBAL_MESH_RESOURCE = old_resources @@ -375,7 +414,13 @@ def global_mesh_resource() -> MeshResource: " context. If you are not using multiple GPUs, you can use an empty MeshResource by" " wrapping your program in 'with global_shard_guard(MeshResource()):'" ) - _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) + # ROCm: JAX 0.9 compat (all backends) + # _validate_mesh_resource_configuration is intentionally NOT called here. + # Validation is done once in global_shard_guard() at context-setup time, where + # get_abstract_mesh() correctly reflects the physical mesh. Calling it here + # would break in JAX 0.9 when global_mesh_resource() is invoked from inside a + # custom_partitioning sharded_impl during jit(...).lower(), at which point + # get_abstract_mesh() returns an empty AbstractMesh. return _GLOBAL_MESH_RESOURCE @@ -418,8 +463,11 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes Returns: Reduced tensor """ - all_axes = get_all_mesh_axes() - for axis in all_axes: + # ROCm: JAX 0.9 compat (all backends) + # Use mesh.axis_names from the concrete mesh argument rather than calling + # get_all_mesh_axes() → _get_mesh() → get_abstract_mesh(), which returns + # empty in JAX 0.9 when called from inside a custom_partitioning sharded_impl. + for axis in mesh.axis_names: if axis != global_mesh_resource().pp_resource: x = lax_paral_op(x, jax.lax.pmax, axis, mesh) return x