Skip to content

Commit 336dc98

Browse files
authored
[Refactor] Box device (#881)
1 parent c5493ec commit 336dc98

File tree

5 files changed

+55
-10
lines changed

5 files changed

+55
-10
lines changed

test/_utils_internal.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import contextlib
67
import os
78
import time
89
from functools import wraps
@@ -104,3 +105,13 @@ def dtype_fixture():
104105
torch.set_default_dtype(torch.double)
105106
yield dtype
106107
torch.set_default_dtype(dtype)
108+
109+
110+
@contextlib.contextmanager
111+
def set_global_var(module, var_name, value):
112+
old_value = getattr(module, var_name)
113+
setattr(module, var_name, value)
114+
try:
115+
yield
116+
finally:
117+
setattr(module, var_name, old_value)

test/test_specs.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import numpy as np
88
import pytest
99
import torch
10-
from _utils_internal import get_available_devices
10+
import torchrl.data.tensor_specs
11+
from _utils_internal import get_available_devices, set_global_var
1112
from scipy.stats import chisquare
1213
from tensordict.tensordict import TensorDict, TensorDictBase
1314
from torchrl.data.tensor_specs import (
@@ -57,8 +58,11 @@ def test_discrete(cls):
5758
ts.encode(torch.tensor([5]))
5859
ts.encode(torch.tensor(5).numpy())
5960
ts.encode(9)
60-
with pytest.raises(AssertionError):
61+
with pytest.raises(AssertionError), set_global_var(
62+
torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True
63+
):
6164
ts.encode(torch.tensor([11])) # out of bounds
65+
assert not torchrl.data.tensor_specs._CHECK_SPEC_ENCODE
6266
assert ts.is_in(r)
6367
assert (ts.encode(ts.to_numpy(r)) == r).all()
6468

@@ -114,10 +118,15 @@ def test_ndbounded(dtype, shape):
114118
ts.encode(lb + torch.rand(10) * (ub - lb))
115119
ts.encode((lb + torch.rand(10) * (ub - lb)).numpy())
116120
assert (ts.encode(ts.to_numpy(r)) == r).all()
117-
with pytest.raises(AssertionError):
121+
with pytest.raises(AssertionError), set_global_var(
122+
torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True
123+
):
118124
ts.encode(torch.rand(10) + 3) # out of bounds
119-
with pytest.raises(AssertionError):
125+
with pytest.raises(AssertionError), set_global_var(
126+
torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True
127+
):
120128
ts.to_numpy(torch.rand(10) + 3) # out of bounds
129+
assert not torchrl.data.tensor_specs._CHECK_SPEC_ENCODE
121130

122131

123132
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None])

torchrl/data/tensor_specs.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535

3636
INDEX_TYPING = Union[int, torch.Tensor, np.ndarray, slice, List]
3737

38-
_NO_CHECK_SPEC_ENCODE = get_binary_env_var("NO_CHECK_SPEC_ENCODE")
38+
# By default, we do not check that an obs is in the domain. THis should be done when validating the env beforehand
39+
_CHECK_SPEC_ENCODE = get_binary_env_var("CHECK_SPEC_ENCODE")
40+
3941

4042
_DEFAULT_SHAPE = torch.Size((1,))
4143

@@ -108,8 +110,28 @@ def clone(self) -> DiscreteBox:
108110
class ContinuousBox(Box):
109111
"""A continuous box of values, in between a minimum and a maximum."""
110112

111-
minimum: torch.Tensor
112-
maximum: torch.Tensor
113+
_minimum: torch.Tensor
114+
_maximum: torch.Tensor
115+
device: torch.device = None
116+
117+
# We store the tensors on CPU to avoid overloading CUDA with tensors that are rarely used.
118+
@property
119+
def minimum(self):
120+
return self._minimum.to(self.device)
121+
122+
@property
123+
def maximum(self):
124+
return self._maximum.to(self.device)
125+
126+
@minimum.setter
127+
def minimum(self, value):
128+
self.device = value.device
129+
self._minimum = value.cpu()
130+
131+
@maximum.setter
132+
def maximum(self, value):
133+
self.device = value.device
134+
self._maximum = value.cpu()
113135

114136
def __post_init__(self):
115137
self.minimum = self.minimum.clone()
@@ -257,7 +279,7 @@ def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
257279
f"Shape mismatch: the value has shape {val.shape} which "
258280
f"is incompatible with the spec shape {self.shape}."
259281
)
260-
if not _NO_CHECK_SPEC_ENCODE:
282+
if _CHECK_SPEC_ENCODE:
261283
self.assert_is_in(val)
262284
return val
263285

torchrl/envs/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ class EnvBase(nn.Module, metaclass=abc.ABCMeta):
210210
- run_type_checks (bool): if True, the observation and reward dtypes
211211
will be compared against their respective spec and an exception
212212
will be raised if they don't match.
213+
Defaults to False.
213214
214215
Methods:
215216
step (TensorDictBase -> TensorDictBase): step in the environment
@@ -226,7 +227,7 @@ def __init__(
226227
device: DEVICE_TYPING = "cpu",
227228
dtype: Optional[Union[torch.dtype, np.dtype]] = None,
228229
batch_size: Optional[torch.Size] = None,
229-
run_type_checks: bool = True,
230+
run_type_checks: bool = False,
230231
):
231232
super().__init__()
232233
if device is not None:

torchrl/envs/transforms/transforms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2738,7 +2738,9 @@ def transform_observation_spec(
27382738
dtype=torch.int64,
27392739
device=observation_spec.device,
27402740
)
2741-
observation_spec["step_count"].space.minimum = 0
2741+
observation_spec["step_count"].space.minimum = (
2742+
observation_spec["step_count"].space.minimum * 0
2743+
)
27422744
return observation_spec
27432745

27442746

0 commit comments

Comments
 (0)