Skip to content

Commit dd26ae7

Browse files
author
Vincent Moens
committed
[Feature] spec.cardinality
ghstack-source-id: 1160900 Pull Request resolved: #2638
1 parent 4bc40a8 commit dd26ae7

File tree

3 files changed

+207
-15
lines changed

3 files changed

+207
-15
lines changed

test/test_specs.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,6 +1689,85 @@ def test_unboundeddiscrete(
16891689
assert spec is not spec.clone()
16901690

16911691

1692+
class TestCardinality:
1693+
@pytest.mark.parametrize("shape1", [(5, 4)])
1694+
def test_binary(self, shape1):
1695+
spec = Binary(n=4, shape=shape1, device="cpu", dtype=torch.bool)
1696+
assert spec.cardinality() == len(list(spec.enumerate()))
1697+
1698+
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
1699+
def test_discrete(
1700+
self,
1701+
shape1,
1702+
):
1703+
spec = Categorical(n=4, shape=shape1, device="cpu", dtype=torch.long)
1704+
assert spec.cardinality() == len(list(spec.enumerate()))
1705+
1706+
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
1707+
def test_multidiscrete(
1708+
self,
1709+
shape1,
1710+
):
1711+
if shape1 is None:
1712+
shape1 = (3,)
1713+
else:
1714+
shape1 = (*shape1, 3)
1715+
spec = MultiCategorical(
1716+
nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long
1717+
)
1718+
assert spec.cardinality() == len(spec.enumerate())
1719+
1720+
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
1721+
def test_multionehot(
1722+
self,
1723+
shape1,
1724+
):
1725+
if shape1 is None:
1726+
shape1 = (15,)
1727+
else:
1728+
shape1 = (*shape1, 15)
1729+
spec = MultiOneHot(nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long)
1730+
assert spec.cardinality() == len(list(spec.enumerate()))
1731+
1732+
def test_non_tensor(self):
1733+
spec = NonTensor(shape=(3, 4), device="cpu")
1734+
with pytest.raises(RuntimeError, match="Cannot enumerate a NonTensorSpec."):
1735+
spec.cardinality()
1736+
1737+
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
1738+
def test_onehot(
1739+
self,
1740+
shape1,
1741+
):
1742+
if shape1 is None:
1743+
shape1 = (15,)
1744+
else:
1745+
shape1 = (*shape1, 15)
1746+
spec = OneHot(n=15, shape=shape1, device="cpu", dtype=torch.long)
1747+
assert spec.cardinality() == len(list(spec.enumerate()))
1748+
1749+
def test_composite(self):
1750+
batch_size = (5,)
1751+
spec2 = Binary(n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool)
1752+
spec3 = Categorical(n=4, shape=batch_size, device="cpu", dtype=torch.long)
1753+
spec4 = MultiCategorical(
1754+
nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long
1755+
)
1756+
spec5 = MultiOneHot(
1757+
nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long
1758+
)
1759+
spec6 = OneHot(n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long)
1760+
spec = Composite(
1761+
spec2=spec2,
1762+
spec3=spec3,
1763+
spec4=spec4,
1764+
spec5=spec5,
1765+
spec6=spec6,
1766+
shape=batch_size,
1767+
)
1768+
assert spec.cardinality() == len(spec.enumerate())
1769+
1770+
16921771
class TestUnbind:
16931772
@pytest.mark.parametrize("shape1", [(5, 4)])
16941773
def test_binary(self, shape1):

torchrl/data/tensor_specs.py

Lines changed: 109 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
unravel_key,
4242
)
4343
from tensordict.base import NO_DEFAULT
44-
from tensordict.utils import _getitem_batch_size, NestedKey
44+
from tensordict.utils import _getitem_batch_size, is_non_tensor, NestedKey
4545
from torchrl._utils import _make_ordinal_device, get_binary_env_var, implement_for
4646

