Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,24 @@ jobs:
# The --extra-index-url onto PyPI is required: torch nightly pulls in
# transitive deps (e.g. spmd-types) that are only shipped as sdists on the
# torch channel, and building those sdists needs setuptools/wheel which the
# torch index does not host. torch/torchvision still resolve from nightly
# (their dev versions outrank any PyPI stable), and assert_torch_version.sh
# below fails the job loudly if that ever stops holding.
python3.10 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu126 --extra-index-url https://pypi.org/simple -U
python3.10 -m pip install ninja pytest pytest-benchmark pytest-timeout "hoptorch>=0.1.4" "mujoco>=3.8.1,<3.9.0" "dm_control>=1.0.41" "gym[accept-rom-license,atari]" transformers accelerate
python -m pip install "pybind11[global]"
python3.10 -m pip install cloudpickle packaging importlib_metadata numpy orjson "pyvers>=0.2.0,<0.3.0"
# torch index does not host. Install torch separately so torchvision's
# exact torch dependency cannot make pip backtrack into PyPI stable
# torch. Then install nightly torchvision without dependencies and
# constrain later dependency installs so PyPI stable releases cannot
# upgrade torch/torchvision before the version assertion.
python3.10 -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --extra-index-url https://pypi.org/simple -U
python3.10 -m pip install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu126 --no-deps -U
python3.10 - <<'PY' > /tmp/torch-constraints.txt
from importlib.metadata import version

print(f"torch=={version('torch')}")
print(f"torchvision=={version('torchvision')}")
PY
python3.10 -m pip install -c /tmp/torch-constraints.txt ninja pytest pytest-benchmark pytest-timeout "hoptorch>=0.1.4" "mujoco>=3.8.1,<3.9.0" "dm_control>=1.0.41" "gym[accept-rom-license,atari]" transformers accelerate
python -m pip install -c /tmp/torch-constraints.txt "pybind11[global]"
python3.10 -m pip install -c /tmp/torch-constraints.txt cloudpickle packaging importlib_metadata numpy orjson "pyvers>=0.2.0,<0.3.0"
python3.10 -m pip install --no-deps git+https://github.com/pytorch/tensordict
python3.10 -m pip install safetensors tqdm pandas numpy matplotlib ray
python3.10 -m pip install -c /tmp/torch-constraints.txt safetensors tqdm pandas numpy matplotlib ray
python3.10 -m pip install -e . --no-build-isolation --no-deps

bash .github/unittest/helpers/assert_torch_version.sh nightly
Expand Down
25 changes: 17 additions & 8 deletions .github/workflows/benchmarks_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,24 @@ jobs:
# The --extra-index-url onto PyPI is required: torch nightly pulls in
# transitive deps (e.g. spmd-types) that are only shipped as sdists on the
# torch channel, and building those sdists needs setuptools/wheel which the
# torch index does not host. torch/torchvision still resolve from nightly
# (their dev versions outrank any PyPI stable), and assert_torch_version.sh
# below fails the job loudly if that ever stops holding.
python3.10 -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu126 --extra-index-url https://pypi.org/simple -U
python3.10 -m pip install ninja pytest pytest-benchmark pytest-timeout "hoptorch>=0.1.4" "mujoco>=3.8.1,<3.9.0" "dm_control>=1.0.41" "gym[accept-rom-license,atari]" transformers accelerate ray
python3.10 -m pip install "pybind11[global]"
python3.10 -m pip install cloudpickle packaging importlib_metadata numpy orjson "pyvers>=0.2.0,<0.3.0"
# torch index does not host. Install torch separately so torchvision's
# exact torch dependency cannot make pip backtrack into PyPI stable
# torch. Then install nightly torchvision without dependencies and
# constrain later dependency installs so PyPI stable releases cannot
# upgrade torch/torchvision before the version assertion.
python3.10 -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --extra-index-url https://pypi.org/simple -U
python3.10 -m pip install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu126 --no-deps -U
python3.10 - <<'PY' > /tmp/torch-constraints.txt
from importlib.metadata import version

print(f"torch=={version('torch')}")
print(f"torchvision=={version('torchvision')}")
PY
python3.10 -m pip install -c /tmp/torch-constraints.txt ninja pytest pytest-benchmark pytest-timeout "hoptorch>=0.1.4" "mujoco>=3.8.1,<3.9.0" "dm_control>=1.0.41" "gym[accept-rom-license,atari]" transformers accelerate ray
python3.10 -m pip install -c /tmp/torch-constraints.txt "pybind11[global]"
python3.10 -m pip install -c /tmp/torch-constraints.txt cloudpickle packaging importlib_metadata numpy orjson "pyvers>=0.2.0,<0.3.0"
python3.10 -m pip install --no-deps git+https://github.com/pytorch/tensordict
python3.10 -m pip install safetensors tqdm pandas numpy matplotlib
python3.10 -m pip install -c /tmp/torch-constraints.txt safetensors tqdm pandas numpy matplotlib

