Skip to content

Commit a126a6f

Browse files
author
Vincent Moens
committed
[Refactor] Use <spec>_unbatched in VMAS
ghstack-source-id: 2190278 Pull Request resolved: #2593
1 parent d30599e commit a126a6f

File tree

2 files changed

+45
-18
lines changed

2 files changed

+45
-18
lines changed

sota-implementations/multiagent/qmix_vdn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def train(cfg: "DictConfig"): # noqa: F821
110110
if cfg.loss.mixer_type == "qmix":
111111
mixer = TensorDictModule(
112112
module=QMixer(
113-
state_shape=env.unbatched_observation_spec[
113+
state_shape=env.observation_spec_unbatched[
114114
"agents", "observation"
115115
].shape,
116116
mixing_embed_dim=32,

torchrl/envs/libs/vmas.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import importlib.util
8+
import warnings
89

910
from typing import Dict, List, Optional, Union
1011

@@ -328,9 +329,9 @@ def _make_specs(
328329
self.group_map = self.group_map.get_group_map(self.agent_names)
329330
check_marl_grouping(self.group_map, self.agent_names)
330331

331-
self.unbatched_action_spec = Composite(device=self.device)
332-
self.unbatched_observation_spec = Composite(device=self.device)
333-
self.unbatched_reward_spec = Composite(device=self.device)
332+
full_action_spec_unbatched = Composite(device=self.device)
333+
full_observation_spec_unbatched = Composite(device=self.device)
334+
full_reward_spec_unbatched = Composite(device=self.device)
334335

335336
self.het_specs = False
336337
self.het_specs_map = {}
@@ -341,18 +342,18 @@ def _make_specs(
341342
group_reward_spec,
342343
group_info_spec,
343344
) = self._make_unbatched_group_specs(group)
344-
self.unbatched_action_spec[group] = group_action_spec
345-
self.unbatched_observation_spec[group] = group_observation_spec
346-
self.unbatched_reward_spec[group] = group_reward_spec
345+
full_action_spec_unbatched[group] = group_action_spec
346+
full_observation_spec_unbatched[group] = group_observation_spec
347+
full_reward_spec_unbatched[group] = group_reward_spec
347348
if group_info_spec is not None:
348-
self.unbatched_observation_spec[(group, "info")] = group_info_spec
349+
full_observation_spec_unbatched[(group, "info")] = group_info_spec
349350
group_het_specs = isinstance(
350351
group_observation_spec, StackedComposite
351352
) or isinstance(group_action_spec, StackedComposite)
352353
self.het_specs_map[group] = group_het_specs
353354
self.het_specs = self.het_specs or group_het_specs
354355

355-
self.unbatched_done_spec = Composite(
356+
full_done_spec_unbatched = Composite(
356357
{
357358
"done": Categorical(
358359
n=2,
@@ -363,18 +364,42 @@ def _make_specs(
363364
},
364365
)
365366

366-
self.action_spec = self.unbatched_action_spec.expand(
367-
*self.batch_size, *self.unbatched_action_spec.shape
367+
self.full_action_spec_unbatched = full_action_spec_unbatched
368+
self.full_observation_spec_unbatched = full_observation_spec_unbatched
369+
self.full_reward_spec_unbatched = full_reward_spec_unbatched
370+
self.full_done_spec_unbatched = full_done_spec_unbatched
371+
372+
@property
373+
def unbatched_action_spec(self):
374+
warnings.warn(
375+
"unbatched_action_spec is deprecated and will be removed in v0.9. "
376+
"Please use full_action_spec_unbatched instead."
368377
)
369-
self.observation_spec = self.unbatched_observation_spec.expand(
370-
*self.batch_size, *self.unbatched_observation_spec.shape
378+
return self.full_action_spec_unbatched
379+
380+
@property
381+
def unbatched_observation_spec(self):
382+
warnings.warn(
383+
"unbatched_observation_spec is deprecated and will be removed in v0.9. "
384+
"Please use full_observation_spec_unbatched instead."
371385
)
372-
self.reward_spec = self.unbatched_reward_spec.expand(
373-
*self.batch_size, *self.unbatched_reward_spec.shape
386+
return self.full_observation_spec_unbatched
387+
388+
@property
389+
def unbatched_reward_spec(self):
390+
warnings.warn(
391+
"unbatched_reward_spec is deprecated and will be removed in v0.9. "
392+
"Please use full_reward_spec_unbatched instead."
374393
)
375-
self.done_spec = self.unbatched_done_spec.expand(
376-
*self.batch_size, *self.unbatched_done_spec.shape
394+
return self.full_reward_spec_unbatched
395+
396+
@property
397+
def unbatched_done_spec(self):
398+
warnings.warn(
399+
"unbatched_done_spec is deprecated and will be removed in v0.9. "
400+
"Please use full_done_spec_unbatched instead."
377401
)
402+
return self.full_done_spec_unbatched
378403

379404
def _make_unbatched_group_specs(self, group: str):
380405
# Agent specs
@@ -618,7 +643,9 @@ def read_reward(self, rewards):
618643

619644
def read_action(self, action, group: str = "agents"):
620645
if not self.continuous_actions and not self.categorical_actions:
621-
action = self.unbatched_action_spec[group, "action"].to_categorical(action)
646+
action = self.full_action_spec_unbatched[group, "action"].to_categorical(
647+
action
648+
)
622649
agent_actions = action.unbind(dim=1)
623650
return agent_actions
624651

0 commit comments

Comments
 (0)