diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 720af90a37a..a0f886a7a1b 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -91,15 +91,24 @@ jobs: # The --extra-index-url onto PyPI is required: torch nightly pulls in # transitive deps (e.g. spmd-types) that are only shipped as sdists on the # torch channel, and building those sdists needs setuptools/wheel which the - # torch index does not host. torch/torchvision still resolve from nightly - # (their dev versions outrank any PyPI stable), and assert_torch_version.sh - # below fails the job loudly if that ever stops holding. - python3.10 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu126 --extra-index-url https://pypi.org/simple -U - python3.10 -m pip install ninja pytest pytest-benchmark pytest-timeout "hoptorch>=0.1.4" "mujoco>=3.8.1,<3.9.0" "dm_control>=1.0.41" "gym[accept-rom-license,atari]" transformers accelerate - python -m pip install "pybind11[global]" - python3.10 -m pip install cloudpickle packaging importlib_metadata numpy orjson "pyvers>=0.2.0,<0.3.0" + # torch index does not host. Install torch separately so torchvision's + # exact torch dependency cannot make pip backtrack into PyPI stable + # torch. Then install nightly torchvision without dependencies and + # constrain later dependency installs so PyPI stable releases cannot + # upgrade torch/torchvision before the version assertion. + python3.10 -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --extra-index-url https://pypi.org/simple -U + python3.10 -m pip install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu126 --no-deps -U + python3.10 - <<'PY' > /tmp/torch-constraints.txt + from importlib.metadata import version + + print(f"torch=={version('torch')}") + print(f"torchvision=={version('torchvision')}") + PY + python3.10 -m pip install -c /tmp/torch-constraints.txt ninja pytest pytest-benchmark pytest-timeout "hoptorch>=0.1.4" "mujoco>=3.8.1,<3.9.0" "dm_control>=1.0.41" "gym[accept-rom-license,atari]" transformers accelerate + python -m pip install -c /tmp/torch-constraints.txt "pybind11[global]" + python3.10 -m pip install -c /tmp/torch-constraints.txt cloudpickle packaging importlib_metadata numpy orjson "pyvers>=0.2.0,<0.3.0" python3.10 -m pip install --no-deps git+https://github.com/pytorch/tensordict - python3.10 -m pip install safetensors tqdm pandas numpy matplotlib ray + python3.10 -m pip install -c /tmp/torch-constraints.txt safetensors tqdm pandas numpy matplotlib ray python3.10 -m pip install -e . --no-build-isolation --no-deps bash .github/unittest/helpers/assert_torch_version.sh nightly diff --git a/.github/workflows/benchmarks_pr.yml b/.github/workflows/benchmarks_pr.yml index 15fdb316d14..b65439f7eb5 100644 --- a/.github/workflows/benchmarks_pr.yml +++ b/.github/workflows/benchmarks_pr.yml @@ -104,15 +104,24 @@ jobs: # The --extra-index-url onto PyPI is required: torch nightly pulls in # transitive deps (e.g. spmd-types) that are only shipped as sdists on the # torch channel, and building those sdists needs setuptools/wheel which the - # torch index does not host. torch/torchvision still resolve from nightly - # (their dev versions outrank any PyPI stable), and assert_torch_version.sh - # below fails the job loudly if that ever stops holding. - python3.10 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu126 --extra-index-url https://pypi.org/simple -U - python3.10 -m pip install ninja pytest pytest-benchmark pytest-timeout "hoptorch>=0.1.4" "mujoco>=3.8.1,<3.9.0" "dm_control>=1.0.41" "gym[accept-rom-license,atari]" transformers accelerate ray - python3.10 -m pip install "pybind11[global]" - python3.10 -m pip install cloudpickle packaging importlib_metadata numpy orjson "pyvers>=0.2.0,<0.3.0" + # torch index does not host. Install torch separately so torchvision's + # exact torch dependency cannot make pip backtrack into PyPI stable + # torch. Then install nightly torchvision without dependencies and + # constrain later dependency installs so PyPI stable releases cannot + # upgrade torch/torchvision before the version assertion. + python3.10 -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --extra-index-url https://pypi.org/simple -U + python3.10 -m pip install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu126 --no-deps -U + python3.10 - <<'PY' > /tmp/torch-constraints.txt + from importlib.metadata import version + + print(f"torch=={version('torch')}") + print(f"torchvision=={version('torchvision')}") + PY + python3.10 -m pip install -c /tmp/torch-constraints.txt ninja pytest pytest-benchmark pytest-timeout "hoptorch>=0.1.4" "mujoco>=3.8.1,<3.9.0" "dm_control>=1.0.41" "gym[accept-rom-license,atari]" transformers accelerate ray + python3.10 -m pip install -c /tmp/torch-constraints.txt "pybind11[global]" + python3.10 -m pip install -c /tmp/torch-constraints.txt cloudpickle packaging importlib_metadata numpy orjson "pyvers>=0.2.0,<0.3.0" python3.10 -m pip install --no-deps git+https://github.com/pytorch/tensordict - python3.10 -m pip install safetensors tqdm pandas numpy matplotlib + python3.10 -m pip install -c /tmp/torch-constraints.txt safetensors tqdm pandas numpy matplotlib bash .github/unittest/helpers/assert_torch_version.sh nightly bash .github/unittest/helpers/assert_torch_tensordict_versions.sh nightly diff --git a/test/libs/conftest.py b/test/libs/conftest.py index 79baad212d6..20267151b08 100644 --- a/test/libs/conftest.py +++ b/test/libs/conftest.py @@ -12,7 +12,7 @@ import torch from packaging import version -from torchrl.envs.libs.gym import _has_gym, gym_backend, set_gym_backend +from torchrl.envs.libs.gym import gym_backend, set_gym_backend pytestmark = [ pytest.mark.filterwarnings("error"), @@ -51,20 +51,7 @@ _has_gymnasium = importlib.util.find_spec("gymnasium") is not None -_has_isaaclab = importlib.util.find_spec("isaaclab") is not None - _has_gym_regular = importlib.util.find_spec("gym") is not None -if _has_gymnasium: - set_gym_backend("gymnasium").set() - import gymnasium - - assert gym_backend() is gymnasium -elif _has_gym: - set_gym_backend("gym").set() - import gym - - assert gym_backend() is gym - _has_meltingpot = importlib.util.find_spec("meltingpot") is not None _has_minigrid = importlib.util.find_spec("minigrid") is not None @@ -73,7 +60,21 @@ @pytest.fixture(scope="session", autouse=True) -def maybe_init_minigrid(): +def _setup_gym_backend(): + if _has_gymnasium: + set_gym_backend("gymnasium").set() + import gymnasium + + assert gym_backend() is gymnasium + elif _has_gym_regular: + set_gym_backend("gym").set() + import gym + + assert gym_backend() is gym + + +@pytest.fixture(scope="session", autouse=True) +def maybe_init_minigrid(_setup_gym_backend): if _has_minigrid and _has_gymnasium: import minigrid diff --git a/test/libs/test_datasets.py b/test/libs/test_datasets.py index a6a6116b595..4f409814c18 100644 --- a/test/libs/test_datasets.py +++ b/test/libs/test_datasets.py @@ -60,12 +60,6 @@ _has_minari = importlib.util.find_spec("minari") is not None _has_gymnasium = importlib.util.find_spec("gymnasium") is not None -if importlib.util.find_spec("gym"): - import gym - -if _has_gymnasium: - import gymnasium - @pytest.mark.slow class TestGenDGRL: @@ -310,6 +304,7 @@ def test_d4rl_dummy(self, task): @pytest.mark.parametrize("from_env", [True, False]) def test_dataset_build(self, task, split_trajs, from_env): import d4rl # noqa: F401 + import gym t0 = time.time() data = D4RLExperienceReplay( @@ -670,6 +665,7 @@ def test_local_minari_dataset_loading(self, tmpdir): MINARI_DATASETS_PATH = os.environ.get("MINARI_DATASETS_PATH") os.environ["MINARI_DATASETS_PATH"] = str(tmpdir) try: + import gymnasium import minari from minari import DataCollector diff --git a/test/libs/test_gym.py b/test/libs/test_gym.py index a442ff40005..d7e70c709f8 100644 --- a/test/libs/test_gym.py +++ b/test/libs/test_gym.py @@ -77,9 +77,6 @@ _has_gymnasium = importlib.util.find_spec("gymnasium") is not None _has_minigrid = importlib.util.find_spec("minigrid") is not None -if _has_gymnasium: - import gymnasium - try: from torch.utils._pytree import tree_flatten @@ -1664,6 +1661,7 @@ def _test_resetting_strategies(self, heterogeneous, kwargs): def test_is_from_pixels_simple_env(self): """Test that _is_from_pixels correctly identifies non-pixel environments.""" + # Test with a simple environment that doesn't have pixels class SimpleEnv: def __init__(self): @@ -1681,6 +1679,7 @@ def __init__(self): def test_is_from_pixels_box_env(self): """Test that _is_from_pixels correctly identifies pixel Box environments.""" + # Test with a pixel-like environment class PixelEnv: def __init__(self): @@ -1700,6 +1699,7 @@ def __init__(self): def test_is_from_pixels_dict_env(self): """Test that _is_from_pixels correctly identifies Dict environments with pixels.""" + # Test with a Dict environment that has pixels class DictPixelEnv: def __init__(self): @@ -1724,6 +1724,7 @@ def __init__(self): def test_is_from_pixels_dict_env_no_pixels(self): """Test that _is_from_pixels correctly identifies Dict environments without pixels.""" + # Test with a Dict environment that doesn't have pixels class DictNoPixelEnv: def __init__(self): @@ -1869,6 +1870,7 @@ def mock_isinstance(obj, cls): def test_gymnasium_num_envs(self, num_envs, request): if not _has_gymnasium: pytest.skip("gymnasium not found") + import gymnasium gym_version = version.parse(gymnasium.__version__) if version.parse("1.0.0") <= gym_version < version.parse("1.1.0"): @@ -1904,6 +1906,8 @@ class TestMiniGrid: ], ) def test_minigrid(self, id): + import gymnasium + env_base = gymnasium.make(id) env = GymWrapper(env_base) check_env_specs(env) diff --git a/test/modules/_modules_common.py b/test/modules/_modules_common.py index c6d6ea98b71..e4bf7362a6b 100644 --- a/test/modules/_modules_common.py +++ b/test/modules/_modules_common.py @@ -33,13 +33,6 @@ def _has_triton_backend() -> bool: _has_triton = _has_triton_backend() _triton_skip_reason = "requires triton (>= 2.2) and CUDA" -_has_functorch = False -try: - try: - from torch import vmap as vmap # noqa: F401 - except ImportError: - from functorch import vmap as vmap # noqa: F401 - - _has_functorch = True -except ImportError: - pass +_has_functorch = ( + hasattr(torch, "vmap") or importlib.util.find_spec("functorch") is not None +) diff --git a/test/modules/test_rnn.py b/test/modules/test_rnn.py index fb9aed7ad60..b1eed4230cc 100644 --- a/test/modules/test_rnn.py +++ b/test/modules/test_rnn.py @@ -74,12 +74,19 @@ ) _has_hoptorch = importlib.util.find_spec("hoptorch") is not None +_vmap = None -if _has_functorch: - try: - from torch import vmap - except ImportError: - from functorch import vmap + +def _get_vmap(): + global _vmap + if _vmap is None: + if hasattr(torch, "vmap"): + _vmap = torch.vmap + else: + from functorch import vmap + + _vmap = vmap + return _vmap @pytest.mark.parametrize("device", get_default_devices()) @@ -552,7 +559,6 @@ def test_lstm_parallel_env( def _test_lstm_parallel_env( self, python_based, parallel, heterogeneous, within, maybe_fork_ParallelEnv ): - torch.manual_seed(0) num_envs = 3 device = "cuda" if torch.cuda.device_count() else "cpu" @@ -654,6 +660,8 @@ def test_lstm_parallel_within( not _has_functorch, reason="vmap can only be used with functorch" ) def test_lstm_vmap_complex_model(self): + vmap = _get_vmap() + # Tests that all ops in GRU are compatible with VMAP (when build using # the PT backend). # This used to fail when splitting the input based on the is_init mask. @@ -2512,6 +2520,8 @@ def test_gru_parallel_within( not _has_functorch, reason="vmap can only be used with functorch" ) def test_gru_vmap_complex_model(self): + vmap = _get_vmap() + # Tests that all ops in GRU are compatible with VMAP (when build using # the PT backend). # This used to fail when splitting the input based on the is_init mask. @@ -3286,7 +3296,6 @@ def test_gru_module_scan_non_canonical_hidden_strides(self): def test_get_primers_from_module(): - # No primers in the model module = MLP(in_features=10, out_features=10, num_cells=[]) transform = get_primers_from_module(module) diff --git a/test/objectives/_objectives_common.py b/test/objectives/_objectives_common.py index abaa9ed7e21..a6532b8444e 100644 --- a/test/objectives/_objectives_common.py +++ b/test/objectives/_objectives_common.py @@ -24,16 +24,10 @@ from torchrl.data import Composite, Unbounded from torchrl.envs import EnvBase -_has_functorch = True -try: - import functorch as ft # noqa - - make_functional_with_buffers = ft.make_functional_with_buffers - FUNCTORCH_ERR = "" -except ImportError as err: - _has_functorch = False - FUNCTORCH_ERR = str(err) - make_functional_with_buffers = None +_has_functorch = ( + hasattr(torch, "vmap") or importlib.util.find_spec("functorch") is not None +) +FUNCTORCH_ERR = "" _has_transformers = bool(importlib.util.find_spec("transformers")) _has_botorch = bool(importlib.util.find_spec("botorch")) diff --git a/test/objectives/test_ppo.py b/test/objectives/test_ppo.py index d1666b8dc1f..64f38d45699 100644 --- a/test/objectives/test_ppo.py +++ b/test/objectives/test_ppo.py @@ -17,7 +17,6 @@ _has_transformers, FUNCTORCH_ERR, LossModuleTestBase, - make_functional_with_buffers, MARLEnv, ) @@ -1175,7 +1174,6 @@ def test_ppo_notensordict( loss_val = loss(**kwargs) torch.manual_seed(self.seed) if beta is not None: - loss.beta = beta.clone() loss_val_td = loss(td) @@ -2163,6 +2161,8 @@ def test_a2c_diff(self, device, gradient_mode, advantage, composite_action_dist) loss_fn = A2CLoss(actor, value, loss_critic_type="l2") + from functorch import make_functional_with_buffers + floss_fn, params, buffers = make_functional_with_buffers(loss_fn) if advantage is not None: diff --git a/test/test_helpers.py b/test/test_helpers.py index 61367727426..6b3efe71077 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -6,6 +6,7 @@ import argparse import dataclasses +import importlib.util import pathlib import sys from time import sleep @@ -44,20 +45,26 @@ make_dqn_actor, ) -try: - from hydra import compose, initialize - from hydra.core.config_store import ConfigStore +_has_hydra = importlib.util.find_spec("hydra") is not None +_hydra_deps = None - _has_hydra = True - @pytest.fixture(autouse=True, scope="module") - def clear_hydra(): +def _get_hydra_deps(): + global _hydra_deps + if _hydra_deps is None: + from hydra import compose, initialize + from hydra.core.config_store import ConfigStore from hydra.core.global_hydra import GlobalHydra - GlobalHydra.instance().clear() + _hydra_deps = compose, initialize, ConfigStore, GlobalHydra + return _hydra_deps + + +@pytest.fixture(scope="module") +def clear_hydra(): + *_, GlobalHydra = _get_hydra_deps() + GlobalHydra.instance().clear() -except ImportError: - _has_hydra = False TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) if TORCH_VERSION < version.parse("1.12.0"): @@ -73,7 +80,6 @@ def clear_hydra(): @pytest.fixture def dreamer_constructor_fixture(): - # we hack the env constructor sys.path.append( str(pathlib.Path(__file__).parent.parent / "sota-implementations" / "dreamer") @@ -88,6 +94,7 @@ def dreamer_constructor_fixture(): @pytest.mark.skipif(not _has_gym, reason="No gym library found") @pytest.mark.skipif(not _has_tv, reason="No torchvision library found") @pytest.mark.skipif(not _has_hydra, reason="No hydra library found") +@pytest.mark.usefixtures("clear_hydra") @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("noisy", [(), ("noisy=True",)]) @pytest.mark.parametrize("distributional", [(), ("distributional=True",)]) @@ -99,6 +106,7 @@ def dreamer_constructor_fixture(): def test_dqn_maker( device, noisy, distributional, from_pixels, categorical_action_encoding ): + compose, initialize, ConfigStore, _ = _get_hydra_deps() flags = list(noisy + distributional + from_pixels + categorical_action_encoding) + [ "env_name=CartPole-v1" ] @@ -218,9 +226,10 @@ def test_timeit(): @pytest.mark.skipif(not _has_hydra, reason="No hydra library found") +@pytest.mark.usefixtures("clear_hydra") @pytest.mark.parametrize("from_pixels", [(), ("from_pixels=True", "catframes=4")]) def test_transformed_env_constructor_with_state_dict(from_pixels): - + compose, initialize, ConfigStore, _ = _get_hydra_deps() config_fields = [ (config_field.name, config_field.type, config_field) for config_cls in ( diff --git a/test/test_inference_server.py b/test/test_inference_server.py index 683250e9f6a..4327874c704 100644 --- a/test/test_inference_server.py +++ b/test/test_inference_server.py @@ -5,6 +5,7 @@ from __future__ import annotations import concurrent.futures +import importlib.util import multiprocessing as mp import threading import time @@ -28,17 +29,18 @@ ) from torchrl.modules.inference_server._monarch import MonarchTransport -_has_ray = True -try: - import ray -except ImportError: - _has_ray = False +_has_ray = importlib.util.find_spec("ray") is not None +_has_monarch = importlib.util.find_spec("monarch") is not None +_ray = None -_has_monarch = True -try: - import monarch # noqa: F401 -except ImportError: - _has_monarch = False + +def _ray_lib(): + global _ray + if _ray is None: + import ray + + _ray = ray + return _ray # ============================================================================= @@ -421,6 +423,7 @@ def bad_model(td): class TestRayTransport: @classmethod def setup_class(cls): + ray = _ray_lib() if not ray.is_initialized(): ray.init(num_cpus=4, ignore_reinit_error=True) @@ -465,6 +468,7 @@ def client_fn(client_idx): def test_ray_remote_actor(self): """A Ray remote actor can use the client to get inference results.""" + ray = _ray_lib() transport = RayTransport() client = transport.client() policy = _make_policy() diff --git a/test/test_trainer.py b/test/test_trainer.py index 4f9ed611815..e520c9ee773 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -18,13 +18,7 @@ import torch from torch import nn -try: - from tensorboard.backend.event_processing import event_accumulator - from torchrl.record.loggers.tensorboard import TensorboardLogger - - _has_tb = True -except ImportError: - _has_tb = False +_has_tb = importlib.util.find_spec("tensorboard") is not None from tensordict import TensorDict from torchrl.data import ( @@ -901,6 +895,9 @@ def _get_args(self): return args def test_recorder(self, N=8): + from tensorboard.backend.event_processing import event_accumulator + from torchrl.record.loggers.tensorboard import TensorboardLogger + args = self._get_args() with tempfile.TemporaryDirectory() as folder: logger = TensorboardLogger(exp_name=folder) @@ -926,7 +923,7 @@ def test_recorder(self, N=8): for _ in range(N): recorder(None) - for (_, _, filenames) in walk(folder): + for _, _, filenames in walk(folder): filename = filenames[0] break @@ -954,6 +951,8 @@ def test_recorder(self, N=8): ], ) def test_recorder_load(self, backend, N=8): + from torchrl.record.loggers.tensorboard import TensorboardLogger + if not _has_ts and backend == "torchsnapshot": pytest.skip("torchsnapshot not found") @@ -1089,7 +1088,10 @@ def _make_countframe_and_trainer(tmpdirname): frame_skip = 3 batch = 10 - with tempfile.TemporaryDirectory() as tmpdirname, tempfile.TemporaryDirectory() as tmpdirname2: + with ( + tempfile.TemporaryDirectory() as tmpdirname, + tempfile.TemporaryDirectory() as tmpdirname2, + ): trainer, count_frames, file = _make_countframe_and_trainer(tmpdirname) td = TensorDict( {