41
41
unravel_key ,
42
42
)
43
43
from tensordict .base import NO_DEFAULT
44
- from tensordict .utils import _getitem_batch_size , NestedKey
44
+ from tensordict .utils import _getitem_batch_size , is_non_tensor , NestedKey
45
45
from torchrl ._utils import _make_ordinal_device , get_binary_env_var , implement_for
46
46
47
47
DEVICE_TYPING = Union [torch .device , str , int ]
@@ -582,6 +582,16 @@ def clear_device_(self) -> T:
582
582
"""
583
583
return self
584
584
585
+ @abc .abstractmethod
586
+ def cardinality (self ) -> int :
587
+ """The cardinality of the spec.
588
+
589
+ This refers to the number of possible outcomes in a spec. It is assumed that the cardinality of a composite
590
+ spec is the cartesian product of all possible outcomes.
591
+
592
+ """
593
+ ...
594
+
585
595
def encode (
586
596
self ,
587
597
val : np .ndarray | torch .Tensor | TensorDictBase ,
@@ -1515,6 +1525,9 @@ def __init__(
1515
1525
def n (self ):
1516
1526
return self .space .n
1517
1527
1528
+ def cardinality (self ) -> int :
1529
+ return self .n
1530
+
1518
1531
def update_mask (self , mask ):
1519
1532
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
1520
1533
@@ -2107,6 +2120,9 @@ def enumerate(self) -> Any:
2107
2120
f"enumerate is not implemented for spec of class { type (self ).__name__ } ."
2108
2121
)
2109
2122
2123
+ def cardinality (self ) -> int :
2124
+ return float ("inf" )
2125
+
2110
2126
def __eq__ (self , other ):
2111
2127
return (
2112
2128
type (other ) == type (self )
@@ -2426,8 +2442,11 @@ def __init__(
2426
2442
shape = shape , space = None , device = device , dtype = dtype , domain = domain , ** kwargs
2427
2443
)
2428
2444
2445
+ def cardinality (self ) -> Any :
2446
+ raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
2447
+
2429
2448
def enumerate (self ) -> Any :
2430
- raise NotImplementedError ("Cannot enumerate a NonTensorSpec." )
2449
+ raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
2431
2450
2432
2451
def to (self , dest : Union [torch .dtype , DEVICE_TYPING ]) -> NonTensor :
2433
2452
if isinstance (dest , torch .dtype ):
@@ -2466,10 +2485,10 @@ def one(self, shape=None):
2466
2485
data = None , batch_size = (* shape , * self ._safe_shape ), device = self .device
2467
2486
)
2468
2487
2469
- def is_in (self , val : torch . Tensor ) -> bool :
2488
+ def is_in (self , val : Any ) -> bool :
2470
2489
shape = torch .broadcast_shapes (self ._safe_shape , val .shape )
2471
2490
return (
2472
- isinstance (val , NonTensorData )
2491
+ is_non_tensor (val )
2473
2492
and val .shape == shape
2474
2493
# We relax constrains on device as they're hard to enforce for non-tensor
2475
2494
# tensordicts and pointless
@@ -2832,6 +2851,9 @@ def __init__(
2832
2851
)
2833
2852
self .update_mask (mask )
2834
2853
2854
+ def cardinality (self ) -> int :
2855
+ return torch .as_tensor (self .nvec ).prod ()
2856
+
2835
2857
def enumerate (self ) -> torch .Tensor :
2836
2858
nvec = self .nvec
2837
2859
enum_disc = self .to_categorical_spec ().enumerate ()
@@ -3220,13 +3242,20 @@ class Categorical(TensorSpec):
3220
3242
The spec will have the shape defined by the ``shape`` argument: if a singleton dimension is
3221
3243
desired for the training dimension, one should specify it explicitly.
3222
3244
3245
+ Attributes:
3246
+ n (int): The number of possible outcomes.
3247
+ shape (torch.Size): The shape of the variable.
3248
+ device (torch.device): The device of the tensors.
3249
+ dtype (torch.dtype): The dtype of the tensors.
3250
+
3223
3251
Args:
3224
- n (int): number of possible outcomes.
3252
+ n (int): number of possible outcomes. If set to -1, the cardinality of the categorical spec is undefined,
3253
+ and `set_provisional_n` must be called before sampling from this spec.
3225
3254
shape: (torch.Size, optional): shape of the variable, default is "torch.Size([])".
3226
- device (str, int or torch.device, optional): device of the tensors.
3227
- dtype (str or torch.dtype, optional): dtype of the tensors.
3228
- mask (torch.Tensor or None): mask some of the possible outcomes when a
3229
- sample is taken. See :meth:`~.update_mask` for more information.
3255
+ device (str, int or torch.device, optional): the device of the tensors.
3256
+ dtype (str or torch.dtype, optional): the dtype of the tensors.
3257
+ mask (torch.Tensor or None): A boolean mask to prevent some of the possible outcomes when a sample is taken.
3258
+ See :meth:`~.update_mask` for more information.
3230
3259
3231
3260
Examples:
3232
3261
>>> categ = Categorical(3)
@@ -3249,6 +3278,13 @@ class Categorical(TensorSpec):
3249
3278
domain=discrete)
3250
3279
>>> categ.rand()
3251
3280
tensor([1])
3281
+ >>> categ = Categorical(-1)
3282
+ >>> categ.set_provisional_n(5)
3283
+ >>> categ.rand()
3284
+ tensor(3)
3285
+
3286
+ .. note:: When n is set to -1, calling `rand` without first setting a provisional n using `set_provisional_n`
3287
+ will raise a ``RuntimeError``.
3252
3288
3253
3289
"""
3254
3290
@@ -3276,16 +3312,31 @@ def __init__(
3276
3312
shape = shape , space = space , device = device , dtype = dtype , domain = "discrete"
3277
3313
)
3278
3314
self .update_mask (mask )
3315
+ self ._provisional_n = None
3279
3316
3280
3317
def enumerate (self ) -> torch .Tensor :
3281
- arange = torch .arange (self .n , dtype = self .dtype , device = self .device )
3318
+ dtype = self .dtype
3319
+ if dtype is torch .bool :
3320
+ dtype = torch .uint8
3321
+ arange = torch .arange (self .n , dtype = dtype , device = self .device )
3282
3322
if self .ndim :
3283
3323
arange = arange .view (- 1 , * (1 ,) * self .ndim )
3284
3324
return arange .expand (self .n , * self .shape )
3285
3325
3286
3326
@property
3287
3327
def n (self ):
3288
- return self .space .n
3328
+ n = self .space .n
3329
+ if n == - 1 :
3330
+ n = self ._provisional_n
3331
+ if n is None :
3332
+ raise RuntimeError (
3333
+ f"Undefined cardinality for { type (self )} . Please call "
3334
+ f"spec.set_provisional_n(int)."
3335
+ )
3336
+ return n
3337
+
3338
+ def cardinality (self ) -> int :
3339
+ return self .n
3289
3340
3290
3341
def update_mask (self , mask ):
3291
3342
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
@@ -3316,13 +3367,33 @@ def update_mask(self, mask):
3316
3367
raise ValueError ("Only boolean masks are accepted." )
3317
3368
self .mask = mask
3318
3369
3370
+ def set_provisional_n (self , n : int ):
3371
+ """Set the cardinality of the Categorical spec temporarily.
3372
+
3373
+ This method is required to be called before sampling from the spec when n is -1.
3374
+
3375
+ Args:
3376
+ n (int): The cardinality of the Categorical spec.
3377
+
3378
+ """
3379
+ self ._provisional_n = n
3380
+
3319
3381
def rand (self , shape : torch .Size = None ) -> torch .Tensor :
3382
+ if self .space .n < 0 :
3383
+ if self ._provisional_n is None :
3384
+ raise RuntimeError (
3385
+ "Cannot generate random categorical samples for undefined cardinality (n=-1). "
3386
+ "To sample from this class, first call Categorical.set_provisional_n(n) before calling rand()."
3387
+ )
3388
+ n = self ._provisional_n
3389
+ else :
3390
+ n = self .space .n
3320
3391
if shape is None :
3321
3392
shape = _size ([])
3322
3393
if self .mask is None :
3323
3394
return torch .randint (
3324
3395
0 ,
3325
- self . space . n ,
3396
+ n ,
3326
3397
_size ([* shape , * self .shape ]),
3327
3398
device = self .device ,
3328
3399
dtype = self .dtype ,
@@ -3334,6 +3405,12 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor:
3334
3405
else :
3335
3406
mask_flat = mask
3336
3407
shape_out = mask .shape [:- 1 ]
3408
+ # Check that the mask has the right size
3409
+ if mask_flat .shape [- 1 ] != n :
3410
+ raise ValueError (
3411
+ "The last dimension of the mask must match the number of action allowed by the "
3412
+ f"Categorical spec. Got mask.shape={ self .mask .shape } and n={ n } ."
3413
+ )
3337
3414
out = torch .multinomial (mask_flat .float (), 1 ).reshape (shape_out )
3338
3415
return out
3339
3416
@@ -3360,6 +3437,8 @@ def is_in(self, val: torch.Tensor) -> bool:
3360
3437
dtype_match = val .dtype == self .dtype
3361
3438
if not dtype_match :
3362
3439
return False
3440
+ if self .space .n == - 1 :
3441
+ return True
3363
3442
return (0 <= val ).all () and (val < self .space .n ).all ()
3364
3443
shape = self .mask .shape
3365
3444
shape = _size ([* torch .broadcast_shapes (shape [:- 1 ], val .shape ), shape [- 1 ]])
@@ -3607,7 +3686,7 @@ def __init__(
3607
3686
device : Optional [DEVICE_TYPING ] = None ,
3608
3687
dtype : Union [str , torch .dtype ] = torch .int8 ,
3609
3688
):
3610
- if n is None and not shape :
3689
+ if n is None and shape is None :
3611
3690
raise TypeError ("Must provide either n or shape." )
3612
3691
if n is None :
3613
3692
n = shape [- 1 ]
@@ -3813,6 +3892,9 @@ def enumerate(self) -> torch.Tensor:
3813
3892
arange = arange .expand (arange .shape [0 ], * self .shape )
3814
3893
return arange
3815
3894
3895
+ def cardinality (self ) -> int :
3896
+ return self .nvec ._base .prod ()
3897
+
3816
3898
def update_mask (self , mask ):
3817
3899
"""Sets a mask to prevent some of the possible outcomes when a sample is taken.
3818
3900
@@ -4373,7 +4455,7 @@ def set(self, name, spec):
4373
4455
shape = spec .shape
4374
4456
if shape [: self .ndim ] != self .shape :
4375
4457
if (
4376
- isinstance (spec , Composite )
4458
+ isinstance (spec , ( Composite , NonTensor ) )
4377
4459
and spec .ndim < self .ndim
4378
4460
and self .shape [: spec .ndim ] == spec .shape
4379
4461
):
@@ -4382,7 +4464,7 @@ def set(self, name, spec):
4382
4464
spec .shape = self .shape
4383
4465
else :
4384
4466
raise ValueError (
4385
- "The shape of the spec and the Composite mismatch: the first "
4467
+ f "The shape of the spec { type ( spec ). __name__ } and the Composite { type ( self ). __name__ } mismatch: the first "
4386
4468
f"{ self .ndim } dimensions should match but got spec.shape={ spec .shape } and "
4387
4469
f"Composite.shape={ self .shape } ."
4388
4470
)
@@ -4798,6 +4880,18 @@ def clone(self) -> Composite:
4798
4880
shape = self .shape ,
4799
4881
)
4800
4882
4883
+ def cardinality (self ) -> int :
4884
+ n = None
4885
+ for spec in self .values ():
4886
+ if spec is None :
4887
+ continue
4888
+ if n is None :
4889
+ n = 1
4890
+ n = n * spec .cardinality ()
4891
+ if n is None :
4892
+ n = 0
4893
+ return n
4894
+
4801
4895
def enumerate (self ) -> TensorDictBase :
4802
4896
# We are going to use meshgrid to create samples of all the subspecs in here
4803
4897
# but first let's get rid of the batch size, we'll put it back later
0 commit comments