4747
DEVICE_TYPING = Union[torch.device, str, int]
@@ -582,6 +582,16 @@ def clear_device_(self) -> T:
582582
"""
583583
return self
584584

585+
@abc.abstractmethod
586+
def cardinality(self) -> int:
587+
"""The cardinality of the spec.
588+
589+
This refers to the number of possible outcomes in a spec. It is assumed that the cardinality of a composite
590+
spec is the cartesian product of all possible outcomes.
591+
592+
"""
593+
...
594+
585595
def encode(
586596
self,
587597
val: np.ndarray | torch.Tensor | TensorDictBase,
@@ -1515,6 +1525,9 @@ def __init__(
15151525
def n(self):
15161526
return self.space.n
15171527

1528+
def cardinality(self) -> int:
1529+
return self.n
1530+
15181531
def update_mask(self, mask):
15191532
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
15201533
@@ -2107,6 +2120,9 @@ def enumerate(self) -> Any:
21072120
f"enumerate is not implemented for spec of class {type(self).__name__}."
21082121
)
21092122

2123+
def cardinality(self) -> int:
2124+
return float("inf")
2125+
21102126
def __eq__(self, other):
21112127
return (
21122128
type(other) == type(self)
@@ -2426,8 +2442,11 @@ def __init__(
24262442
shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs
24272443
)
24282444

2445+
def cardinality(self) -> Any:
2446+
raise RuntimeError("Cannot enumerate a NonTensorSpec.")
2447+
24292448
def enumerate(self) -> Any:
2430-
raise NotImplementedError("Cannot enumerate a NonTensorSpec.")
2449+
raise RuntimeError("Cannot enumerate a NonTensorSpec.")
24312450

24322451
def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor:
24332452
if isinstance(dest, torch.dtype):
@@ -2466,10 +2485,10 @@ def one(self, shape=None):
24662485
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
24672486
)
24682487

2469-
def is_in(self, val: torch.Tensor) -> bool:
2488+
def is_in(self, val: Any) -> bool:
24702489
shape = torch.broadcast_shapes(self._safe_shape, val.shape)
24712490
return (
2472-
isinstance(val, NonTensorData)
2491+
is_non_tensor(val)
24732492
and val.shape == shape
24742493
# We relax constrains on device as they're hard to enforce for non-tensor
24752494
# tensordicts and pointless
@@ -2832,6 +2851,9 @@ def __init__(
28322851
)
28332852
self.update_mask(mask)
28342853

2854+
def cardinality(self) -> int:
2855+
return torch.as_tensor(self.nvec).prod()
2856+
28352857
def enumerate(self) -> torch.Tensor:
28362858
nvec = self.nvec
28372859
enum_disc = self.to_categorical_spec().enumerate()
@@ -3220,13 +3242,20 @@ class Categorical(TensorSpec):
32203242
The spec will have the shape defined by the ``shape`` argument: if a singleton dimension is
32213243
desired for the training dimension, one should specify it explicitly.
32223244
3245+
Attributes:
3246+
n (int): The number of possible outcomes.
3247+
shape (torch.Size): The shape of the variable.
3248+
device (torch.device): The device of the tensors.
3249+
dtype (torch.dtype): The dtype of the tensors.
3250+
32233251
Args:
3224-
n (int): number of possible outcomes.
3252+
n (int): number of possible outcomes. If set to -1, the cardinality of the categorical spec is undefined,
3253+
and `set_provisional_n` must be called before sampling from this spec.
32253254
shape: (torch.Size, optional): shape of the variable, default is "torch.Size([])".
3226-
device (str, int or torch.device, optional): device of the tensors.
3227-
dtype (str or torch.dtype, optional): dtype of the tensors.
3228-
mask (torch.Tensor or None): mask some of the possible outcomes when a
3229-
sample is taken. See :meth:`~.update_mask` for more information.
3255+
device (str, int or torch.device, optional): the device of the tensors.
3256+
dtype (str or torch.dtype, optional): the dtype of the tensors.
3257+
mask (torch.Tensor or None): A boolean mask to prevent some of the possible outcomes when a sample is taken.
3258+
See :meth:`~.update_mask` for more information.
32303259
32313260
Examples:
32323261
>>> categ = Categorical(3)
@@ -3249,6 +3278,13 @@ class Categorical(TensorSpec):
32493278
domain=discrete)
32503279
>>> categ.rand()
32513280
tensor([1])
3281+
>>> categ = Categorical(-1)
3282+
>>> categ.set_provisional_n(5)
3283+
>>> categ.rand()
3284+
tensor(3)
3285+
3286+
.. note:: When n is set to -1, calling `rand` without first setting a provisional n using `set_provisional_n`
3287+
will raise a ``RuntimeError``.
32523288
32533289
"""
32543290