bash .github/unittest/helpers/assert_torch_version.sh nightly
bash .github/unittest/helpers/assert_torch_tensordict_versions.sh nightly
Expand Down
31 changes: 16 additions & 15 deletions test/libs/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
from packaging import version

from torchrl.envs.libs.gym import _has_gym, gym_backend, set_gym_backend
from torchrl.envs.libs.gym import gym_backend, set_gym_backend

pytestmark = [
pytest.mark.filterwarnings("error"),
Expand Down Expand Up @@ -51,20 +51,7 @@

_has_gymnasium = importlib.util.find_spec("gymnasium") is not None

_has_isaaclab = importlib.util.find_spec("isaaclab") is not None

_has_gym_regular = importlib.util.find_spec("gym") is not None
if _has_gymnasium:
set_gym_backend("gymnasium").set()
import gymnasium

assert gym_backend() is gymnasium
elif _has_gym:
set_gym_backend("gym").set()
import gym

assert gym_backend() is gym

_has_meltingpot = importlib.util.find_spec("meltingpot") is not None

_has_minigrid = importlib.util.find_spec("minigrid") is not None
Expand All @@ -73,7 +60,21 @@


@pytest.fixture(scope="session", autouse=True)
def maybe_init_minigrid():
def _setup_gym_backend():
if _has_gymnasium:
set_gym_backend("gymnasium").set()
import gymnasium

assert gym_backend() is gymnasium
elif _has_gym_regular:
set_gym_backend("gym").set()
import gym

assert gym_backend() is gym


@pytest.fixture(scope="session", autouse=True)
def maybe_init_minigrid(_setup_gym_backend):
if _has_minigrid and _has_gymnasium:
import minigrid

Expand Down
8 changes: 2 additions & 6 deletions test/libs/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,6 @@
_has_minari = importlib.util.find_spec("minari") is not None
_has_gymnasium = importlib.util.find_spec("gymnasium") is not None

if importlib.util.find_spec("gym"):
import gym

if _has_gymnasium:
import gymnasium


@pytest.mark.slow
class TestGenDGRL:
Expand Down Expand Up @@ -310,6 +304,7 @@ def test_d4rl_dummy(self, task):
@pytest.mark.parametrize("from_env", [True, False])
def test_dataset_build(self, task, split_trajs, from_env):
import d4rl # noqa: F401
import gym

t0 = time.time()
data = D4RLExperienceReplay(
Expand Down Expand Up @@ -670,6 +665,7 @@ def test_local_minari_dataset_loading(self, tmpdir):
MINARI_DATASETS_PATH = os.environ.get("MINARI_DATASETS_PATH")
os.environ["MINARI_DATASETS_PATH"] = str(tmpdir)
try:
import gymnasium
import minari
from minari import DataCollector

Expand Down
10 changes: 7 additions & 3 deletions test/libs/test_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@
_has_gymnasium = importlib.util.find_spec("gymnasium") is not None
_has_minigrid = importlib.util.find_spec("minigrid") is not None

if _has_gymnasium:
import gymnasium

try:
from torch.utils._pytree import tree_flatten

Expand Down Expand Up @@ -1664,6 +1661,7 @@ def _test_resetting_strategies(self, heterogeneous, kwargs):

def test_is_from_pixels_simple_env(self):
"""Test that _is_from_pixels correctly identifies non-pixel environments."""

# Test with a simple environment that doesn't have pixels
class SimpleEnv:
def __init__(self):
Expand All @@ -1681,6 +1679,7 @@ def __init__(self):

def test_is_from_pixels_box_env(self):
"""Test that _is_from_pixels correctly identifies pixel Box environments."""

# Test with a pixel-like environment
class PixelEnv:
def __init__(self):
Expand All @@ -1700,6 +1699,7 @@ def __init__(self):

def test_is_from_pixels_dict_env(self):
"""Test that _is_from_pixels correctly identifies Dict environments with pixels."""

# Test with a Dict environment that has pixels
class DictPixelEnv:
def __init__(self):
Expand All @@ -1724,6 +1724,7 @@ def __init__(self):

def test_is_from_pixels_dict_env_no_pixels(self):
"""Test that _is_from_pixels correctly identifies Dict environments without pixels."""

# Test with a Dict environment that doesn't have pixels
class DictNoPixelEnv:
def __init__(self):
Expand Down Expand Up @@ -1869,6 +1870,7 @@ def mock_isinstance(obj, cls):
def test_gymnasium_num_envs(self, num_envs, request):
if not _has_gymnasium:
pytest.skip("gymnasium not found")
import gymnasium

gym_version = version.parse(gymnasium.__version__)
if version.parse("1.0.0") <= gym_version < version.parse("1.1.0"):
Expand Down Expand Up @@ -1904,6 +1906,8 @@ class TestMiniGrid:
],
)
def test_minigrid(self, id):
import gymnasium

