@@ -560,6 +560,21 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:
560
560
output_spec ["full_done_spec" ] = self .transform_done_spec (
561
561
output_spec ["full_done_spec" ]
562
562
)
563
+ output_spec_keys = [
564
+ unravel_key (k [1 :]) for k in output_spec .keys (True ) if isinstance (k , tuple )
565
+ ]
566
+ out_keys = {unravel_key (k ) for k in self .out_keys }
567
+ in_keys = {unravel_key (k ) for k in self .in_keys }
568
+ for key in out_keys - in_keys :
569
+ if unravel_key (key ) not in output_spec_keys :
570
+ warnings .warn (
571
+ f"The key '{ key } ' is unaccounted for by the transform (expected keys { output_spec_keys } ). "
572
+ f"Every new entry in the tensordict resulting from a call to a transform must be "
573
+ f"registered in the specs for torchrl rollouts to be consistently built. "
574
+ f"Make sure transform_output_spec/transform_observation_spec/... is coded correctly. "
575
+ "This warning will trigger a KeyError in v0.9, make sure to adapt your code accordingly." ,
576
+ category = FutureWarning ,
577
+ )
563
578
return output_spec
564
579
565
580
def transform_input_spec (self , input_spec : TensorSpec ) -> TensorSpec :
@@ -1468,33 +1483,57 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
1468
1483
# the action spec from the env, map it using t0 then t1 (going from in to out).
1469
1484
for t in self .transforms :
1470
1485
input_spec = t .transform_input_spec (input_spec )
1486
+ if not isinstance (input_spec , Composite ):
1487
+ raise TypeError (
1488
+ f"Expected Compose but got { type (input_spec )} with transform { t } "
1489
+ )
1471
1490
return input_spec
1472
1491
1473
1492
def transform_action_spec (self , action_spec : TensorSpec ) -> TensorSpec :
1474
1493
# To understand why we don't invert, look up at transform_input_spec
1475
1494
for t in self .transforms :
1476
1495
action_spec = t .transform_action_spec (action_spec )
1496
+ if not isinstance (action_spec , TensorSpec ):
1497
+ raise TypeError (
1498
+ f"Expected TensorSpec but got { type (action_spec )} with transform { t } "
1499
+ )
1477
1500
return action_spec
1478
1501
1479
1502
def transform_state_spec (self , state_spec : TensorSpec ) -> TensorSpec :
1480
1503
# To understand why we don't invert, look up at transform_input_spec
1481
1504
for t in self .transforms :
1482
1505
state_spec = t .transform_state_spec (state_spec )
1506
+ if not isinstance (state_spec , Composite ):
1507
+ raise TypeError (
1508
+ f"Expected Compose but got { type (state_spec )} with transform { t } "
1509
+ )
1483
1510
return state_spec
1484
1511
1485
1512
def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
1486
1513
for t in self .transforms :
1487
1514
observation_spec = t .transform_observation_spec (observation_spec )
1515
+ if not isinstance (observation_spec , TensorSpec ):
1516
+ raise TypeError (
1517
+ f"Expected TensorSpec but got { type (observation_spec )} with transform { t } "
1518
+ )
1488
1519
return observation_spec
1489
1520
1490
1521
def transform_output_spec (self , output_spec : TensorSpec ) -> TensorSpec :
1491
1522
for t in self .transforms :
1492
1523
output_spec = t .transform_output_spec (output_spec )
1524
+ if not isinstance (output_spec , Composite ):
1525
+ raise TypeError (
1526
+ f"Expected Compose but got { type (output_spec )} with transform { t } "
1527
+ )
1493
1528
return output_spec
1494
1529
1495
1530
def transform_reward_spec (self , reward_spec : TensorSpec ) -> TensorSpec :
1496
1531
for t in self .transforms :
1497
1532
reward_spec = t .transform_reward_spec (reward_spec )
1533
+ if not isinstance (reward_spec , TensorSpec ):
1534
+ raise TypeError (
1535
+ f"Expected TensorSpec but got { type (reward_spec )} with transform { t } "
1536
+ )
1498
1537
return reward_spec
1499
1538
1500
1539
def __getitem__ (self , item : Union [int , slice , List ]) -> Union :
0 commit comments