@@ -1402,7 +1402,6 @@ def type_check(self, value: torch.Tensor, key: NestedKey | None = None) -> None:
1402
1402
spec .type_check (val )
1403
1403
1404
1404
def is_in (self , value ) -> bool :
1405
- raise RuntimeError
1406
1405
if self .dim == 0 and not hasattr (value , "unbind" ):
1407
1406
# We don't use unbind because value could be a tuple or a nested tensor
1408
1407
return all (
@@ -1834,7 +1833,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
1834
1833
return val
1835
1834
1836
1835
def is_in (self , val : torch .Tensor ) -> bool :
1837
- raise RuntimeError
1838
1836
if self .mask is None :
1839
1837
shape = torch .broadcast_shapes (self ._safe_shape , val .shape )
1840
1838
shape_match = val .shape == shape
@@ -2288,7 +2286,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
2288
2286
return val
2289
2287
2290
2288
def is_in (self , val : torch .Tensor ) -> bool :
2291
- raise RuntimeError
2292
2289
val_shape = _remove_neg_shapes (tensordict .utils ._shape (val ))
2293
2290
shape = torch .broadcast_shapes (self ._safe_shape , val_shape )
2294
2291
shape = list (shape )
@@ -2489,7 +2486,6 @@ def one(self, shape=None):
2489
2486
)
2490
2487
2491
2488
def is_in (self , val : Any ) -> bool :
2492
- raise RuntimeError
2493
2489
shape = torch .broadcast_shapes (self ._safe_shape , val .shape )
2494
2490
return (
2495
2491
is_non_tensor (val )
@@ -2682,7 +2678,6 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
2682
2678
return torch .empty (shape , device = self .device , dtype = self .dtype ).random_ ()
2683
2679
2684
2680
def is_in (self , val : torch .Tensor ) -> bool :
2685
- raise RuntimeError
2686
2681
shape = torch .broadcast_shapes (self ._safe_shape , val .shape )
2687
2682
return val .shape == shape and val .dtype == self .dtype
2688
2683
@@ -3034,7 +3029,6 @@ def index(self, index: INDEX_TYPING, tensor_to_index: torch.Tensor) -> torch.Ten
3034
3029
return torch .cat (out , - 1 )
3035
3030
3036
3031
def is_in (self , val : torch .Tensor ) -> bool :
3037
- raise RuntimeError
3038
3032
vals = self ._split (val )
3039
3033
if vals is None :
3040
3034
return False
@@ -3435,7 +3429,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
3435
3429
return val
3436
3430
3437
3431
def is_in (self , val : torch .Tensor ) -> bool :
3438
- raise RuntimeError
3439
3432
if self .mask is None :
3440
3433
shape = torch .broadcast_shapes (self ._safe_shape , val .shape )
3441
3434
shape_match = val .shape == shape
@@ -4066,7 +4059,6 @@ def _project(self, val: torch.Tensor) -> torch.Tensor:
4066
4059
return val .squeeze (0 ) if val_is_scalar else val
4067
4060
4068
4061
def is_in (self , val : torch .Tensor ) -> bool :
4069
- raise RuntimeError
4070
4062
if self .mask is not None :
4071
4063
vals = val .unbind (- 1 )
4072
4064
splits = self ._split_self ()
0 commit comments