Skip to content

Commit ba0faef

Browse files
louisfauryLouis Faury
andauthored
[BugFix] Fixes the Categorical is_in with non-long integer (#2981)
Co-authored-by: Louis Faury <louis.faury@helsing.ai>
1 parent 00657f0 commit ba0faef

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchrl/data/tensor_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3966,7 +3966,7 @@ def is_in(self, val: torch.Tensor) -> bool:
39663966
shape = self.mask.shape
39673967
shape = _size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]])
39683968
mask_expand = self.mask.expand(shape)
3969-
gathered = mask_expand.gather(-1, val.unsqueeze(-1))
3969+
gathered = mask_expand.gather(-1, val.unsqueeze(-1).to(torch.long))
39703970
return gathered.all()
39713971

39723972
def __getitem__(self, idx: SHAPE_INDEX_TYPING):

0 commit comments

Comments
 (0)