@@ -3276,16 +3312,31 @@ def __init__(
32763312
shape=shape, space=space, device=device, dtype=dtype, domain="discrete"
32773313
)
32783314
self.update_mask(mask)
3315+
self._provisional_n = None
32793316

32803317
def enumerate(self) -> torch.Tensor:
3281-
arange = torch.arange(self.n, dtype=self.dtype, device=self.device)
3318+
dtype = self.dtype
3319+
if dtype is torch.bool:
3320+
dtype = torch.uint8
3321+
arange = torch.arange(self.n, dtype=dtype, device=self.device)
32823322
if self.ndim:
32833323
arange = arange.view(-1, *(1,) * self.ndim)
32843324
return arange.expand(self.n, *self.shape)
32853325

32863326
@property
32873327
def n(self):
3288-
return self.space.n
3328+
n = self.space.n
3329+
if n == -1:
3330+
n = self._provisional_n
3331+
if n is None:
3332+
raise RuntimeError(
3333+
f"Undefined cardinality for {type(self)}. Please call "
3334+
f"spec.set_provisional_n(int)."
3335+
)
3336+
return n
3337+
3338+
def cardinality(self) -> int:
3339+
return self.n
32893340

32903341
def update_mask(self, mask):
32913342
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
@@ -3316,13 +3367,33 @@ def update_mask(self, mask):
33163367
raise ValueError("Only boolean masks are accepted.")
33173368
self.mask = mask
33183369

3370+
def set_provisional_n(self, n: int):
3371+
"""Set the cardinality of the Categorical spec temporarily.
3372+
3373+
This method is required to be called before sampling from the spec when n is -1.
3374+
3375+
Args:
3376+
n (int): The cardinality of the Categorical spec.
3377+
3378+
"""
3379+
self._provisional_n = n
3380+
33193381
def rand(self, shape: torch.Size = None) -> torch.Tensor:
3382+
if self.space.n < 0:
3383+
if self._provisional_n is None:
3384+
raise RuntimeError(
3385+
"Cannot generate random categorical samples for undefined cardinality (n=-1). "
3386+
"To sample from this class, first call Categorical.set_provisional_n(n) before calling rand()."
3387+
)
3388+
n = self._provisional_n
3389+
else:
3390+
n = self.space.n
33203391
if shape is None:
33213392
shape = _size([])
33223393
if self.mask is None:
33233394
return torch.randint(
33243395
0,
3325-
self.space.n,
3396+
n,
33263397
_size([*shape, *self.shape]),
33273398
device=self.device,
33283399
dtype=self.dtype,
@@ -3334,6 +3405,12 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
33343405
else:
33353406
mask_flat = mask
33363407
shape_out = mask.shape[:-1]
3408+
# Check that the mask has the right size
3409+
if mask_flat.shape[-1] != n:
3410+
raise ValueError(
3411+
"The last dimension of the mask must match the number of action allowed by the "
3412+
f"Categorical spec. Got mask.shape={self.mask.shape} and n={n}."
3413+
)
33373414
out = torch.multinomial(mask_flat.float(), 1).reshape(shape_out)
33383415
return out
33393416

@@ -3360,6 +3437,8 @@ def is_in(self, val: torch.Tensor) -> bool:
33603437
dtype_match = val.dtype == self.dtype
33613438
if not dtype_match:
33623439
return False
3440+
if self.space.n == -1:
3441+
return True
33633442
return (0 <= val).all() and (val < self.space.n).all()
33643443
shape = self.mask.shape
33653444
shape = _size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]])
@@ -3607,7 +3686,7 @@ def __init__(
36073686
device: Optional[DEVICE_TYPING] = None,
36083687
dtype: Union[str, torch.dtype] = torch.int8,
36093688
):
3610-
if n is None and not shape:
3689+
if n is None and shape is None:
36113690
raise TypeError("Must provide either n or shape.")
36123691
if n is None:
36133692
n = shape[-1]
@@ -3813,6 +3892,9 @@ def enumerate(self) -> torch.Tensor:
38133892
arange = arange.expand(arange.shape[0], *self.shape)
38143893
return arange
38153894