env_base = gymnasium.make(id)
env = GymWrapper(env_base)
check_env_specs(env)
Expand Down
13 changes: 3 additions & 10 deletions test/modules/_modules_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,6 @@ def _has_triton_backend() -> bool:
_has_triton = _has_triton_backend()
_triton_skip_reason = "requires triton (>= 2.2) and CUDA"

_has_functorch = False
try:
try:
from torch import vmap as vmap # noqa: F401
except ImportError:
from functorch import vmap as vmap # noqa: F401

_has_functorch = True
except ImportError:
pass
_has_functorch = (
hasattr(torch, "vmap") or importlib.util.find_spec("functorch") is not None
)
23 changes: 16 additions & 7 deletions test/modules/test_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,19 @@
)

_has_hoptorch = importlib.util.find_spec("hoptorch") is not None
_vmap = None

if _has_functorch:
try:
from torch import vmap
except ImportError:
from functorch import vmap

def _get_vmap():
global _vmap
if _vmap is None:
if hasattr(torch, "vmap"):
_vmap = torch.vmap
else:
from functorch import vmap

_vmap = vmap
return _vmap


@pytest.mark.parametrize("device", get_default_devices())
Expand Down Expand Up @@ -552,7 +559,6 @@ def test_lstm_parallel_env(
def _test_lstm_parallel_env(
self, python_based, parallel, heterogeneous, within, maybe_fork_ParallelEnv
):

torch.manual_seed(0)
num_envs = 3
device = "cuda" if torch.cuda.device_count() else "cpu"
Expand Down Expand Up @@ -654,6 +660,8 @@ def test_lstm_parallel_within(
not _has_functorch, reason="vmap can only be used with functorch"
)
def test_lstm_vmap_complex_model(self):
vmap = _get_vmap()

# Tests that all ops in GRU are compatible with VMAP (when build using
# the PT backend).
# This used to fail when splitting the input based on the is_init mask.
Expand Down Expand Up @@ -2512,6 +2520,8 @@ def test_gru_parallel_within(
not _has_functorch, reason="vmap can only be used with functorch"
)
def test_gru_vmap_complex_model(self):
vmap = _get_vmap()

# Tests that all ops in GRU are compatible with VMAP (when build using
# the PT backend).
# This used to fail when splitting the input based on the is_init mask.
Expand Down Expand Up @@ -3286,7 +3296,6 @@ def test_gru_module_scan_non_canonical_hidden_strides(self):


def test_get_primers_from_module():

# No primers in the model
module = MLP(in_features=10, out_features=10, num_cells=[])
transform = get_primers_from_module(module)
Expand Down
14 changes: 4 additions & 10 deletions test/objectives/_objectives_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,10 @@
from torchrl.data import Composite, Unbounded
from torchrl.envs import EnvBase

_has_functorch = True
try:
import functorch as ft # noqa

make_functional_with_buffers = ft.make_functional_with_buffers
FUNCTORCH_ERR = ""
except ImportError as err:
_has_functorch = False
FUNCTORCH_ERR = str(err)
make_functional_with_buffers = None
_has_functorch = (
hasattr(torch, "vmap") or importlib.util.find_spec("functorch") is not None
)
FUNCTORCH_ERR = ""

_has_transformers = bool(importlib.util.find_spec("transformers"))
_has_botorch = bool(importlib.util.find_spec("botorch"))
Expand Down
4 changes: 2 additions & 2 deletions test/objectives/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
_has_transformers,
FUNCTORCH_ERR,
LossModuleTestBase,
make_functional_with_buffers,
MARLEnv,
)

Expand Down Expand Up @@ -1175,7 +1174,6 @@ def test_ppo_notensordict(
loss_val = loss(**kwargs)
torch.manual_seed(self.seed)
if beta is not None:

loss.beta = beta.clone()
loss_val_td = loss(td)

Expand Down Expand Up @@ -2163,6 +2161,8 @@ def test_a2c_diff(self, device, gradient_mode, advantage, composite_action_dist)

loss_fn = A2CLoss(actor, value, loss_critic_type="l2")

from functorch import make_functional_with_buffers

floss_fn, params, buffers = make_functional_with_buffers(loss_fn)

if advantage is not None:
Expand Down
Loading
Loading