From adc09687f45c2cbd732100fbd9436513369b8121 Mon Sep 17 00:00:00 2001 From: Ivan Podkidyshev Date: Wed, 15 Apr 2026 13:34:17 +0200 Subject: [PATCH 1/3] refactor env variables and sbatch directives in slurm --- .../systems/slurm/single_sbatch_runner.py | 6 +- .../slurm/slurm_command_gen_strategy.py | 114 ++++++++++-------- src/cloudai/workloads/common/nixl.py | 1 - .../deepep/slurm_command_gen_strategy.py | 46 +++---- .../nixl_ep/slurm_command_gen_strategy.py | 32 +++-- .../slurm_command_gen_strategy.py | 18 +-- tests/ref_data/nixl-kvbench.sbatch | 1 - tests/ref_data/nixl-perftest.sbatch | 1 - tests/ref_data/nixl_bench.sbatch | 1 - .../slurm/test_command_gen_strategy.py | 10 +- .../test_command_gen_strategy_slurm.py | 7 +- 11 files changed, 116 insertions(+), 121 deletions(-) diff --git a/src/cloudai/systems/slurm/single_sbatch_runner.py b/src/cloudai/systems/slurm/single_sbatch_runner.py index 6f763f4b8..c98b9458e 100644 --- a/src/cloudai/systems/slurm/single_sbatch_runner.py +++ b/src/cloudai/systems/slurm/single_sbatch_runner.py @@ -134,12 +134,10 @@ def unroll_dse(self, tr: TestRun) -> Generator[TestRun, None, None]: yield next_tr def get_global_env_vars(self) -> str: - vars: list[str] = ["export SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1)"] tr = self.test_scenario.test_runs[0] cmd_gen = cast(SlurmCommandGenStrategy, self.get_cmd_gen_strategy(self.system, tr)) - for key, value in cmd_gen.final_env_vars.items(): - vars.append(f"export {key}={value}") - return "\n".join(vars) + env_vars = cmd_gen.get_sbatch_env_vars() | cmd_gen.final_env_vars + return "\n".join([f"export {key}={value}" for key, value in env_vars.items()]) def gen_sbatch_content(self) -> str: content: list[str] = ["#!/bin/bash", *self.get_sbatch_directives(), ""] diff --git a/src/cloudai/systems/slurm/slurm_command_gen_strategy.py b/src/cloudai/systems/slurm/slurm_command_gen_strategy.py index 67619a1b9..e15a89f82 100644 --- a/src/cloudai/systems/slurm/slurm_command_gen_strategy.py +++ b/src/cloudai/systems/slurm/slurm_command_gen_strategy.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from abc import abstractmethod from datetime import datetime @@ -61,6 +60,17 @@ def nodelist_in_use(self) -> bool: _, nodes = self.get_cached_nodes_spec() return len(nodes) > 0 + def get_sbatch_env_vars(self) -> dict[str, str]: + env_vars = { + "SLURM_JOB_MASTER_NODE": "$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1)", + } + + _, hostfile = self.get_nodes_related_directives() + if hostfile is not None: + env_vars["SLURM_HOSTFILE"] = str(hostfile) + + return env_vars + @abstractmethod def _container_mounts(self) -> list[str]: """Return CommandGenStrategy specific container mounts for the test run.""" @@ -274,22 +284,10 @@ def gen_srun_prefix(self, use_pretest_extras: bool = False, with_num_nodes: bool def generate_test_command(self) -> List[str]: return [] - def _add_reservation(self, batch_script_content: List[str]): - """ - Add reservation if provided. - - Args: - batch_script_content (List[str]): content of the batch script. - - Returns: - List[str]: updated batch script with reservation if exists. - """ - reservation_key = "--reservation " - if self.system.extra_srun_args and reservation_key in self.system.extra_srun_args: - reservation = self.system.extra_srun_args.split(reservation_key, 1)[1].split(" ", 1)[0] - batch_script_content.append(f"#SBATCH --reservation={reservation}") - - return batch_script_content + def _get_reservation(self) -> str | None: + if self.system.extra_srun_args and "--reservation " in self.system.extra_srun_args: + return self.system.extra_srun_args.split("--reservation ", 1)[1].split(" ", 1)[0] + return None def _ranks_mapping_cmd(self) -> str: return " ".join( @@ -352,6 +350,11 @@ def _write_sbatch_script(self, srun_command: str) -> str: ] self._append_sbatch_directives(batch_script_content) + batch_script_content.append("") + batch_script_content.extend([self._format_env_vars(self.get_sbatch_env_vars())]) + + if sbatch_prefix := self._gen_sbatch_prefix(): + batch_script_content.extend(sbatch_prefix) batch_script_content.extend([self._format_env_vars(self.final_env_vars)]) @@ -368,50 +371,65 @@ def _write_sbatch_script(self, srun_command: str) -> str: return f"sbatch {batch_script_path}" - def _append_sbatch_directives(self, batch_script_content: List[str]) -> None: - """ - Append SBATCH directives to the batch script content. + def _get_sbatch_directives(self) -> dict[str, str]: + directives = {} - Args: - batch_script_content (List[str]): The list of script lines to append to. - """ - batch_script_content = self._add_reservation(batch_script_content) + if reservation := self._get_reservation(): + directives["reservation"] = reservation + + directives["output"] = self.test_run.output_path.absolute() / "stdout.txt" + directives["error"] = self.test_run.output_path.absolute() / "stderr.txt" + directives["partition"] = self.system.default_partition - batch_script_content.append(f"#SBATCH --output={self.test_run.output_path.absolute() / 'stdout.txt'}") - batch_script_content.append(f"#SBATCH --error={self.test_run.output_path.absolute() / 'stderr.txt'}") - batch_script_content.append(f"#SBATCH --partition={self.system.default_partition}") if self.system.account: - batch_script_content.append(f"#SBATCH --account={self.system.account}") + directives["account"] = self.system.account - hostfile = self._append_nodes_related_directives(batch_script_content) + if self.system.distribution: + directives["distribution"] = self.system.distribution + + directives.update(self.get_nodes_related_directives()[0]) if self.system.gpus_per_node and self.system.supports_gpu_directives: - batch_script_content.append(f"#SBATCH --gpus-per-node={self.system.gpus_per_node}") - batch_script_content.append(f"#SBATCH --gres=gpu:{self.system.gpus_per_node}") + directives["gpus-per-node"] = self.system.gpus_per_node + directives["gres"] = f"gpu:{self.system.gpus_per_node}" if self.system.ntasks_per_node: - batch_script_content.append(f"#SBATCH --ntasks-per-node={self.system.ntasks_per_node}") + directives["ntasks_per_node"] = self.system.ntasks_per_node + if self.test_run.time_limit: - batch_script_content.append(f"#SBATCH --time={self.test_run.time_limit}") + directives["time"] = self.test_run.time_limit for arg in self.system.extra_sbatch_args: - batch_script_content.append(f"#SBATCH {arg}") + directives[arg] = "" - if hostfile is not None: - batch_script_content.append(f"export SLURM_HOSTFILE={hostfile}") + return directives - batch_script_content.append( - "\nexport SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1)" - ) + def _append_sbatch_directives(self, batch_script_content: List[str]) -> None: + """ + Append SBATCH directives to the batch script content. - def _append_nodes_related_directives(self, content: List[str]) -> Optional[Path]: - num_nodes, node_list = self.get_cached_nodes_spec() + Args: + batch_script_content (List[str]): The list of script lines to append to. + """ + directives = self._get_sbatch_directives() + for key, value in directives.items(): + if key.startswith("-"): + # strip makes handling empty `value` cleaner + batch_script_content.append(f"#SBATCH {key} {value}".strip()) + elif value: + batch_script_content.append(f"#SBATCH --{key}={value}") + else: + batch_script_content.append(f"#SBATCH --{key}") + + def _gen_sbatch_prefix(self) -> list[str]: + return [] - if self.system.distribution: - content.append(f"#SBATCH --distribution={self.system.distribution}") + def get_nodes_related_directives(self) -> tuple[dict, Optional[Path]]: + directives = {} + num_nodes, node_list = self.get_cached_nodes_spec() if node_list: - content.append(f"#SBATCH --nodelist={','.join(node_list)}") + directives["nodelist"] = ",".join(node_list) hostfile = (self.test_run.output_path / "hostfile.txt").absolute() with hostfile.open("w") as hf: @@ -420,14 +438,14 @@ def _append_nodes_related_directives(self, content: List[str]) -> Optional[Path] for _ in range(tasks): hf.write(f"{node}\n") - return hostfile + return directives, hostfile - content.append(f"#SBATCH -N {num_nodes}") + directives["-N"] = num_nodes if self.test_run.exclude_nodes: - content.append(f"#SBATCH --exclude={','.join(self.test_run.exclude_nodes)}") + directives["exclude"] = ",".join(self.test_run.exclude_nodes) - return None + return directives, None def _format_env_vars(self, env_vars: Dict[str, Any]) -> str: """ diff --git a/src/cloudai/workloads/common/nixl.py b/src/cloudai/workloads/common/nixl.py index 430a63951..54ffe68aa 100644 --- a/src/cloudai/workloads/common/nixl.py +++ b/src/cloudai/workloads/common/nixl.py @@ -256,7 +256,6 @@ def final_env_vars(self) -> dict[str, str | list[str]]: env_vars = super().final_env_vars env_vars["NIXL_ETCD_NAMESPACE"] = "/nixl/kvbench/$(uuidgen)" env_vars["NIXL_ETCD_ENDPOINTS"] = '"$SLURM_JOB_MASTER_NODE:2379"' - env_vars["SLURM_JOB_MASTER_NODE"] = "$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1)" return env_vars @final_env_vars.setter diff --git a/src/cloudai/workloads/deepep/slurm_command_gen_strategy.py b/src/cloudai/workloads/deepep/slurm_command_gen_strategy.py index 60bbf3000..5707ffdd5 100644 --- a/src/cloudai/workloads/deepep/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/deepep/slurm_command_gen_strategy.py @@ -25,37 +25,21 @@ class DeepEPSlurmCommandGenStrategy(SlurmCommandGenStrategy): """Command generation strategy for DeepEP benchmark on Slurm systems.""" - def _append_head_node_detection(self, batch_script_content: List[str]) -> None: - """ - Append bash commands to detect head node IP for torchrun. - - Args: - batch_script_content: The list of script lines to append to. - """ - batch_script_content.extend( - [ - "", - "nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )", - "nodes_array=($nodes)", - "head_node=${nodes_array[0]}", - 'head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)', - "", - "echo Nodes: $SLURM_JOB_NODELIST", - "echo Num Nodes: ${#nodes[@]}", - "echo Head Node IP: $head_node_ip", - "", - ] - ) - - def _append_sbatch_directives(self, batch_script_content: List[str]) -> None: - """ - Append SBATCH directives and head node detection setup for DeepEP. - - Args: - batch_script_content: The list of script lines to append to. - """ - super()._append_sbatch_directives(batch_script_content) - self._append_head_node_detection(batch_script_content) + def _gen_sbatch_prefix(self) -> list[str]: + """Append bash commands to detect head node IP for torchrun.""" + return [ + *super()._gen_sbatch_prefix(), + "", + "nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )", + "nodes_array=($nodes)", + "head_node=${nodes_array[0]}", + 'head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)', + "", + "echo Nodes: $SLURM_JOB_NODELIST", + "echo Num Nodes: ${#nodes[@]}", + "echo Head Node IP: $head_node_ip", + "", + ] def _container_mounts(self) -> List[str]: """Return container mounts specific to DeepEP benchmark.""" diff --git a/src/cloudai/workloads/nixl_ep/slurm_command_gen_strategy.py b/src/cloudai/workloads/nixl_ep/slurm_command_gen_strategy.py index bea7e7dd0..510a35fba 100644 --- a/src/cloudai/workloads/nixl_ep/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/nixl_ep/slurm_command_gen_strategy.py @@ -64,23 +64,21 @@ def num_processes_per_node(self) -> int: raise ValueError("NIXL EP Slurm command generation requires num_processes_per_node to be an integer.") return num_processes_per_node - def _append_sbatch_directives(self, batch_script_content: list[str]) -> None: - super()._append_sbatch_directives(batch_script_content) - batch_script_content.extend( - [ - "", - "nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )", - "nodes_array=($nodes)", - "master_node=${nodes_array[0]}", - "master_ip=$(srun --nodes=1 --ntasks=1 -w \"$master_node\" hostname --ip-address | awk '{print $1}')", - "", - "echo Nodes: $SLURM_JOB_NODELIST", - "echo Num Nodes: ${#nodes[@]}", - "echo Master Node: $master_node", - "echo Master IP: $master_ip", - "", - ] - ) + def _gen_sbatch_prefix(self) -> list[str]: + return [ + *super()._gen_sbatch_prefix(), + "", + "nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )", + "nodes_array=($nodes)", + "master_node=${nodes_array[0]}", + "master_ip=$(srun --nodes=1 --ntasks=1 -w \"$master_node\" hostname --ip-address | awk '{print $1}')", + "", + "echo Nodes: $SLURM_JOB_NODELIST", + "echo Num Nodes: ${#nodes[@]}", + "echo Master Node: $master_node", + "echo Master IP: $master_ip", + "", + ] @property def env_vars_path(self) -> Path: diff --git a/src/cloudai/workloads/triton_inference/slurm_command_gen_strategy.py b/src/cloudai/workloads/triton_inference/slurm_command_gen_strategy.py index 0dd4075be..adbff39b1 100644 --- a/src/cloudai/workloads/triton_inference/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/triton_inference/slurm_command_gen_strategy.py @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,7 +15,7 @@ # limitations under the License. from pathlib import Path -from typing import Any, Dict, List, Tuple, cast +from typing import Any, Dict, Tuple, cast from cloudai.core import TestRun from cloudai.systems.slurm import SlurmCommandGenStrategy, SlurmSystem @@ -44,12 +44,14 @@ def _container_mounts(self) -> list[str]: return mounts - def _append_sbatch_directives(self, batch_script_content: List[str]) -> None: - super()._append_sbatch_directives(batch_script_content) - batch_script_content.append("export HEAD_NODE=$SLURM_JOB_MASTER_NODE") - batch_script_content.append("export NIM_LEADER_IP_ADDRESS=$SLURM_JOB_MASTER_NODE") - batch_script_content.append(f"export NIM_NUM_COMPUTE_NODES={self.test_run.nnodes - 1}") - batch_script_content.append("export NIM_MODEL_TOKENIZER='deepseek-ai/DeepSeek-R1'") + def _gen_sbatch_prefix(self) -> list[str]: + return [ + *super()._gen_sbatch_prefix(), + "export HEAD_NODE=$SLURM_JOB_MASTER_NODE", + "export NIM_LEADER_IP_ADDRESS=$SLURM_JOB_MASTER_NODE", + f"export NIM_NUM_COMPUTE_NODES={self.test_run.nnodes - 1}", + "export NIM_MODEL_TOKENIZER='deepseek-ai/DeepSeek-R1'", + ] def _generate_start_wrapper_script(self, script_path: Path, env_vars: Dict[str, Any]) -> None: lines = ["#!/bin/bash", ""] diff --git a/tests/ref_data/nixl-kvbench.sbatch b/tests/ref_data/nixl-kvbench.sbatch index 817b0eacb..6e422791c 100644 --- a/tests/ref_data/nixl-kvbench.sbatch +++ b/tests/ref_data/nixl-kvbench.sbatch @@ -11,7 +11,6 @@ export SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1) export NIXL_ETCD_NAMESPACE=/nixl/kvbench/$(uuidgen) export NIXL_ETCD_ENDPOINTS="$SLURM_JOB_MASTER_NODE:2379" -export SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1) srun --export=ALL --mpi=pmix -N2 --container-image=url.com/docker:tag --container-mounts=__OUTPUT_DIR__/output:/cloudai_run_results,__OUTPUT_DIR__/install:/cloudai_install,__OUTPUT_DIR__/output --output=__OUTPUT_DIR__/output/mapping-stdout.txt --error=__OUTPUT_DIR__/output/mapping-stderr.txt bash -c "echo \$(date): \$(hostname):node \${SLURM_NODEID}:rank \${SLURM_PROCID}." srun --export=ALL --mpi=pmix -N2 --container-image=url.com/docker:tag --container-mounts=__OUTPUT_DIR__/output:/cloudai_run_results,__OUTPUT_DIR__/install:/cloudai_install,__OUTPUT_DIR__/output --ntasks=2 --ntasks-per-node=1 --output=__OUTPUT_DIR__/output/metadata/node-%N.toml --error=__OUTPUT_DIR__/output/metadata/nodes.err bash /cloudai_install/slurm-metadata.sh diff --git a/tests/ref_data/nixl-perftest.sbatch b/tests/ref_data/nixl-perftest.sbatch index 6b4d2ba88..68351a33c 100644 --- a/tests/ref_data/nixl-perftest.sbatch +++ b/tests/ref_data/nixl-perftest.sbatch @@ -11,7 +11,6 @@ export SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1) export NIXL_ETCD_NAMESPACE=/nixl/kvbench/$(uuidgen) export NIXL_ETCD_ENDPOINTS="$SLURM_JOB_MASTER_NODE:2379" -export SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1) srun --export=ALL --mpi=pmix -N1 --container-image=url.com/docker:tag --container-mounts=__OUTPUT_DIR__/output:/cloudai_run_results,__OUTPUT_DIR__/install:/cloudai_install,__OUTPUT_DIR__/output --output=__OUTPUT_DIR__/output/mapping-stdout.txt --error=__OUTPUT_DIR__/output/mapping-stderr.txt bash -c "echo \$(date): \$(hostname):node \${SLURM_NODEID}:rank \${SLURM_PROCID}." srun --export=ALL --mpi=pmix -N1 --container-image=url.com/docker:tag --container-mounts=__OUTPUT_DIR__/output:/cloudai_run_results,__OUTPUT_DIR__/install:/cloudai_install,__OUTPUT_DIR__/output --ntasks=1 --ntasks-per-node=1 --output=__OUTPUT_DIR__/output/metadata/node-%N.toml --error=__OUTPUT_DIR__/output/metadata/nodes.err bash /cloudai_install/slurm-metadata.sh diff --git a/tests/ref_data/nixl_bench.sbatch b/tests/ref_data/nixl_bench.sbatch index b3191cc0d..be8b4d0ec 100644 --- a/tests/ref_data/nixl_bench.sbatch +++ b/tests/ref_data/nixl_bench.sbatch @@ -11,7 +11,6 @@ export SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1) export NIXL_ETCD_NAMESPACE=/nixl/kvbench/$(uuidgen) export NIXL_ETCD_ENDPOINTS="$SLURM_JOB_MASTER_NODE:2379" -export SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1) srun --export=ALL --mpi=pmix -N2 --output=__OUTPUT_DIR__/output/mapping-stdout.txt --error=__OUTPUT_DIR__/output/mapping-stderr.txt bash -c "echo \$(date): \$(hostname):node \${SLURM_NODEID}:rank \${SLURM_PROCID}." srun --export=ALL --mpi=pmix -N2 --ntasks=2 --ntasks-per-node=1 --output=__OUTPUT_DIR__/output/metadata/node-%N.toml --error=__OUTPUT_DIR__/output/metadata/nodes.err bash __INSTALL_DIR__/slurm-metadata.sh diff --git a/tests/systems/slurm/test_command_gen_strategy.py b/tests/systems/slurm/test_command_gen_strategy.py index 98e103704..e113ff5d3 100644 --- a/tests/systems/slurm/test_command_gen_strategy.py +++ b/tests/systems/slurm/test_command_gen_strategy.py @@ -340,7 +340,7 @@ def test_append_distribution_and_hostfile_with_nodes(slurm_system: SlurmSystem, testrun_fixture.nodes = ["node1", "node2"] strategy = MySlurmCommandGenStrategy(slurm_system, testrun_fixture) content: List[str] = [] - strategy._append_nodes_related_directives(content) + strategy._append_sbatch_directives(content) assert "#SBATCH --distribution=block" in content assert "#SBATCH --nodelist=node1,node2" in content @@ -355,7 +355,7 @@ def test_distribution_fallback_when_no_nodes(strategy_fixture: SlurmCommandGenSt strategy_fixture.test_run.nodes = [] strategy_fixture.system.distribution = "cyclic" content: List[str] = [] - strategy_fixture._append_nodes_related_directives(content) + strategy_fixture._append_sbatch_directives(content) assert "#SBATCH --distribution=cyclic" in content assert "#SBATCH --nodelist=" not in content @@ -366,7 +366,7 @@ def test_exclude_nodes_directive_when_no_nodelist(strategy_fixture: SlurmCommand strategy_fixture.test_run.num_nodes = 3 strategy_fixture.test_run.exclude_nodes = ["node01", "node02"] content: List[str] = [] - strategy_fixture._append_nodes_related_directives(content) + strategy_fixture._append_sbatch_directives(content) assert "#SBATCH -N 3" in content assert "#SBATCH --exclude=node01,node02" in content @@ -377,7 +377,7 @@ def test_no_exclude_directive_when_nodelist_present(slurm_system: SlurmSystem, t testrun_fixture.exclude_nodes = ["node01", "node02"] strategy = MySlurmCommandGenStrategy(slurm_system, testrun_fixture) content: List[str] = [] - strategy._append_nodes_related_directives(content) + strategy._append_sbatch_directives(content) assert "#SBATCH --nodelist=node3,node4" in content assert "#SBATCH --exclude=" not in content @@ -388,7 +388,7 @@ def test_no_exclude_directive_when_exclude_nodes_unset(strategy_fixture: SlurmCo strategy_fixture.test_run.num_nodes = 2 strategy_fixture.test_run.exclude_nodes = [] content: List[str] = [] - strategy_fixture._append_nodes_related_directives(content) + strategy_fixture._append_sbatch_directives(content) assert "#SBATCH -N 2" in content assert not any("--exclude" in line for line in content) diff --git a/tests/workloads/triton_inference/test_command_gen_strategy_slurm.py b/tests/workloads/triton_inference/test_command_gen_strategy_slurm.py index 8e17208fe..d4f137e5c 100644 --- a/tests/workloads/triton_inference/test_command_gen_strategy_slurm.py +++ b/tests/workloads/triton_inference/test_command_gen_strategy_slurm.py @@ -16,7 +16,7 @@ import stat from pathlib import Path -from typing import List, cast +from typing import cast from unittest.mock import Mock import pytest @@ -113,9 +113,8 @@ def test_generate_start_wrapper_script(tmp_path: Path, strategy: TritonInference assert bool(mode & stat.S_IXUSR) -def test_append_sbatch_directives(strategy: TritonInferenceSlurmCommandGenStrategy) -> None: - lines: List[str] = [] - strategy._append_sbatch_directives(lines) +def test_gen_sbatch_prefix(strategy: TritonInferenceSlurmCommandGenStrategy) -> None: + lines = strategy._gen_sbatch_prefix() assert "export HEAD_NODE=$SLURM_JOB_MASTER_NODE" in lines assert "export NIM_LEADER_IP_ADDRESS=$SLURM_JOB_MASTER_NODE" in lines assert "export NIM_NUM_COMPUTE_NODES=2" in lines From 677f7d63c8df153fb0c194242aea119d9a8d9da0 Mon Sep 17 00:00:00 2001 From: Ivan Podkidyshev Date: Wed, 15 Apr 2026 21:41:54 +0200 Subject: [PATCH 2/3] properly handle basic sruncommandgen methods --- .../slurm/slurm_command_gen_strategy.py | 10 +- .../slurm_command_gen_strategy.py | 163 +++++++------- tests/ref_data/megatron-bridge.sbatch | 3 +- .../test_command_gen_strategy_slurm.py | 202 ++++-------------- 4 files changed, 124 insertions(+), 254 deletions(-) diff --git a/src/cloudai/systems/slurm/slurm_command_gen_strategy.py b/src/cloudai/systems/slurm/slurm_command_gen_strategy.py index e15a89f82..8df130873 100644 --- a/src/cloudai/systems/slurm/slurm_command_gen_strategy.py +++ b/src/cloudai/systems/slurm/slurm_command_gen_strategy.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging from abc import abstractmethod from datetime import datetime @@ -71,6 +72,11 @@ def get_sbatch_env_vars(self) -> dict[str, str]: return env_vars + def write_env_vars(self): + with (self.test_run.output_path / "env_vars.sh").open("w") as f: + for key, value in self.final_env_vars.items(): + f.write(f'export {key}="{value}"\n') + @abstractmethod def _container_mounts(self) -> list[str]: """Return CommandGenStrategy specific container mounts for the test run.""" @@ -241,9 +247,7 @@ def _gen_srun_command(self) -> str: nsys_command_parts = self.gen_nsys_command() test_command_parts = self.generate_test_command() - with (self.test_run.output_path / "env_vars.sh").open("w") as f: - for key, value in self.final_env_vars.items(): - f.write(f'export {key}="{value}"\n') + self.write_env_vars() full_test_cmd = ( f'bash -c "source {(self.test_run.output_path / "env_vars.sh").absolute()}; ' diff --git a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py index f33c71964..76c39f6fd 100644 --- a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py @@ -37,14 +37,6 @@ class MegatronBridgeSlurmCommandGenStrategy(SlurmCommandGenStrategy): The launcher submits the actual training sbatch job; CloudAI tracks that job ID via SlurmRunner parsing. """ - CONTAINER_RUNTIME_ENV_VARS: frozenset[str] = frozenset( - { - "MELLANOX_VISIBLE_DEVICES", - "NVIDIA_VISIBLE_DEVICES", - "NVIDIA_DRIVER_CAPABILITIES", - } - ) - def _container_mounts(self) -> list[str]: # This workload submits its own sbatch job and passes mounts via `-cm`. return [] @@ -92,36 +84,70 @@ def _write_command_to_file(self, command: str, output_path: Path) -> None: with log_file.open("w") as f: f.write(f"{command}\n") - def _build_custom_bash_env_exports(self) -> list[str]: - """ - Build repeated -cb entries that export env vars inside the launched Slurm job shell. + def _get_startup_env_vars(self) -> dict[str, str]: + """Env vars to export before sbatch submission (those that contain commas) = best effort to support them.""" + return {k: str(v) for k, v in (self.get_sbatch_env_vars() | self.final_env_vars).items() if "," in v} - We quote each full `export KEY=value` command so `$SLURM_*` and commas survive - argument parsing on the submit node and are expanded/interpreted in the job shell. - """ - exports: list[str] = [] - for key, value in sorted(self.final_env_vars.items()): - exports.extend(["-cb", shlex.quote(f"export {key}={value}")]) - return exports + def _build_custom_env_vars(self) -> str: + ignored_env_vars = {} - def _container_runtime_env_exports(self) -> list[str]: - """ - Build ``export`` lines for container-runtime env vars. - - Variables like ``MELLANOX_VISIBLE_DEVICES`` and ``NVIDIA_VISIBLE_DEVICES`` - are consumed by the NVIDIA container toolkit / enroot at container-creation - time to decide which devices to mount. They must be present in the process - environment **before** the Megatron-Bridge launcher calls ``sbatch`` so that - Slurm inherits them into the job and ``srun`` passes them to the container - runtime. Exporting them in the wrapper script (which runs on the submit - node) achieves this. The same variables are still passed via ``-cb`` as - well, so they are also set inside the container for any runtime readers. - """ - lines: list[str] = [] - for key, value in sorted(self.final_env_vars.items()): - if key in self.CONTAINER_RUNTIME_ENV_VARS: - lines.append(f"export {key}={shlex.quote(str(value))}") - return lines + custom_env_vars = {} + for k, v in (self.get_sbatch_env_vars() | self.final_env_vars).items(): + v = str(v) + + if "," in v: + ignored_env_vars[k] = v + continue + + # make sure expressions aren't evaluated before sbatch submit + if "$(" in v: + v = v.replace("$(", "\\$(") + + custom_env_vars[k] = v + + if ignored_env_vars: + logging.warning( + "Megatron-Bridge has limited support for env vars with commas. The following env vars will not be " + "evaluated inside SBATCH. Only outside (before sbatch submission) and on each compute node (srun).\n%s", + ",".join(f"{key}={value}" for key, value in ignored_env_vars.items()), + ) + + if not custom_env_vars: + return "" + + return shlex.quote(",".join(f"{key}={value}" for key, value in custom_env_vars.items())) + + def _build_custom_srun_args(self) -> list[str]: + srun_parts = self.gen_srun_prefix() + srun_parts.remove("srun") + + part_prefixes_to_remove = ( + "--mpi", + "--container-writable", + "--no-container-mount-home", + "--container-mounts", + "--container-image", + ) + return [part for part in srun_parts if not part.startswith(part_prefixes_to_remove)] + + def _build_additional_slurm_params(self) -> str: + directives = self._get_sbatch_directives() + additional_slurm_params: list[str] = [] + + for key, value in directives.items(): + if key == "-N": + key = "nodes" + key = key.lstrip("-").replace("_", "-") + + if value is True or value == "": + additional_slurm_params.append(key) + else: + additional_slurm_params.append(f"{key}={value}") + + if not additional_slurm_params: + return "" + + return shlex.quote(";".join(additional_slurm_params)) def _normalize_recompute_modules(self, val: Any) -> str: if isinstance(val, list): @@ -135,30 +161,6 @@ def _normalize_recompute_modules(self, val: Any) -> str: joined = ",".join(items) return f'"{joined}"' - @staticmethod - def _parse_srun_args_as_slurm_params(srun_args: str) -> list[str]: - """ - Convert ``--key value`` pairs from extra_srun_args into ``key=value`` for --additional_slurm_params. - - Standalone boolean flags (e.g. ``--exclusive``) are emitted as bare - key names without a ``=value`` suffix. - """ - params: list[str] = [] - tokens = shlex.split(srun_args) - i = 0 - while i < len(tokens): - tok = tokens[i] - if tok.startswith("--") and "=" in tok: - key, val = tok[2:].split("=", 1) - params.append(f"{key}={val}") - elif tok.startswith("--") and i + 1 < len(tokens) and not tokens[i + 1].startswith("--"): - params.append(f"{tok[2:]}={tokens[i + 1]}") - i += 1 - elif tok.startswith("--"): - params.append(tok[2:]) - i += 1 - return params - def _normalize_cuda_graph_scope_arg(self, val: Any) -> str: s = str(val).strip().strip("\"'") if s.startswith("[") and s.endswith("]"): @@ -179,8 +181,6 @@ def _wrap_launcher_for_job_id_and_quiet_output(self, launcher_cmd: str, launcher wrapper_path = output_dir / "cloudai_megatron_bridge_submit_and_parse_jobid.sh" log_path = output_dir / "cloudai_megatron_bridge_launcher.log" - container_runtime_exports = self._container_runtime_env_exports() - script_lines = [ "#!/usr/bin/env bash", "set -o pipefail", @@ -193,7 +193,7 @@ def _wrap_launcher_for_job_id_and_quiet_output(self, launcher_cmd: str, launcher # Mirror wrapper stdout/stderr to files for debugging while still emitting to the parent process. 'exec > >(tee -a "$WRAPPER_STDOUT") 2> >(tee -a "$WRAPPER_STDERR" >&2)', "", - *container_runtime_exports, + self._format_env_vars(self._get_startup_env_vars()), "", ': >"$LOG"', "WANDB_INSTALL_RC=0", @@ -289,11 +289,7 @@ def _installed_container_path() -> str: venv_path = tdef.python_executable.venv_path or (self.system.install_path / tdef.python_executable.venv_name) python_bin = (venv_path / "bin" / "python").absolute() - parts: list[str] = [ - f'NEMORUN_HOME="{self.test_run.output_path.absolute()}"', - str(python_bin), - str(Path(launcher_py).absolute()), - ] + parts: list[str] = [str(python_bin), str(Path(launcher_py).absolute())] def add(flag: str, value: Any) -> None: if value is None: @@ -345,10 +341,10 @@ def add_field(field: str, flag: str, value: Any) -> None: if mounts: add("-cm", ",".join(mounts)) - # Pass extra env variables as `-cb export KEY=value` commands to avoid Megatron-Bridge's - # --custom_env_vars parser limitation for comma-containing values. - if self.final_env_vars: - parts.extend(self._build_custom_bash_env_exports()) + # Manage environment variables + mounts.append(f"{(self.test_run.output_path / 'env_vars.sh').absolute()}:/env_vars.sh") + parts.extend(["-cb", "'source /env_vars.sh'"]) + self.write_env_vars() # Model flags (Megatron-Bridge main-branch API) add_field("domain", "--domain", args.domain) @@ -451,25 +447,14 @@ def add_field(field: str, flag: str, value: Any) -> None: add_field("nsys_trace", "--nsys_trace", self._list_or_comma_str(args.nsys_trace)) add_field("nsys_extra_args", "--nsys_extra_args", self._list_or_comma_str(args.nsys_extra_args)) - additional_slurm_params: list[str] = [] - - if self.system.gpus_per_node and self.system.supports_gpu_directives: - additional_slurm_params.append(f"gpus-per-node={self.system.gpus_per_node}") - additional_slurm_params.append(f"gres=gpu:{self.system.gpus_per_node}") - - _, node_list = self.get_cached_nodes_spec() - if node_list: - nodelist_str = ",".join(node_list) - additional_slurm_params.append(f"nodelist={nodelist_str}") - elif self.test_run.exclude_nodes: - additional_slurm_params.append(f"exclude={','.join(self.test_run.exclude_nodes)}") + if additional_slurm_params := self._build_additional_slurm_params(): + parts.extend(["--additional_slurm_params", additional_slurm_params]) - for source in (self.system.extra_srun_args, self.test_run.extra_srun_args): - if source: - additional_slurm_params.extend(self._parse_srun_args_as_slurm_params(source)) + if custom_srun_args := self._build_custom_srun_args(): + parts.extend(["--custom_srun_args", f"'{','.join(custom_srun_args)}'"]) - if additional_slurm_params: - parts.extend(["--additional_slurm_params", shlex.quote(";".join(additional_slurm_params))]) + if custom_env_vars := self._build_custom_env_vars(): + parts.extend(["--custom_env_vars", custom_env_vars]) # Config variant add_field("config_variant", "-cv", args.config_variant) diff --git a/tests/ref_data/megatron-bridge.sbatch b/tests/ref_data/megatron-bridge.sbatch index 3d0dc840a..b26f2cc33 100644 --- a/tests/ref_data/megatron-bridge.sbatch +++ b/tests/ref_data/megatron-bridge.sbatch @@ -8,6 +8,7 @@ WRAPPER_STDOUT="__OUTPUT_DIR__/output/cloudai_megatron_bridge_wrapper.stdout" WRAPPER_STDERR="__OUTPUT_DIR__/output/cloudai_megatron_bridge_wrapper.stderr" exec > >(tee -a "$WRAPPER_STDOUT") 2> >(tee -a "$WRAPPER_STDERR" >&2) +export CUDA_VISIBLE_DEVICES=0,1,2,3 : >"$LOG" WANDB_INSTALL_RC=0 @@ -19,7 +20,7 @@ if [ "${WANDB_INSTALL_RC}" -ne 0 ]; then fi LAUNCH_RC=0 -NEMORUN_HOME="__OUTPUT_DIR__/output" __INSTALL_DIR__/Run__main-venv/bin/python __INSTALL_DIR__/Megatron-Bridge__main/scripts/performance/setup_experiment.py -p main -t 00:20:00 -i __OUTPUT_DIR__/output/megatron_bridge_image.sqsh -hf dummy_token -ng 8 -gn 8 -cm __INSTALL_DIR__/Megatron-Bridge__main:/opt/Megatron-Bridge -cb 'export CUDA_VISIBLE_DEVICES=0,1,2,3' -cb 'export NCCL_DEBUG=INFO' -m qwen3 -mr 30b_a3b --detach false --additional_slurm_params 'gpus-per-node=8;gres=gpu:8' >>"$LOG" 2>&1 || LAUNCH_RC=$? +__INSTALL_DIR__/Run__main-venv/bin/python __INSTALL_DIR__/Megatron-Bridge__main/scripts/performance/setup_experiment.py -p main -t 00:20:00 -i __OUTPUT_DIR__/output/megatron_bridge_image.sqsh -hf dummy_token -ng 8 -gn 8 -cm __INSTALL_DIR__/Megatron-Bridge__main:/opt/Megatron-Bridge -cb 'source /env_vars.sh' -m qwen3 -mr 30b_a3b --detach false --additional_slurm_params 'output=__OUTPUT_DIR__/output/stdout.txt;error=__OUTPUT_DIR__/output/stderr.txt;partition=main;nodes=1;gpus-per-node=8;gres=gpu:8;time=00:20:00' --custom_srun_args '--export=ALL,-N1' --custom_env_vars 'SLURM_JOB_MASTER_NODE=\$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1),NCCL_DEBUG=INFO' >>"$LOG" 2>&1 || LAUNCH_RC=$? JOB_ID="" diff --git a/tests/workloads/megatron_bridge/test_command_gen_strategy_slurm.py b/tests/workloads/megatron_bridge/test_command_gen_strategy_slurm.py index ebad2bb82..2e00094ae 100644 --- a/tests/workloads/megatron_bridge/test_command_gen_strategy_slurm.py +++ b/tests/workloads/megatron_bridge/test_command_gen_strategy_slurm.py @@ -49,6 +49,13 @@ def _wrapper_content(cmd_gen: MegatronBridgeSlurmCommandGenStrategy) -> str: assert wrapper.exists() return wrapper.read_text() + @staticmethod + def _env_file_content(cmd_gen: MegatronBridgeSlurmCommandGenStrategy) -> str: + cmd_gen.gen_exec_command() + env_file = cmd_gen.test_run.output_path / "env_vars.sh" + assert env_file.exists() + return env_file.read_text() + @pytest.fixture def configured_slurm_system(self, slurm_system: SlurmSystem) -> SlurmSystem: slurm_system.account = "acct" @@ -95,6 +102,7 @@ def _make( git_repos=[GitRepo(**repo_kwargs)], ) self._configure_fake_installs(tdef, tmp_path) + (tmp_path / output_subdir).mkdir() return TestRun( test=tdef, name="tr", @@ -169,11 +177,6 @@ def test_model_fields_validation(self, field_name: str, value: str, match: str) with pytest.raises(Exception, match=match): MegatronBridgeCmdArgs.model_validate(data) - def test_git_repos_can_pin_megatron_bridge_commit(self, make_test_run: Callable[..., TestRun]) -> None: - tr = make_test_run(git_commit="abcdef1234567890") - tdef = cast(MegatronBridgeTestDefinition, tr.test) - assert tdef.megatron_bridge_repo.commit == "abcdef1234567890" - def test_defaults_not_emitted_when_not_set_in_toml( self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun] ) -> None: @@ -204,88 +207,6 @@ def test_cuda_graph_scope_normalization( wrapper_content = self._wrapper_content(cmd_gen) assert "--cuda_graph_scope moe_router,moe_preprocess" in wrapper_content - def test_env_vars_are_forwarded_via_custom_bash_cmds( - self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun] - ) -> None: - tr = make_test_run() - tdef = cast(MegatronBridgeTestDefinition, tr.test) - tdef.extra_env_vars = {"CUDA_VISIBLE_DEVICES": "0,1,2,3", "NCCL_DEBUG": "INFO"} - - cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) - wrapper_content = self._wrapper_content(cmd_gen) - assert "--custom_env_vars" not in wrapper_content - assert "-cb 'export CUDA_VISIBLE_DEVICES=0,1,2,3'" in wrapper_content - assert "-cb 'export NCCL_DEBUG=INFO'" in wrapper_content - - def test_container_runtime_env_vars_exported_in_wrapper_script( - self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun] - ) -> None: - configured_slurm_system.global_env_vars = { - "MELLANOX_VISIBLE_DEVICES": "0,1,4,5", - "NCCL_IB_HCA": "roce_p0_r0,roce_p0_r1,roce_p0_r2,roce_p0_r3", - "NCCL_IB_GID_INDEX": "3", - } - tr = make_test_run(output_subdir="out_container_rt") - tdef = cast(MegatronBridgeTestDefinition, tr.test) - tdef.extra_env_vars = {"NVIDIA_VISIBLE_DEVICES": "all", "NCCL_DEBUG": "INFO"} - - cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) - wrapper_content = self._wrapper_content(cmd_gen) - - launcher_idx = wrapper_content.index("setup_experiment.py") - - assert "export MELLANOX_VISIBLE_DEVICES=0,1,4,5" in wrapper_content - assert "export NVIDIA_VISIBLE_DEVICES=all" in wrapper_content - mvd_idx = wrapper_content.index("export MELLANOX_VISIBLE_DEVICES=") - nvd_idx = wrapper_content.index("export NVIDIA_VISIBLE_DEVICES=") - assert mvd_idx < launcher_idx, "MELLANOX_VISIBLE_DEVICES must be exported before the launcher" - assert nvd_idx < launcher_idx, "NVIDIA_VISIBLE_DEVICES must be exported before the launcher" - - assert "-cb 'export MELLANOX_VISIBLE_DEVICES=0,1,4,5'" in wrapper_content - assert "-cb 'export NVIDIA_VISIBLE_DEVICES=all'" in wrapper_content - assert "-cb 'export NCCL_IB_HCA=roce_p0_r0,roce_p0_r1,roce_p0_r2,roce_p0_r3'" in wrapper_content - assert "-cb 'export NCCL_DEBUG=INFO'" in wrapper_content - - assert "export NCCL_IB_HCA=" not in wrapper_content.split("setup_experiment.py")[0] - assert "export NCCL_DEBUG=" not in wrapper_content.split("setup_experiment.py")[0] - - def test_wrapper_emits_job_id_even_when_launcher_non_zero( - self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun] - ) -> None: - tr = make_test_run() - cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) - wrapper_content = self._wrapper_content(cmd_gen) - assert 'if [ "${LAUNCH_RC}" -ne 0 ]; then' in wrapper_content - assert 'echo "Submitted batch job ${JOB_ID}"' in wrapper_content - assert 'exit "${LAUNCH_RC}"' not in wrapper_content - assert "Submitted batch job[ ]+[0-9]+" in wrapper_content - - def test_wrapper_installs_wandb_before_launcher( - self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun] - ) -> None: - tr = make_test_run() - cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) - wrapper_content = self._wrapper_content(cmd_gen) - - assert "-m pip install wandb numpy==1.26.4" in wrapper_content - wandb_idx = wrapper_content.index("-m pip install wandb") - launcher_idx = wrapper_content.index("setup_experiment.py") - assert wandb_idx < launcher_idx - - def test_wrapper_exits_when_wandb_install_fails( - self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun] - ) -> None: - tr = make_test_run() - cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) - wrapper_content = self._wrapper_content(cmd_gen) - - assert 'if [ "${WANDB_INSTALL_RC}" -ne 0 ]; then' in wrapper_content - assert ( - 'echo "Failed to install runtime deps (wandb, numpy==1.26.4) in launcher venv (exit ' - '${WANDB_INSTALL_RC})." >&2' - ) in wrapper_content - assert 'exit "${WANDB_INSTALL_RC}"' in wrapper_content - @pytest.mark.parametrize( ("log_content", "expected_is_successful"), ( @@ -339,93 +260,37 @@ def test_use_recipes_emitted_only_when_true( wrapper_content = self._wrapper_content(cmd_gen) assert ("--use_recipes" in wrapper_content) is expected_in_wrapper + @pytest.mark.parametrize("mount_as", ("/opt/custom-megatron", None)) def test_mount_as_adds_repo_to_container_mounts( - self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun], tmp_path: Path - ) -> None: - tr = make_test_run(mount_as="/opt/custom-megatron", output_subdir="out_mount") - tdef = cast(MegatronBridgeTestDefinition, tr.test) - repo_path = tdef.megatron_bridge_repo.installed_path - assert repo_path is not None - - cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) - wrapper_content = self._wrapper_content(cmd_gen) - assert f"-cm {repo_path.absolute()}:/opt/custom-megatron" in wrapper_content - - def test_no_mount_as_skips_repo_container_mount( - self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun] + self, + configured_slurm_system: SlurmSystem, + make_test_run: Callable[..., TestRun], + tmp_path: Path, + mount_as: str | None, ) -> None: - tr = make_test_run(mount_as=None, output_subdir="out_no_mount") + tr = make_test_run(mount_as=mount_as, output_subdir="out_mount") tdef = cast(MegatronBridgeTestDefinition, tr.test) repo_path = tdef.megatron_bridge_repo.installed_path assert repo_path is not None cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) wrapper_content = self._wrapper_content(cmd_gen) - assert f"{repo_path.absolute()}:" not in wrapper_content - assert ":/opt/Megatron-Bridge" not in wrapper_content - - @pytest.mark.parametrize(("system_gpus_per_node", "expected_gpus"), ((None, None), (4, 4))) - def test_gpus_per_node( - self, - configured_slurm_system: SlurmSystem, - make_test_run: Callable[..., TestRun], - system_gpus_per_node: int | None, - expected_gpus: int | None, - ) -> None: - configured_slurm_system.supports_gpu_directives_cache = True - configured_slurm_system.gpus_per_node = system_gpus_per_node - tr = make_test_run(output_subdir="out_gpus") - cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) - wrapper_content = self._wrapper_content(cmd_gen) - - if expected_gpus is None: - assert "--additional_slurm_params" not in wrapper_content - assert "-gn" not in wrapper_content + if mount_as is not None: + assert f"-cm {repo_path.absolute()}:/opt/custom-megatron" in wrapper_content else: - assert "--additional_slurm_params" in wrapper_content - assert f"gpus-per-node={expected_gpus}" in wrapper_content - assert f"gres=gpu:{expected_gpus}" in wrapper_content - assert f"-gn {expected_gpus}" in wrapper_content - - def test_gpus_per_node_skipped_when_gpu_directives_unsupported( - self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun] - ) -> None: - configured_slurm_system.supports_gpu_directives_cache = False - tr = make_test_run(cmd_args_overrides={"gpus_per_node": 2}, output_subdir="out_no_gpu_directives") - cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) - wrapper_content = self._wrapper_content(cmd_gen) - assert "gpus-per-node=2" not in wrapper_content - assert "gres=gpu:2" not in wrapper_content + assert f"{repo_path.absolute()}:" not in wrapper_content + assert ":/opt/Megatron-Bridge" not in wrapper_content - def test_system_extra_srun_args_forwarded( + def test_extra_srun_args_forwarded( self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun] ) -> None: - configured_slurm_system.extra_srun_args = "--reservation my_reserv" + configured_slurm_system.extra_srun_args = "--reservation=my_reservation" tr = make_test_run(output_subdir="out_srun") + tr.extra_srun_args = "--constraint=gpu" cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) wrapper_content = self._wrapper_content(cmd_gen) - assert "reservation=my_reserv" in wrapper_content - - def test_test_run_extra_srun_args_forwarded( - self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun] - ) -> None: - tr = make_test_run(output_subdir="out_tr_srun") - tr.extra_srun_args = "--constraint gpu" - cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) - wrapper_content = self._wrapper_content(cmd_gen) - assert "constraint=gpu" in wrapper_content - - def test_parse_srun_args_as_slurm_params(self) -> None: - result = MegatronBridgeSlurmCommandGenStrategy._parse_srun_args_as_slurm_params( - "--reservation my_reserv --constraint=gpu" - ) - assert result == ["reservation=my_reserv", "constraint=gpu"] - - def test_parse_srun_args_boolean_flags(self) -> None: - result = MegatronBridgeSlurmCommandGenStrategy._parse_srun_args_as_slurm_params( - "--exclusive --reservation my_reserv --overcommit" - ) - assert result == ["exclusive", "reservation=my_reserv", "overcommit"] + assert "--reservation=my_reservation" in wrapper_content + assert "--constraint=gpu" in wrapper_content def test_profiling_ranks_string_format( self, @@ -441,12 +306,27 @@ def test_profiling_ranks_string_format( wrapper_content = self._wrapper_content(cmd_gen) assert "--profiling_ranks 0,1,2,3" in wrapper_content + @pytest.mark.parametrize(("input_vp", "expected_vp"), ((1, "None"), (2, "2"))) def test_vp( self, configured_slurm_system: SlurmSystem, make_test_run: Callable[..., TestRun], + input_vp: int, + expected_vp: str, ): - tr = make_test_run(cmd_args_overrides={"vp": 1}) + tr = make_test_run(cmd_args_overrides={"vp": input_vp}) cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, tr) wrapper_content = self._wrapper_content(cmd_gen) - assert "-vp None" in wrapper_content + assert expected_vp in wrapper_content + + def test_env_vars(self, test_run: TestRun, configured_slurm_system: SlurmSystem): + configured_slurm_system.global_env_vars = {"MELLANOX_VISIBLE_DEVICES": "0,1,2,3", "BLA": "bla"} + cmd_gen = MegatronBridgeSlurmCommandGenStrategy(configured_slurm_system, test_run) + cmd_gen.gen_exec_command() + + env_file = test_run.output_path / "env_vars.sh" + assert env_file.is_file() + assert env_file.read_text().splitlines() == [ + 'export MELLANOX_VISIBLE_DEVICES="0,1,2,3"', + 'export BLA="bla"', + ] From 1dd08c91d5a84ca0f874736e5273fa52b19440d7 Mon Sep 17 00:00:00 2001 From: Ivan Podkidyshev Date: Mon, 20 Apr 2026 11:55:15 +0200 Subject: [PATCH 3/3] small fixes --- .../megatron_bridge/slurm_command_gen_strategy.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py index 76c39f6fd..5d8eeccc2 100644 --- a/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py +++ b/src/cloudai/workloads/megatron_bridge/slurm_command_gen_strategy.py @@ -99,10 +99,6 @@ def _build_custom_env_vars(self) -> str: ignored_env_vars[k] = v continue - # make sure expressions aren't evaluated before sbatch submit - if "$(" in v: - v = v.replace("$(", "\\$(") - custom_env_vars[k] = v if ignored_env_vars: @@ -338,14 +334,15 @@ def add_field(field: str, flag: str, value: Any) -> None: parts.append("-d") add_field("num_gpus", "-ng", args.num_gpus) add_field("gpus_per_node", "-gn", self.system.gpus_per_node) - if mounts: - add("-cm", ",".join(mounts)) # Manage environment variables mounts.append(f"{(self.test_run.output_path / 'env_vars.sh').absolute()}:/env_vars.sh") parts.extend(["-cb", "'source /env_vars.sh'"]) self.write_env_vars() + if mounts: + add("-cm", ",".join(mounts)) + # Model flags (Megatron-Bridge main-branch API) add_field("domain", "--domain", args.domain) if args.use_recipes and "use_recipes" in fields_set: @@ -451,7 +448,7 @@ def add_field(field: str, flag: str, value: Any) -> None: parts.extend(["--additional_slurm_params", additional_slurm_params]) if custom_srun_args := self._build_custom_srun_args(): - parts.extend(["--custom_srun_args", f"'{','.join(custom_srun_args)}'"]) + parts.extend([f"--custom_srun_args='{','.join(custom_srun_args)}'"]) if custom_env_vars := self._build_custom_env_vars(): parts.extend(["--custom_env_vars", custom_env_vars])