5
5
from __future__ import annotations
6
6
7
7
import importlib .util
8
+ import warnings
8
9
9
10
from typing import Dict , List , Optional , Union
10
11
@@ -328,9 +329,9 @@ def _make_specs(
328
329
self .group_map = self .group_map .get_group_map (self .agent_names )
329
330
check_marl_grouping (self .group_map , self .agent_names )
330
331
331
- self . unbatched_action_spec = Composite (device = self .device )
332
- self . unbatched_observation_spec = Composite (device = self .device )
333
- self . unbatched_reward_spec = Composite (device = self .device )
332
+ full_action_spec_unbatched = Composite (device = self .device )
333
+ full_observation_spec_unbatched = Composite (device = self .device )
334
+ full_reward_spec_unbatched = Composite (device = self .device )
334
335
335
336
self .het_specs = False
336
337
self .het_specs_map = {}
@@ -341,18 +342,18 @@ def _make_specs(
341
342
group_reward_spec ,
342
343
group_info_spec ,
343
344
) = self ._make_unbatched_group_specs (group )
344
- self . unbatched_action_spec [group ] = group_action_spec
345
- self . unbatched_observation_spec [group ] = group_observation_spec
346
- self . unbatched_reward_spec [group ] = group_reward_spec
345
+ full_action_spec_unbatched [group ] = group_action_spec
346
+ full_observation_spec_unbatched [group ] = group_observation_spec
347
+ full_reward_spec_unbatched [group ] = group_reward_spec
347
348
if group_info_spec is not None :
348
- self . unbatched_observation_spec [(group , "info" )] = group_info_spec
349
+ full_observation_spec_unbatched [(group , "info" )] = group_info_spec
349
350
group_het_specs = isinstance (
350
351
group_observation_spec , StackedComposite
351
352
) or isinstance (group_action_spec , StackedComposite )
352
353
self .het_specs_map [group ] = group_het_specs
353
354
self .het_specs = self .het_specs or group_het_specs
354
355
355
- self . unbatched_done_spec = Composite (
356
+ full_done_spec_unbatched = Composite (
356
357
{
357
358
"done" : Categorical (
358
359
n = 2 ,
@@ -363,18 +364,42 @@ def _make_specs(
363
364
},
364
365
)
365
366
366
- self .action_spec = self .unbatched_action_spec .expand (
367
- * self .batch_size , * self .unbatched_action_spec .shape
367
+ self .full_action_spec_unbatched = full_action_spec_unbatched
368
+ self .full_observation_spec_unbatched = full_observation_spec_unbatched
369
+ self .full_reward_spec_unbatched = full_reward_spec_unbatched
370
+ self .full_done_spec_unbatched = full_done_spec_unbatched
371
+
372
+ @property
373
+ def unbatched_action_spec (self ):
374
+ warnings .warn (
375
+ "unbatched_action_spec is deprecated and will be removed in v0.9. "
376
+ "Please use full_action_spec_unbatched instead."
368
377
)
369
- self .observation_spec = self .unbatched_observation_spec .expand (
370
- * self .batch_size , * self .unbatched_observation_spec .shape
378
+ return self .full_action_spec_unbatched
379
+
380
+ @property
381
+ def unbatched_observation_spec (self ):
382
+ warnings .warn (
383
+ "unbatched_observation_spec is deprecated and will be removed in v0.9. "
384
+ "Please use full_observation_spec_unbatched instead."
371
385
)
372
- self .reward_spec = self .unbatched_reward_spec .expand (
373
- * self .batch_size , * self .unbatched_reward_spec .shape
386
+ return self .full_observation_spec_unbatched
387
+
388
+ @property
389
+ def unbatched_reward_spec (self ):
390
+ warnings .warn (
391
+ "unbatched_reward_spec is deprecated and will be removed in v0.9. "
392
+ "Please use full_reward_spec_unbatched instead."
374
393
)
375
- self .done_spec = self .unbatched_done_spec .expand (
376
- * self .batch_size , * self .unbatched_done_spec .shape
394
+ return self .full_reward_spec_unbatched
395
+
396
+ @property
397
+ def unbatched_done_spec (self ):
398
+ warnings .warn (
399
+ "unbatched_done_spec is deprecated and will be removed in v0.9. "
400
+ "Please use full_done_spec_unbatched instead."
377
401
)
402
+ return self .full_done_spec_unbatched
378
403
379
404
def _make_unbatched_group_specs (self , group : str ):
380
405
# Agent specs
@@ -618,7 +643,9 @@ def read_reward(self, rewards):
618
643
619
644
def read_action (self , action , group : str = "agents" ):
620
645
if not self .continuous_actions and not self .categorical_actions :
621
- action = self .unbatched_action_spec [group , "action" ].to_categorical (action )
646
+ action = self .full_action_spec_unbatched [group , "action" ].to_categorical (
647
+ action
648
+ )
622
649
agent_actions = action .unbind (dim = 1 )
623
650
return agent_actions
624
651
0 commit comments