Skip to content

Commit f3275da

Browse files
author
Vincent Moens
committed
[CI, BugFix] Py3.8 for old deps
ghstack-source-id: 13c7923 Pull Request resolved: #2568
1 parent a4c1ee3 commit f3275da

File tree

6 files changed

+58
-15
lines changed

6 files changed

+58
-15
lines changed

.github/unittest/linux_olddeps/scripts_gym_0_13/install.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ printf "Installing PyTorch with %s\n" "${CU_VERSION}"
3939
if [ "${CU_VERSION:-}" == cpu ] ; then
4040
conda install pytorch==2.0 torchvision==0.15 cpuonly -c pytorch -y
4141
else
42-
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 numpy==1.26 -c pytorch -c nvidia -y
42+
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 "numpy<2.0" -c pytorch -c nvidia -y
4343
fi
4444

4545
# Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has

.github/workflows/test-linux.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,11 @@ jobs:
128128
with:
129129
repository: pytorch/rl
130130
runner: "linux.g5.4xlarge.nvidia.gpu"
131-
# gpu-arch-type: cuda
132-
# gpu-arch-version: "11.7"
133131
docker-image: "nvidia/cudagl:11.4.0-base"
134132
timeout: 120
135133
script: |
136134
set -euo pipefail
137-
export PYTHON_VERSION="3.9"
135+
export PYTHON_VERSION="3.8"
138136
export CU_VERSION="cu116"
139137
export TAR_OPTIONS="--no-same-owner"
140138
if [[ "${{ github.ref }}" =~ release/* ]]; then

test/test_distributions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import argparse
77
import importlib.util
88
import os
9+
from typing import Tuple
910

1011
import pytest
1112
import torch
@@ -685,7 +686,7 @@ class TestOrdinal:
685686
@pytest.mark.parametrize("device", get_default_devices())
686687
@pytest.mark.parametrize("logit_shape", [(10,), (1, 1), (10, 10), (5, 10, 20)])
687688
def test_correct_sampling_shape(
688-
self, logit_shape: tuple[int, ...], dtype: torch.dtype, device: str
689+
self, logit_shape: Tuple[int, ...], dtype: torch.dtype, device: str
689690
) -> None:
690691
logits = torch.testing.make_tensor(logit_shape, dtype=dtype, device=device)
691692

@@ -753,7 +754,7 @@ class TestOneHotOrdinal:
753754
@pytest.mark.parametrize("device", get_default_devices())
754755
@pytest.mark.parametrize("logit_shape", [(10,), (10, 10), (5, 10, 20)])
755756
def test_correct_sampling_shape(
756-
self, logit_shape: tuple[int, ...], dtype: torch.dtype, device: str
757+
self, logit_shape: Tuple[int, ...], dtype: torch.dtype, device: str
757758
) -> None:
758759
logits = torch.testing.make_tensor(logit_shape, dtype=dtype, device=device)
759760

test/test_transforms.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2089,10 +2089,17 @@ def make_env(max_steps=4):
20892089
total_frames=99,
20902090
frames_per_batch=8,
20912091
)
2092-
for d in collector:
2093-
# The env has one more traj because the collector calls reset during init
2094-
assert d["collector", "traj_ids"].max() == d["next", "traj_count"].max() - 1
2095-
assert d["traj_count"].max() > 0
2092+
2093+
try:
2094+
traj_ids_collector = []
2095+
traj_ids_env = []
2096+
for d in collector:
2097+
traj_ids_collector.extend(d["collector", "traj_ids"].view(-1).tolist())
2098+
traj_ids_env.extend(d["next", "traj_count"].view(-1).tolist())
2099+
assert len(set(traj_ids_env)) == len(set(traj_ids_collector))
2100+
finally:
2101+
collector.shutdown()
2102+
del collector
20962103

20972104
def test_transform_compose(self):
20982105
t = TrajCounter()

torchrl/data/tensor_specs.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
)
4343
from tensordict.base import NO_DEFAULT
4444
from tensordict.utils import _getitem_batch_size, NestedKey
45-
from torchrl._utils import _make_ordinal_device, get_binary_env_var
45+
from torchrl._utils import _make_ordinal_device, get_binary_env_var, implement_for
4646

4747
DEVICE_TYPING = Union[torch.device, str, int]
4848

@@ -193,14 +193,14 @@ def _slice_indexing(shape: list[int], idx: slice) -> List[int]:
193193

194194

195195
def _shape_indexing(
196-
shape: Union[list[int], torch.Size, tuple[int]], idx: SHAPE_INDEX_TYPING
196+
shape: Union[list[int], torch.Size, Tuple[int]], idx: SHAPE_INDEX_TYPING
197197
) -> List[int]:
198198
"""Given an input shape and an index, returns the size of the resulting indexed spec.
199199
200200
This function includes indexing checks and may raise IndexErrors.
201201
202202
Args:
203-
shape (list[int], torch.Size, tuple[int): Input shape
203+
shape (list[int], torch.Size, Tuple[int): Input shape
204204
idx (SHAPE_INDEX_TYPING): Index
205205
Returns:
206206
Shape of the resulting spec
@@ -1020,7 +1020,7 @@ def unbind(self, dim: int = 0):
10201020

10211021

10221022
class _LazyStackedMixin(Generic[T]):
1023-
def __init__(self, *specs: tuple[T, ...], dim: int) -> None:
1023+
def __init__(self, *specs: Tuple[T, ...], dim: int) -> None:
10241024
self._specs = list(specs)
10251025
self.dim = dim
10261026
if self.dim < 0:
@@ -1682,7 +1682,31 @@ def unbind(self, dim: int = 0):
16821682
for i in range(self.shape[dim])
16831683
)
16841684

1685+
@implement_for("torch", None, "2.1")
16851686
def rand(self, shape: torch.Size = None) -> torch.Tensor:
1687+
if shape is None:
1688+
shape = self.shape[:-1]
1689+
else:
1690+
shape = _size([*shape, *self.shape[:-1]])
1691+
mask = self.mask
1692+
n = int(self.space.n)
1693+
if mask is None:
1694+
m = torch.randint(n, shape, device=self.device)
1695+
else:
1696+
mask = mask.expand(_remove_neg_shapes(*shape, mask.shape[-1]))
1697+
if mask.ndim > 2:
1698+
mask_flat = torch.flatten(mask, 0, -2)
1699+
else:
1700+
mask_flat = mask
1701+
shape_out = mask.shape[:-1]
1702+
m = torch.multinomial(mask_flat.float(), 1).reshape(shape_out)
1703+
out = torch.nn.functional.one_hot(m, n).to(self.dtype)
1704+
# torch.zeros((*shape, self.space.n), device=self.device, dtype=self.dtype)
1705+
# out.scatter_(-1, m, 1)
1706+
return out
1707+
1708+
@implement_for("torch", "2.1")
1709+
def rand(self, shape: torch.Size = None) -> torch.Tensor: # noqa: F811
16861710
if shape is None:
16871711
shape = self.shape[:-1]
16881712
else:

torchrl/envs/transforms/transforms.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,13 @@
5252
from torch import nn, Tensor
5353
from torch.utils._pytree import tree_map
5454

55-
from torchrl._utils import _append_last, _ends_with, _make_ordinal_device, _replace_last
55+
from torchrl._utils import (
56+
_append_last,
57+
_ends_with,
58+
_make_ordinal_device,
59+
_replace_last,
60+
implement_for,
61+
)
5662

5763
from torchrl.data.tensor_specs import (
5864
Binary,
@@ -8772,7 +8778,14 @@ def __init__(self, out_key: NestedKey = "traj_count"):
87728778
def _make_shared_value(self):
87738779
self._traj_count = mp.Value("i", 0)
87748780

8781+
@implement_for("torch", None, "2.1")
87758782
def __getstate__(self):
8783+
state = self.__dict__.copy()
8784+
state["_traj_count"] = None
8785+
return state
8786+
8787+
@implement_for("torch", "2.1")
8788+
def __getstate__(self): # noqa: F811
87768789
state = super().__getstate__()
87778790
state["_traj_count"] = None
87788791
return state

0 commit comments

Comments
 (0)