From 47bf72ae805c05ea9c4587979caf362d01e0327d Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Wed, 24 Jun 2026 14:05:08 +0530 Subject: [PATCH 1/5] [Feature] Speed up tests by moving external library imports to inside test functions Replace module-level try/except ImportError patterns with lightweight importlib.util.find_spec() checks, and move actual library imports (gym, gymnasium, ray, tensorboard, hydra, functorch, etc.) into the test functions that need them. This reduces import time when multiprocessing test workers spawn new Python processes. --- test/libs/conftest.py | 29 ++++++++-------- test/libs/test_datasets.py | 20 +++++------ test/libs/test_gym.py | 40 +++++++++++----------- test/modules/_modules_common.py | 13 ++------ test/modules/test_rnn.py | 48 ++++++++++++--------------- test/objectives/_objectives_common.py | 29 ++++++---------- test/objectives/test_ppo.py | 4 +-- test/test_helpers.py | 11 +++--- test/test_inference_server.py | 14 ++------ test/test_trainer.py | 35 ++++++++++--------- 10 files changed, 106 insertions(+), 137 deletions(-) diff --git a/test/libs/conftest.py b/test/libs/conftest.py index 79baad212d6..78088fb7bd2 100644 --- a/test/libs/conftest.py +++ b/test/libs/conftest.py @@ -51,20 +51,7 @@ _has_gymnasium = importlib.util.find_spec("gymnasium") is not None -_has_isaaclab = importlib.util.find_spec("isaaclab") is not None - _has_gym_regular = importlib.util.find_spec("gym") is not None -if _has_gymnasium: - set_gym_backend("gymnasium").set() - import gymnasium - - assert gym_backend() is gymnasium -elif _has_gym: - set_gym_backend("gym").set() - import gym - - assert gym_backend() is gym - _has_meltingpot = importlib.util.find_spec("meltingpot") is not None _has_minigrid = importlib.util.find_spec("minigrid") is not None @@ -73,7 +60,21 @@ @pytest.fixture(scope="session", autouse=True) -def maybe_init_minigrid(): +def _setup_gym_backend(): + if _has_gymnasium: + set_gym_backend("gymnasium").set() + import gymnasium + + assert gym_backend() is gymnasium + elif _has_gym_regular: + set_gym_backend("gym").set() + import gym + + assert gym_backend() is gym + + +@pytest.fixture(scope="session", autouse=True) +def maybe_init_minigrid(_setup_gym_backend): if _has_minigrid and _has_gymnasium: import minigrid diff --git a/test/libs/test_datasets.py b/test/libs/test_datasets.py index a6a6116b595..9b7a687e317 100644 --- a/test/libs/test_datasets.py +++ b/test/libs/test_datasets.py @@ -60,12 +60,6 @@ _has_minari = importlib.util.find_spec("minari") is not None _has_gymnasium = importlib.util.find_spec("gymnasium") is not None -if importlib.util.find_spec("gym"): - import gym - -if _has_gymnasium: - import gymnasium - @pytest.mark.slow class TestGenDGRL: @@ -310,6 +304,7 @@ def test_d4rl_dummy(self, task): @pytest.mark.parametrize("from_env", [True, False]) def test_dataset_build(self, task, split_trajs, from_env): import d4rl # noqa: F401 + import gym t0 = time.time() data = D4RLExperienceReplay( @@ -670,6 +665,7 @@ def test_local_minari_dataset_loading(self, tmpdir): MINARI_DATASETS_PATH = os.environ.get("MINARI_DATASETS_PATH") os.environ["MINARI_DATASETS_PATH"] = str(tmpdir) try: + import gymnasium import minari from minari import DataCollector @@ -716,9 +712,9 @@ def test_local_minari_dataset_loading(self, tmpdir): torchrl_logger.info( f"[Local Minari] Sampling time {1000 * (t1 - t0):4.4f} ms" ) - assert data.metadata["action_space"].is_in( - sample["action"] - ), "Invalid action sample" + assert data.metadata["action_space"].is_in(sample["action"]), ( + "Invalid action sample" + ) assert data.metadata["observation_space"].is_in( sample["observation"] ), "Invalid observation sample" @@ -1044,9 +1040,9 @@ def test_openx( sample = dataset.sample() assert sample.shape == (batch_size,) if slice_len is not None: - assert sample.get(("next", "done")).sum() == int( - batch_size // slice_len - ), sample.get(("next", "done")) + assert sample.get(("next", "done")).sum() == int(batch_size // slice_len), ( + sample.get(("next", "done")) + ) elif num_slices is not None: assert sample.get(("next", "done")).sum() == num_slices diff --git a/test/libs/test_gym.py b/test/libs/test_gym.py index a442ff40005..04bbed0f070 100644 --- a/test/libs/test_gym.py +++ b/test/libs/test_gym.py @@ -77,9 +77,6 @@ _has_gymnasium = importlib.util.find_spec("gymnasium") is not None _has_minigrid = importlib.util.find_spec("minigrid") is not None -if _has_gymnasium: - import gymnasium - try: from torch.utils._pytree import tree_flatten @@ -195,8 +192,7 @@ def _step(self, tensordict): batch_size=[], ) - def _set_seed(self, seed: int | None) -> None: - ... + def _set_seed(self, seed: int | None) -> None: ... @implement_for("gym", None, "0.18") def _make_spec(self, batch_size, cat, cat_shape, multicat, multicat_shape): @@ -281,13 +277,11 @@ def test_gym_spec_cast(self, categorical): def test_gym_new_spec_reg(self): Space = gym_backend("spaces").Space - class MySpaceParent(Space): - ... + class MySpaceParent(Space): ... s_parent = MySpaceParent() - class MySpaceChild(MySpaceParent): - ... + class MySpaceChild(MySpaceParent): ... # We intentionally register first the child then the parent @register_gym_spec_conversion(MySpaceChild) @@ -302,8 +296,7 @@ def convert_myspace_parent(spec, **kwargs): assert _gym_to_torchrl_spec_transform(s_parent).example_data == "parent" assert _gym_to_torchrl_spec_transform(s_child).example_data == "child" - class NoConversionSpace(Space): - ... + class NoConversionSpace(Space): ... s_no_conv = NoConversionSpace() with pytest.raises( @@ -1664,6 +1657,7 @@ def _test_resetting_strategies(self, heterogeneous, kwargs): def test_is_from_pixels_simple_env(self): """Test that _is_from_pixels correctly identifies non-pixel environments.""" + # Test with a simple environment that doesn't have pixels class SimpleEnv: def __init__(self): @@ -1681,6 +1675,7 @@ def __init__(self): def test_is_from_pixels_box_env(self): """Test that _is_from_pixels correctly identifies pixel Box environments.""" + # Test with a pixel-like environment class PixelEnv: def __init__(self): @@ -1700,6 +1695,7 @@ def __init__(self): def test_is_from_pixels_dict_env(self): """Test that _is_from_pixels correctly identifies Dict environments with pixels.""" + # Test with a Dict environment that has pixels class DictPixelEnv: def __init__(self): @@ -1718,12 +1714,13 @@ def __init__(self): # This should return True since it has a "pixels" key result = _is_from_pixels(dict_pixel_env) - assert ( - result is True - ), f"Expected True for Dict environment with pixels, got {result}" + assert result is True, ( + f"Expected True for Dict environment with pixels, got {result}" + ) def test_is_from_pixels_dict_env_no_pixels(self): """Test that _is_from_pixels correctly identifies Dict environments without pixels.""" + # Test with a Dict environment that doesn't have pixels class DictNoPixelEnv: def __init__(self): @@ -1742,9 +1739,9 @@ def __init__(self): # This should return False since it doesn't have a "pixels" key result = _is_from_pixels(dict_no_pixel_env) - assert ( - result is False - ), f"Expected False for Dict environment without pixels, got {result}" + assert result is False, ( + f"Expected False for Dict environment without pixels, got {result}" + ) def test_num_workers_returns_parallel_env(self): """Ensure explicit TorchRL `num_workers` returns a lazy ParallelEnv, while gym's @@ -1858,9 +1855,9 @@ def mock_isinstance(obj, cls): # This should return True since it's detected as a pixel wrapper result = _is_from_pixels(wrapped_env) - assert ( - result is True - ), f"Expected True for wrapped environment, got {result}" + assert result is True, ( + f"Expected True for wrapped environment, got {result}" + ) finally: # Restore original isinstance builtins.isinstance = original_isinstance @@ -1869,6 +1866,7 @@ def mock_isinstance(obj, cls): def test_gymnasium_num_envs(self, num_envs, request): if not _has_gymnasium: pytest.skip("gymnasium not found") + import gymnasium gym_version = version.parse(gymnasium.__version__) if version.parse("1.0.0") <= gym_version < version.parse("1.1.0"): @@ -1904,6 +1902,8 @@ class TestMiniGrid: ], ) def test_minigrid(self, id): + import gymnasium + env_base = gymnasium.make(id) env = GymWrapper(env_base) check_env_specs(env) diff --git a/test/modules/_modules_common.py b/test/modules/_modules_common.py index c6d6ea98b71..e4bf7362a6b 100644 --- a/test/modules/_modules_common.py +++ b/test/modules/_modules_common.py @@ -33,13 +33,6 @@ def _has_triton_backend() -> bool: _has_triton = _has_triton_backend() _triton_skip_reason = "requires triton (>= 2.2) and CUDA" -_has_functorch = False -try: - try: - from torch import vmap as vmap # noqa: F401 - except ImportError: - from functorch import vmap as vmap # noqa: F401 - - _has_functorch = True -except ImportError: - pass +_has_functorch = ( + hasattr(torch, "vmap") or importlib.util.find_spec("functorch") is not None +) diff --git a/test/modules/test_rnn.py b/test/modules/test_rnn.py index fb9aed7ad60..14ab673ed81 100644 --- a/test/modules/test_rnn.py +++ b/test/modules/test_rnn.py @@ -75,12 +75,6 @@ _has_hoptorch = importlib.util.find_spec("hoptorch") is not None -if _has_functorch: - try: - from torch import vmap - except ImportError: - from functorch import vmap - @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("bias", [True, False]) @@ -95,9 +89,9 @@ def test_python_lstm_cell(device, bias): lstm_cell1.named_parameters(), lstm_cell2.named_parameters() ): assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" - assert ( - v1.shape == v2.shape - ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + assert v1.shape == v2.shape, ( + f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + ) # Run loop input = torch.randn(2, 3, 10, device=device) @@ -131,9 +125,9 @@ def test_python_gru_cell(device, bias): ): assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" assert (v1 == v2).all() - assert ( - v1.shape == v2.shape - ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + assert v1.shape == v2.shape, ( + f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + ) # Run loop input = torch.randn(2, 3, 10, device=device) @@ -179,9 +173,9 @@ def test_python_lstm(device, bias, dropout, batch_first, num_layers): # Make sure parameters match for (k1, v1), (k2, v2) in zip(lstm1.named_parameters(), lstm2.named_parameters()): assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" - assert ( - v1.shape == v2.shape - ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + assert v1.shape == v2.shape, ( + f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + ) if batch_first: input = torch.randn(B, T, 10, device=device) @@ -248,9 +242,9 @@ def test_python_gru(device, bias, dropout, batch_first, num_layers): for (k1, v1), (k2, v2) in zip(gru1.named_parameters(), gru2.named_parameters()): assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" torch.testing.assert_close(v1, v2) - assert ( - v1.shape == v2.shape - ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + assert v1.shape == v2.shape, ( + f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + ) if batch_first: input = torch.randn(B, T, 10, device=device) @@ -552,7 +546,6 @@ def test_lstm_parallel_env( def _test_lstm_parallel_env( self, python_based, parallel, heterogeneous, within, maybe_fork_ParallelEnv ): - torch.manual_seed(0) num_envs = 3 device = "cuda" if torch.cuda.device_count() else "cpu" @@ -654,6 +647,8 @@ def test_lstm_parallel_within( not _has_functorch, reason="vmap can only be used with functorch" ) def test_lstm_vmap_complex_model(self): + from torch import vmap + # Tests that all ops in GRU are compatible with VMAP (when build using # the PT backend). # This used to fail when splitting the input based on the is_init mask. @@ -2512,6 +2507,8 @@ def test_gru_parallel_within( not _has_functorch, reason="vmap can only be used with functorch" ) def test_gru_vmap_complex_model(self): + from torch import vmap + # Tests that all ops in GRU are compatible with VMAP (when build using # the PT backend). # This used to fail when splitting the input based on the is_init mask. @@ -3286,7 +3283,6 @@ def test_gru_module_scan_non_canonical_hidden_strides(self): def test_get_primers_from_module(): - # No primers in the model module = MLP(in_features=10, out_features=10, num_cells=[]) transform = get_primers_from_module(module) @@ -3862,9 +3858,9 @@ def loss_for(mod): atol = 5e-3 for k in grads_pad: - assert torch.isfinite( - grads_triton[k] - ).all(), f"precision={precision} produced non-finite grad for {k}" + assert torch.isfinite(grads_triton[k]).all(), ( + f"precision={precision} produced non-finite grad for {k}" + ) torch.testing.assert_close( grads_pad[k], grads_triton[k], atol=atol, rtol=atol ) @@ -3927,9 +3923,9 @@ def loss_for(mod): atol = 5e-3 for k in grads_pad: - assert torch.isfinite( - grads_triton[k] - ).all(), f"precision={precision} produced non-finite grad for {k}" + assert torch.isfinite(grads_triton[k]).all(), ( + f"precision={precision} produced non-finite grad for {k}" + ) torch.testing.assert_close( grads_pad[k], grads_triton[k], atol=atol, rtol=atol ) diff --git a/test/objectives/_objectives_common.py b/test/objectives/_objectives_common.py index abaa9ed7e21..570d303ea3e 100644 --- a/test/objectives/_objectives_common.py +++ b/test/objectives/_objectives_common.py @@ -24,16 +24,10 @@ from torchrl.data import Composite, Unbounded from torchrl.envs import EnvBase -_has_functorch = True -try: - import functorch as ft # noqa - - make_functional_with_buffers = ft.make_functional_with_buffers - FUNCTORCH_ERR = "" -except ImportError as err: - _has_functorch = False - FUNCTORCH_ERR = str(err) - make_functional_with_buffers = None +_has_functorch = ( + hasattr(torch, "vmap") or importlib.util.find_spec("functorch") is not None +) +FUNCTORCH_ERR = "" _has_transformers = bool(importlib.util.find_spec("transformers")) _has_botorch = bool(importlib.util.find_spec("botorch")) @@ -129,14 +123,11 @@ def make_composite_dist(cls): def _step( self, tensordict: TensorDictBase, - ) -> TensorDictBase: - ... + ) -> TensorDictBase: ... - def _reset(self, tensordic): - ... + def _reset(self, tensordic): ... - def _set_seed(self, seed: int | None) -> None: - ... + def _set_seed(self, seed: int | None) -> None: ... class LossModuleTestBase: @@ -149,9 +140,9 @@ def _composite_log_prob(self): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - assert hasattr( - cls, "test_reset_parameters_recursive" - ), "Please add a test_reset_parameters_recursive test for this class" + assert hasattr(cls, "test_reset_parameters_recursive"), ( + "Please add a test_reset_parameters_recursive test for this class" + ) def _flatten_in_keys(self, in_keys): return [ diff --git a/test/objectives/test_ppo.py b/test/objectives/test_ppo.py index d1666b8dc1f..64f38d45699 100644 --- a/test/objectives/test_ppo.py +++ b/test/objectives/test_ppo.py @@ -17,7 +17,6 @@ _has_transformers, FUNCTORCH_ERR, LossModuleTestBase, - make_functional_with_buffers, MARLEnv, ) @@ -1175,7 +1174,6 @@ def test_ppo_notensordict( loss_val = loss(**kwargs) torch.manual_seed(self.seed) if beta is not None: - loss.beta = beta.clone() loss_val_td = loss(td) @@ -2163,6 +2161,8 @@ def test_a2c_diff(self, device, gradient_mode, advantage, composite_action_dist) loss_fn = A2CLoss(actor, value, loss_critic_type="l2") + from functorch import make_functional_with_buffers + floss_fn, params, buffers = make_functional_with_buffers(loss_fn) if advantage is not None: diff --git a/test/test_helpers.py b/test/test_helpers.py index 61367727426..8ca58f7cead 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -6,6 +6,7 @@ import argparse import dataclasses +import importlib.util import pathlib import sys from time import sleep @@ -44,20 +45,18 @@ make_dqn_actor, ) -try: +_has_hydra = importlib.util.find_spec("hydra") is not None + +if _has_hydra: from hydra import compose, initialize from hydra.core.config_store import ConfigStore - _has_hydra = True - @pytest.fixture(autouse=True, scope="module") def clear_hydra(): from hydra.core.global_hydra import GlobalHydra GlobalHydra.instance().clear() -except ImportError: - _has_hydra = False TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) if TORCH_VERSION < version.parse("1.12.0"): @@ -73,7 +72,6 @@ def clear_hydra(): @pytest.fixture def dreamer_constructor_fixture(): - # we hack the env constructor sys.path.append( str(pathlib.Path(__file__).parent.parent / "sota-implementations" / "dreamer") @@ -220,7 +218,6 @@ def test_timeit(): @pytest.mark.skipif(not _has_hydra, reason="No hydra library found") @pytest.mark.parametrize("from_pixels", [(), ("from_pixels=True", "catframes=4")]) def test_transformed_env_constructor_with_state_dict(from_pixels): - config_fields = [ (config_field.name, config_field.type, config_field) for config_cls in ( diff --git a/test/test_inference_server.py b/test/test_inference_server.py index 683250e9f6a..ec73d627513 100644 --- a/test/test_inference_server.py +++ b/test/test_inference_server.py @@ -5,6 +5,7 @@ from __future__ import annotations import concurrent.futures +import importlib.util import multiprocessing as mp import threading import time @@ -28,17 +29,8 @@ ) from torchrl.modules.inference_server._monarch import MonarchTransport -_has_ray = True -try: - import ray -except ImportError: - _has_ray = False - -_has_monarch = True -try: - import monarch # noqa: F401 -except ImportError: - _has_monarch = False +_has_ray = importlib.util.find_spec("ray") is not None +_has_monarch = importlib.util.find_spec("monarch") is not None # ============================================================================= diff --git a/test/test_trainer.py b/test/test_trainer.py index 4f9ed611815..9bf96ffb02e 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -18,13 +18,7 @@ import torch from torch import nn -try: - from tensorboard.backend.event_processing import event_accumulator - from torchrl.record.loggers.tensorboard import TensorboardLogger - - _has_tb = True -except ImportError: - _has_tb = False +_has_tb = importlib.util.find_spec("tensorboard") is not None from tensordict import TensorDict from torchrl.data import ( @@ -488,9 +482,7 @@ def make_storage(): if re_init: assert load_state_dict_has_been_called_td[0] if backend != "torch": - td1 = ( - storage._storage - ) # trainer.app_state["state"]["replay_buffer.replay_buffer._storage._storage"] + td1 = storage._storage # trainer.app_state["state"]["replay_buffer.replay_buffer._storage._storage"] td2 = trainer2._modules["replay_buffer"].replay_buffer._storage._storage if storage_type == "list": assert all((_td1 == _td2).all() for _td1, _td2 in zip(td1, td2)) @@ -901,6 +893,9 @@ def _get_args(self): return args def test_recorder(self, N=8): + from tensorboard.backend.event_processing import event_accumulator + from torchrl.record.loggers.tensorboard import TensorboardLogger + args = self._get_args() with tempfile.TemporaryDirectory() as folder: logger = TensorboardLogger(exp_name=folder) @@ -926,7 +921,7 @@ def test_recorder(self, N=8): for _ in range(N): recorder(None) - for (_, _, filenames) in walk(folder): + for _, _, filenames in walk(folder): filename = filenames[0] break @@ -954,6 +949,8 @@ def test_recorder(self, N=8): ], ) def test_recorder_load(self, backend, N=8): + from torchrl.record.loggers.tensorboard import TensorboardLogger + if not _has_ts and backend == "torchsnapshot": pytest.skip("torchsnapshot not found") @@ -963,7 +960,10 @@ def test_recorder_load(self, backend, N=8): LogValidationReward.state_dict, Recorder_state_dict = _fun_checker( LogValidationReward.state_dict, state_dict_has_been_called ) - (LogValidationReward.load_state_dict, Recorder_load_state_dict,) = _fun_checker( + ( + LogValidationReward.load_state_dict, + Recorder_load_state_dict, + ) = _fun_checker( LogValidationReward.load_state_dict, load_state_dict_has_been_called ) @@ -1089,7 +1089,10 @@ def _make_countframe_and_trainer(tmpdirname): frame_skip = 3 batch = 10 - with tempfile.TemporaryDirectory() as tmpdirname, tempfile.TemporaryDirectory() as tmpdirname2: + with ( + tempfile.TemporaryDirectory() as tmpdirname, + tempfile.TemporaryDirectory() as tmpdirname2, + ): trainer, count_frames, file = _make_countframe_and_trainer(tmpdirname) td = TensorDict( { @@ -1298,9 +1301,9 @@ def capture_hook(optim_steps, average_losses): def test_subclass_exposes_auto_log_optim_steps(self, trainer_cls): """Every Trainer subclass must surface auto_log_optim_steps in its __init__.""" sig = inspect.signature(trainer_cls.__init__) - assert ( - "auto_log_optim_steps" in sig.parameters - ), f"{trainer_cls.__name__}.__init__ must accept auto_log_optim_steps" + assert "auto_log_optim_steps" in sig.parameters, ( + f"{trainer_cls.__name__}.__init__ must accept auto_log_optim_steps" + ) assert sig.parameters["auto_log_optim_steps"].default is True From 00cc79befe3fdc2ec06825a6a599c5fb2cda59ef Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 24 Jun 2026 18:13:08 -0700 Subject: [PATCH 2/5] [Test] Fix lazy optional imports --- test/modules/test_rnn.py | 17 +++++++++++++++-- test/test_helpers.py | 24 ++++++++++++++++++------ test/test_inference_server.py | 12 ++++++++++++ 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/test/modules/test_rnn.py b/test/modules/test_rnn.py index 14ab673ed81..892660b8b83 100644 --- a/test/modules/test_rnn.py +++ b/test/modules/test_rnn.py @@ -74,6 +74,19 @@ ) _has_hoptorch = importlib.util.find_spec("hoptorch") is not None +_vmap = None + + +def _get_vmap(): + global _vmap + if _vmap is None: + if hasattr(torch, "vmap"): + _vmap = torch.vmap + else: + from functorch import vmap + + _vmap = vmap + return _vmap @pytest.mark.parametrize("device", get_default_devices()) @@ -647,7 +660,7 @@ def test_lstm_parallel_within( not _has_functorch, reason="vmap can only be used with functorch" ) def test_lstm_vmap_complex_model(self): - from torch import vmap + vmap = _get_vmap() # Tests that all ops in GRU are compatible with VMAP (when build using # the PT backend). @@ -2507,7 +2520,7 @@ def test_gru_parallel_within( not _has_functorch, reason="vmap can only be used with functorch" ) def test_gru_vmap_complex_model(self): - from torch import vmap + vmap = _get_vmap() # Tests that all ops in GRU are compatible with VMAP (when build using # the PT backend). diff --git a/test/test_helpers.py b/test/test_helpers.py index 8ca58f7cead..6b3efe71077 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -46,16 +46,24 @@ ) _has_hydra = importlib.util.find_spec("hydra") is not None +_hydra_deps = None -if _has_hydra: - from hydra import compose, initialize - from hydra.core.config_store import ConfigStore - @pytest.fixture(autouse=True, scope="module") - def clear_hydra(): +def _get_hydra_deps(): + global _hydra_deps + if _hydra_deps is None: + from hydra import compose, initialize + from hydra.core.config_store import ConfigStore from hydra.core.global_hydra import GlobalHydra - GlobalHydra.instance().clear() + _hydra_deps = compose, initialize, ConfigStore, GlobalHydra + return _hydra_deps + + +@pytest.fixture(scope="module") +def clear_hydra(): + *_, GlobalHydra = _get_hydra_deps() + GlobalHydra.instance().clear() TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) @@ -86,6 +94,7 @@ def dreamer_constructor_fixture(): @pytest.mark.skipif(not _has_gym, reason="No gym library found") @pytest.mark.skipif(not _has_tv, reason="No torchvision library found") @pytest.mark.skipif(not _has_hydra, reason="No hydra library found") +@pytest.mark.usefixtures("clear_hydra") @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("noisy", [(), ("noisy=True",)]) @pytest.mark.parametrize("distributional", [(), ("distributional=True",)]) @@ -97,6 +106,7 @@ def dreamer_constructor_fixture(): def test_dqn_maker( device, noisy, distributional, from_pixels, categorical_action_encoding ): + compose, initialize, ConfigStore, _ = _get_hydra_deps() flags = list(noisy + distributional + from_pixels + categorical_action_encoding) + [ "env_name=CartPole-v1" ] @@ -216,8 +226,10 @@ def test_timeit(): @pytest.mark.skipif(not _has_hydra, reason="No hydra library found") +@pytest.mark.usefixtures("clear_hydra") @pytest.mark.parametrize("from_pixels", [(), ("from_pixels=True", "catframes=4")]) def test_transformed_env_constructor_with_state_dict(from_pixels): + compose, initialize, ConfigStore, _ = _get_hydra_deps() config_fields = [ (config_field.name, config_field.type, config_field) for config_cls in ( diff --git a/test/test_inference_server.py b/test/test_inference_server.py index ec73d627513..4327874c704 100644 --- a/test/test_inference_server.py +++ b/test/test_inference_server.py @@ -31,6 +31,16 @@ _has_ray = importlib.util.find_spec("ray") is not None _has_monarch = importlib.util.find_spec("monarch") is not None +_ray = None + + +def _ray_lib(): + global _ray + if _ray is None: + import ray + + _ray = ray + return _ray # ============================================================================= @@ -413,6 +423,7 @@ def bad_model(td): class TestRayTransport: @classmethod def setup_class(cls): + ray = _ray_lib() if not ray.is_initialized(): ray.init(num_cpus=4, ignore_reinit_error=True) @@ -457,6 +468,7 @@ def client_fn(client_idx): def test_ray_remote_actor(self): """A Ray remote actor can use the client to get inference results.""" + ray = _ray_lib() transport = RayTransport() client = transport.client() policy = _make_policy() From ff416441e82a98688bca97f74519a978f17248cf Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 25 Jun 2026 13:41:41 -0700 Subject: [PATCH 3/5] [Test] Apply lint formatting --- test/libs/conftest.py | 2 +- test/libs/test_datasets.py | 12 ++++----- test/libs/test_gym.py | 30 ++++++++++++---------- test/modules/test_rnn.py | 36 +++++++++++++-------------- test/objectives/_objectives_common.py | 15 ++++++----- test/test_trainer.py | 15 ++++++----- 6 files changed, 58 insertions(+), 52 deletions(-) diff --git a/test/libs/conftest.py b/test/libs/conftest.py index 78088fb7bd2..20267151b08 100644 --- a/test/libs/conftest.py +++ b/test/libs/conftest.py @@ -12,7 +12,7 @@ import torch from packaging import version -from torchrl.envs.libs.gym import _has_gym, gym_backend, set_gym_backend +from torchrl.envs.libs.gym import gym_backend, set_gym_backend pytestmark = [ pytest.mark.filterwarnings("error"), diff --git a/test/libs/test_datasets.py b/test/libs/test_datasets.py index 9b7a687e317..4f409814c18 100644 --- a/test/libs/test_datasets.py +++ b/test/libs/test_datasets.py @@ -712,9 +712,9 @@ def test_local_minari_dataset_loading(self, tmpdir): torchrl_logger.info( f"[Local Minari] Sampling time {1000 * (t1 - t0):4.4f} ms" ) - assert data.metadata["action_space"].is_in(sample["action"]), ( - "Invalid action sample" - ) + assert data.metadata["action_space"].is_in( + sample["action"] + ), "Invalid action sample" assert data.metadata["observation_space"].is_in( sample["observation"] ), "Invalid observation sample" @@ -1040,9 +1040,9 @@ def test_openx( sample = dataset.sample() assert sample.shape == (batch_size,) if slice_len is not None: - assert sample.get(("next", "done")).sum() == int(batch_size // slice_len), ( - sample.get(("next", "done")) - ) + assert sample.get(("next", "done")).sum() == int( + batch_size // slice_len + ), sample.get(("next", "done")) elif num_slices is not None: assert sample.get(("next", "done")).sum() == num_slices diff --git a/test/libs/test_gym.py b/test/libs/test_gym.py index 04bbed0f070..d7e70c709f8 100644 --- a/test/libs/test_gym.py +++ b/test/libs/test_gym.py @@ -192,7 +192,8 @@ def _step(self, tensordict): batch_size=[], ) - def _set_seed(self, seed: int | None) -> None: ... + def _set_seed(self, seed: int | None) -> None: + ... @implement_for("gym", None, "0.18") def _make_spec(self, batch_size, cat, cat_shape, multicat, multicat_shape): @@ -277,11 +278,13 @@ def test_gym_spec_cast(self, categorical): def test_gym_new_spec_reg(self): Space = gym_backend("spaces").Space - class MySpaceParent(Space): ... + class MySpaceParent(Space): + ... s_parent = MySpaceParent() - class MySpaceChild(MySpaceParent): ... + class MySpaceChild(MySpaceParent): + ... # We intentionally register first the child then the parent @register_gym_spec_conversion(MySpaceChild) @@ -296,7 +299,8 @@ def convert_myspace_parent(spec, **kwargs): assert _gym_to_torchrl_spec_transform(s_parent).example_data == "parent" assert _gym_to_torchrl_spec_transform(s_child).example_data == "child" - class NoConversionSpace(Space): ... + class NoConversionSpace(Space): + ... s_no_conv = NoConversionSpace() with pytest.raises( @@ -1714,9 +1718,9 @@ def __init__(self): # This should return True since it has a "pixels" key result = _is_from_pixels(dict_pixel_env) - assert result is True, ( - f"Expected True for Dict environment with pixels, got {result}" - ) + assert ( + result is True + ), f"Expected True for Dict environment with pixels, got {result}" def test_is_from_pixels_dict_env_no_pixels(self): """Test that _is_from_pixels correctly identifies Dict environments without pixels.""" @@ -1739,9 +1743,9 @@ def __init__(self): # This should return False since it doesn't have a "pixels" key result = _is_from_pixels(dict_no_pixel_env) - assert result is False, ( - f"Expected False for Dict environment without pixels, got {result}" - ) + assert ( + result is False + ), f"Expected False for Dict environment without pixels, got {result}" def test_num_workers_returns_parallel_env(self): """Ensure explicit TorchRL `num_workers` returns a lazy ParallelEnv, while gym's @@ -1855,9 +1859,9 @@ def mock_isinstance(obj, cls): # This should return True since it's detected as a pixel wrapper result = _is_from_pixels(wrapped_env) - assert result is True, ( - f"Expected True for wrapped environment, got {result}" - ) + assert ( + result is True + ), f"Expected True for wrapped environment, got {result}" finally: # Restore original isinstance builtins.isinstance = original_isinstance diff --git a/test/modules/test_rnn.py b/test/modules/test_rnn.py index 892660b8b83..b1eed4230cc 100644 --- a/test/modules/test_rnn.py +++ b/test/modules/test_rnn.py @@ -102,9 +102,9 @@ def test_python_lstm_cell(device, bias): lstm_cell1.named_parameters(), lstm_cell2.named_parameters() ): assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" - assert v1.shape == v2.shape, ( - f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" - ) + assert ( + v1.shape == v2.shape + ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" # Run loop input = torch.randn(2, 3, 10, device=device) @@ -138,9 +138,9 @@ def test_python_gru_cell(device, bias): ): assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" assert (v1 == v2).all() - assert v1.shape == v2.shape, ( - f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" - ) + assert ( + v1.shape == v2.shape + ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" # Run loop input = torch.randn(2, 3, 10, device=device) @@ -186,9 +186,9 @@ def test_python_lstm(device, bias, dropout, batch_first, num_layers): # Make sure parameters match for (k1, v1), (k2, v2) in zip(lstm1.named_parameters(), lstm2.named_parameters()): assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" - assert v1.shape == v2.shape, ( - f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" - ) + assert ( + v1.shape == v2.shape + ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" if batch_first: input = torch.randn(B, T, 10, device=device) @@ -255,9 +255,9 @@ def test_python_gru(device, bias, dropout, batch_first, num_layers): for (k1, v1), (k2, v2) in zip(gru1.named_parameters(), gru2.named_parameters()): assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" torch.testing.assert_close(v1, v2) - assert v1.shape == v2.shape, ( - f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" - ) + assert ( + v1.shape == v2.shape + ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" if batch_first: input = torch.randn(B, T, 10, device=device) @@ -3871,9 +3871,9 @@ def loss_for(mod): atol = 5e-3 for k in grads_pad: - assert torch.isfinite(grads_triton[k]).all(), ( - f"precision={precision} produced non-finite grad for {k}" - ) + assert torch.isfinite( + grads_triton[k] + ).all(), f"precision={precision} produced non-finite grad for {k}" torch.testing.assert_close( grads_pad[k], grads_triton[k], atol=atol, rtol=atol ) @@ -3936,9 +3936,9 @@ def loss_for(mod): atol = 5e-3 for k in grads_pad: - assert torch.isfinite(grads_triton[k]).all(), ( - f"precision={precision} produced non-finite grad for {k}" - ) + assert torch.isfinite( + grads_triton[k] + ).all(), f"precision={precision} produced non-finite grad for {k}" torch.testing.assert_close( grads_pad[k], grads_triton[k], atol=atol, rtol=atol ) diff --git a/test/objectives/_objectives_common.py b/test/objectives/_objectives_common.py index 570d303ea3e..a6532b8444e 100644 --- a/test/objectives/_objectives_common.py +++ b/test/objectives/_objectives_common.py @@ -123,11 +123,14 @@ def make_composite_dist(cls): def _step( self, tensordict: TensorDictBase, - ) -> TensorDictBase: ... + ) -> TensorDictBase: + ... - def _reset(self, tensordic): ... + def _reset(self, tensordic): + ... - def _set_seed(self, seed: int | None) -> None: ... + def _set_seed(self, seed: int | None) -> None: + ... class LossModuleTestBase: @@ -140,9 +143,9 @@ def _composite_log_prob(self): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - assert hasattr(cls, "test_reset_parameters_recursive"), ( - "Please add a test_reset_parameters_recursive test for this class" - ) + assert hasattr( + cls, "test_reset_parameters_recursive" + ), "Please add a test_reset_parameters_recursive test for this class" def _flatten_in_keys(self, in_keys): return [ diff --git a/test/test_trainer.py b/test/test_trainer.py index 9bf96ffb02e..e520c9ee773 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -482,7 +482,9 @@ def make_storage(): if re_init: assert load_state_dict_has_been_called_td[0] if backend != "torch": - td1 = storage._storage # trainer.app_state["state"]["replay_buffer.replay_buffer._storage._storage"] + td1 = ( + storage._storage + ) # trainer.app_state["state"]["replay_buffer.replay_buffer._storage._storage"] td2 = trainer2._modules["replay_buffer"].replay_buffer._storage._storage if storage_type == "list": assert all((_td1 == _td2).all() for _td1, _td2 in zip(td1, td2)) @@ -960,10 +962,7 @@ def test_recorder_load(self, backend, N=8): LogValidationReward.state_dict, Recorder_state_dict = _fun_checker( LogValidationReward.state_dict, state_dict_has_been_called ) - ( - LogValidationReward.load_state_dict, - Recorder_load_state_dict, - ) = _fun_checker( + (LogValidationReward.load_state_dict, Recorder_load_state_dict,) = _fun_checker( LogValidationReward.load_state_dict, load_state_dict_has_been_called ) @@ -1301,9 +1300,9 @@ def capture_hook(optim_steps, average_losses): def test_subclass_exposes_auto_log_optim_steps(self, trainer_cls): """Every Trainer subclass must surface auto_log_optim_steps in its __init__.""" sig = inspect.signature(trainer_cls.__init__) - assert "auto_log_optim_steps" in sig.parameters, ( - f"{trainer_cls.__name__}.__init__ must accept auto_log_optim_steps" - ) + assert ( + "auto_log_optim_steps" in sig.parameters + ), f"{trainer_cls.__name__}.__init__ must accept auto_log_optim_steps" assert sig.parameters["auto_log_optim_steps"].default is True From 525ff299be18130a01edbf0d1b1a7a3719256a7d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 25 Jun 2026 16:36:57 -0700 Subject: [PATCH 4/5] [CI] Keep benchmark torch nightly pinned --- .github/workflows/benchmarks.yml | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 720af90a37a..e481392f89a 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -91,15 +91,22 @@ jobs: # The --extra-index-url onto PyPI is required: torch nightly pulls in # transitive deps (e.g. spmd-types) that are only shipped as sdists on the # torch channel, and building those sdists needs setuptools/wheel which the - # torch index does not host. torch/torchvision still resolve from nightly - # (their dev versions outrank any PyPI stable), and assert_torch_version.sh - # below fails the job loudly if that ever stops holding. + # torch index does not host. Capture the installed nightly + # torch/torchvision versions and constrain later dependency installs so + # PyPI stable releases cannot upgrade them before the version assertion. python3.10 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu126 --extra-index-url https://pypi.org/simple -U - python3.10 -m pip install ninja pytest pytest-benchmark pytest-timeout "hoptorch>=0.1.4" "mujoco>=3.8.1,<3.9.0" "dm_control>=1.0.41" "gym[accept-rom-license,atari]" transformers accelerate - python -m pip install "pybind11[global]" - python3.10 -m pip install cloudpickle packaging importlib_metadata numpy orjson "pyvers>=0.2.0,<0.3.0" + python3.10 - <<'PY' > /tmp/torch-constraints.txt + import torch + import torchvision + + print(f"torch=={torch.__version__}") + print(f"torchvision=={torchvision.__version__}") + PY + python3.10 -m pip install -c /tmp/torch-constraints.txt ninja pytest pytest-benchmark pytest-timeout "hoptorch>=0.1.4" "mujoco>=3.8.1,<3.9.0" "dm_control>=1.0.41" "gym[accept-rom-license,atari]" transformers accelerate + python -m pip install -c /tmp/torch-constraints.txt "pybind11[global]" + python3.10 -m pip install -c /tmp/torch-constraints.txt cloudpickle packaging importlib_metadata numpy orjson "pyvers>=0.2.0,<0.3.0" python3.10 -m pip install --no-deps git+https://github.com/pytorch/tensordict - python3.10 -m pip install safetensors tqdm pandas numpy matplotlib ray + python3.10 -m pip install -c /tmp/torch-constraints.txt safetensors tqdm pandas numpy matplotlib ray python3.10 -m pip install -e . --no-build-isolation --no-deps bash .github/unittest/helpers/assert_torch_version.sh nightly From ad6939db08fa3d3082d51d85f4e1937a37c4c604 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 25 Jun 2026 16:55:08 -0700 Subject: [PATCH 5/5] [CI] Avoid benchmark torch resolver downgrade --- .github/workflows/benchmarks.yml | 18 ++++++++++-------- .github/workflows/benchmarks_pr.yml | 25 +++++++++++++++++-------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index e481392f89a..a0f886a7a1b 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -91,16 +91,18 @@ jobs: # The --extra-index-url onto PyPI is required: torch nightly pulls in # transitive deps (e.g. spmd-types) that are only shipped as sdists on the # torch channel, and building those sdists needs setuptools/wheel which the - # torch index does not host. Capture the installed nightly - # torch/torchvision versions and constrain later dependency installs so - # PyPI stable releases cannot upgrade them before the version assertion. - python3.10 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu126 --extra-index-url https://pypi.org/simple -U + # torch index does not host. Install torch separately so torchvision's + # exact torch dependency cannot make pip backtrack into PyPI stable + # torch. Then install nightly torchvision without dependencies and + # constrain later dependency installs so PyPI stable releases cannot + # upgrade torch/torchvision before the version assertion. + python3.10 -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --extra-index-url https://pypi.org/simple -U + python3.10 -m pip install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu126 --no-deps -U python3.10 - <<'PY' > /tmp/torch-constraints.txt - import torch - import torchvision + from importlib.metadata import version - print(f"torch=={torch.__version__}") - print(f"torchvision=={torchvision.__version__}") + print(f"torch=={version('torch')}") + print(f"torchvision=={version('torchvision')}") PY python3.10 -m pip install -c /tmp/torch-constraints.txt ninja pytest pytest-benchmark pytest-timeout "hoptorch>=0.1.4" "mujoco>=3.8.1,<3.9.0" "dm_control>=1.0.41" "gym[accept-rom-license,atari]" transformers accelerate python -m pip install -c /tmp/torch-constraints.txt "pybind11[global]" diff --git a/.github/workflows/benchmarks_pr.yml b/.github/workflows/benchmarks_pr.yml index 15fdb316d14..b65439f7eb5 100644 --- a/.github/workflows/benchmarks_pr.yml +++ b/.github/workflows/benchmarks_pr.yml @@ -104,15 +104,24 @@ jobs: # The --extra-index-url onto PyPI is required: torch nightly pulls in # transitive deps (e.g. spmd-types) that are only shipped as sdists on the # torch channel, and building those sdists needs setuptools/wheel which the - # torch index does not host. torch/torchvision still resolve from nightly - # (their dev versions outrank any PyPI stable), and assert_torch_version.sh - # below fails the job loudly if that ever stops holding. - python3.10 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu126 --extra-index-url https://pypi.org/simple -U - python3.10 -m pip install ninja pytest pytest-benchmark pytest-timeout "hoptorch>=0.1.4" "mujoco>=3.8.1,<3.9.0" "dm_control>=1.0.41" "gym[accept-rom-license,atari]" transformers accelerate ray - python3.10 -m pip install "pybind11[global]" - python3.10 -m pip install cloudpickle packaging importlib_metadata numpy orjson "pyvers>=0.2.0,<0.3.0" + # torch index does not host. Install torch separately so torchvision's + # exact torch dependency cannot make pip backtrack into PyPI stable + # torch. Then install nightly torchvision without dependencies and + # constrain later dependency installs so PyPI stable releases cannot + # upgrade torch/torchvision before the version assertion. + python3.10 -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --extra-index-url https://pypi.org/simple -U + python3.10 -m pip install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu126 --no-deps -U + python3.10 - <<'PY' > /tmp/torch-constraints.txt + from importlib.metadata import version + + print(f"torch=={version('torch')}") + print(f"torchvision=={version('torchvision')}") + PY + python3.10 -m pip install -c /tmp/torch-constraints.txt ninja pytest pytest-benchmark pytest-timeout "hoptorch>=0.1.4" "mujoco>=3.8.1,<3.9.0" "dm_control>=1.0.41" "gym[accept-rom-license,atari]" transformers accelerate ray + python3.10 -m pip install -c /tmp/torch-constraints.txt "pybind11[global]" + python3.10 -m pip install -c /tmp/torch-constraints.txt cloudpickle packaging importlib_metadata numpy orjson "pyvers>=0.2.0,<0.3.0" python3.10 -m pip install --no-deps git+https://github.com/pytorch/tensordict - python3.10 -m pip install safetensors tqdm pandas numpy matplotlib + python3.10 -m pip install -c /tmp/torch-constraints.txt safetensors tqdm pandas numpy matplotlib bash .github/unittest/helpers/assert_torch_version.sh nightly bash .github/unittest/helpers/assert_torch_tensordict_versions.sh nightly