Skip to content

Commit 72c60fb

Browse files
louisfauryLouis Faury
andauthored
[BugFix] Binary can have empty shape (#2979)
Co-authored-by: Louis Faury <louis.faury@helsing.ai>
1 parent 524a3f5 commit 72c60fb

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

test/test_specs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,9 @@ def test_ndbounded_shape(self):
13871387
assert (-3 <= sample).all() and (3 >= sample).all()
13881388
assert sample.shape == torch.Size([100, 10, 5])
13891389

1390+
def test_binary_empty_shape_construct(self):
1391+
assert len(Binary(shape=()).shape) == 0
1392+
13901393

13911394
class TestExpand:
13921395
@pytest.mark.parametrize("shape1", [None, (4,), (5, 4)])

torchrl/data/tensor_specs.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4305,17 +4305,23 @@ def __init__(
43054305
):
43064306
if n is None and shape is None:
43074307
raise TypeError("Must provide either n or shape.")
4308-
if n is None:
4309-
n = shape[-1]
4310-
if shape is None or not len(shape):
4311-
shape = _size((n,))
4308+
4309+
# Either `shape` or `n` is not `None`.
4310+
shape = _size((n,)) if shape is None else _size(shape)
4311+
4312+
# Consistency checks between `shape` and `n`.
4313+
if len(shape) == 0:
4314+
if n is not None and n != 0:
4315+
raise ValueError(
4316+
f"'n' must be zero for spec {self.__class__} when using an empty shape"
4317+
)
43124318
else:
4313-
shape = _size(shape)
43144319
if shape[-1] != n:
43154320
raise ValueError(
4316-
f"The last value of the shape must match n for spec {self.__class__}. "
4321+
f"The last value of the shape must match 'n' for spec {self.__class__}. "
43174322
f"Got n={n} and shape={shape}."
43184323
)
4324+
43194325
super().__init__(n=2, shape=shape, device=device, dtype=dtype)
43204326
self.encode = self._encode_eager
43214327

0 commit comments

Comments
 (0)