3895+
def cardinality(self) -> int:
3896+
return self.nvec._base.prod()
3897+
38163898
def update_mask(self, mask):
38173899
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
38183900
@@ -4373,7 +4455,7 @@ def set(self, name, spec):
43734455
shape = spec.shape
43744456
if shape[: self.ndim] != self.shape:
43754457
if (
4376-
isinstance(spec, Composite)
4458+
isinstance(spec, (Composite, NonTensor))
43774459
and spec.ndim < self.ndim
43784460
and self.shape[: spec.ndim] == spec.shape
43794461
):
@@ -4382,7 +4464,7 @@ def set(self, name, spec):
43824464
spec.shape = self.shape
43834465
else:
43844466
raise ValueError(
4385-
"The shape of the spec and the Composite mismatch: the first "
4467+
f"The shape of the spec {type(spec).__name__} and the Composite {type(self).__name__} mismatch: the first "
43864468
f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and "
43874469
f"Composite.shape={self.shape}."
43884470
)
@@ -4798,6 +4880,18 @@ def clone(self) -> Composite:
47984880
shape=self.shape,
47994881
)
48004882

4883+
def cardinality(self) -> int:
4884+
n = None
4885+
for spec in self.values():
4886+
if spec is None:
4887+
continue
4888+
if n is None:
4889+
n = 1
4890+
n = n * spec.cardinality()
4891+
if n is None:
4892+
n = 0
4893+
return n
4894+
48014895
def enumerate(self) -> TensorDictBase:
48024896
# We are going to use meshgrid to create samples of all the subspecs in here
48034897
# but first let's get rid of the batch size, we'll put it back later

torchrl/envs/common.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,25 @@ def check_env_specs(self, *args, **kwargs):
561561

562562
check_env_specs.__doc__ = check_env_specs_func.__doc__
563563

564+
def cardinality(self, tensordict: TensorDictBase | None = None) -> int:
565+
"""The cardinality of the action space.
566+
567+
By default, this is just a wrapper around :meth:`env.action_space.cardinality <~torchrl.data.TensorSpec.cardinality>`.
568+
569+
This class is useful when the action spec is variable:
570+
571+
- The number of actions can be undefined, e.g., ``Categorical(n=-1)``;
572+
- The action cardinality may depend on the action mask;
573+
- The shape can be dynamic, as in ``Unbound(shape=(-1))``.
574+
575+
In these cases, the :meth:`~.cardinality` should be overwritten,
576+
577+
Args:
578+
tensordict (TensorDictBase, optional): a tensordict containing the data required to compute the cardinality.
579+
580+
"""
581+
return self.full_action_spec.cardinality()
582+
564583
@classmethod
565584
def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs):
566585
# inplace update will write tensors in-place on the provided tensordict.

0 commit comments

Comments
 (0)