Skip to content

Commit bb6f87a

Browse files
author
Vincent Moens
committed
[BugFix] Remove raisers in specs
ghstack-source-id: a005a62 Pull Request resolved: #2651
1 parent 9e2d214 commit bb6f87a

File tree

1 file changed

+0
-8
lines changed

1 file changed

+0
-8
lines changed

torchrl/data/tensor_specs.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,7 +1402,6 @@ def type_check(self, value: torch.Tensor, key: NestedKey | None = None) -> None:
14021402
spec.type_check(val)
14031403

14041404
def is_in(self, value) -> bool:
1405-
raise RuntimeError
14061405
if self.dim == 0 and not hasattr(value, "unbind"):
14071406
# We don't use unbind because value could be a tuple or a nested tensor
14081407
return all(
@@ -1834,7 +1833,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
18341833
return val
18351834

18361835
def is_in(self, val: torch.Tensor) -> bool:
1837-
raise RuntimeError
18381836
if self.mask is None:
18391837
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
18401838
shape_match = val.shape == shape
@@ -2288,7 +2286,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
22882286
return val
22892287

22902288
def is_in(self, val: torch.Tensor) -> bool:
2291-
raise RuntimeError
22922289
val_shape = _remove_neg_shapes(tensordict.utils._shape(val))
22932290
shape = torch.broadcast_shapes(self._safe_shape, val_shape)
22942291
shape = list(shape)
@@ -2489,7 +2486,6 @@ def one(self, shape=None):
24892486
)
24902487

24912488
def is_in(self, val: Any) -> bool:
2492-
raise RuntimeError
24932489
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
24942490
return (
24952491
is_non_tensor(val)
@@ -2682,7 +2678,6 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
26822678
return torch.empty(shape, device=self.device, dtype=self.dtype).random_()
26832679

26842680
def is_in(self, val: torch.Tensor) -> bool:
2685-
raise RuntimeError
26862681
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
26872682
return val.shape == shape and val.dtype == self.dtype
26882683

@@ -3034,7 +3029,6 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten
30343029
return torch.cat(out, -1)
30353030

30363031
def is_in(self, val: torch.Tensor) -> bool:
3037-
raise RuntimeError
30383032
vals = self._split(val)
30393033
if vals is None:
30403034
return False
@@ -3435,7 +3429,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
34353429
return val
34363430

34373431
def is_in(self, val: torch.Tensor) -> bool:
3438-
raise RuntimeError
34393432
if self.mask is None:
34403433
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
34413434
shape_match = val.shape == shape
@@ -4066,7 +4059,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
40664059
return val.squeeze(0) if val_is_scalar else val
40674060

40684061
def is_in(self, val: torch.Tensor) -> bool:
4069-
raise RuntimeError
40704062
if self.mask is not None:
40714063
vals = val.unbind(-1)
40724064
splits = self._split_self()

0 commit comments

Comments
 (0)