Skip to content

Commit 1a6c9e2

Browse files
[BugFix] PettingZoo dict action spaces (#2692)
1 parent 61e05b3 commit 1a6c9e2

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

torchrl/envs/libs/pettingzoo.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import warnings
1010
from typing import Dict, List, Tuple, Union
1111

12+
import numpy as np
1213
import packaging
1314
import torch
1415
from tensordict import TensorDictBase
@@ -72,6 +73,19 @@ def _load_available_envs() -> Dict:
7273
return all_environments
7374

7475

76+
def _extract_nested_with_index(
77+
data: Union[np.ndarray, Dict[str, np.ndarray]], index: int
78+
):
79+
if isinstance(data, np.ndarray):
80+
return data[index]
81+
elif isinstance(data, dict):
82+
return {
83+
key: _extract_nested_with_index(value, index) for key, value in data.items()
84+
}
85+
else:
86+
raise NotImplementedError(f"Invalid type of data {data}")
87+
88+
7589
class PettingZooWrapper(_EnvWrapper):
7690
"""PettingZoo environment wrapper.
7791
@@ -735,7 +749,9 @@ def _step_parallel(
735749
"full_action_spec", group, "action"
736750
].to_numpy(group_action)
737751
for index, agent in enumerate(agents):
738-
action_dict[agent] = group_action_np[index]
752+
# group_action_np can be a dict or an array. We need to recursively index it
753+
action = _extract_nested_with_index(group_action_np, index)
754+
action_dict[agent] = action
739755

740756
return self._env.step(action_dict)
741757

@@ -750,7 +766,8 @@ def _step_aec(
750766
group_action_np = self.input_spec[
751767
"full_action_spec", group, "action"
752768
].to_numpy(group_action)
753-
action = group_action_np[agent_index]
769+
# group_action_np can be a dict or an array. We need to recursively index it
770+
action = _extract_nested_with_index(group_action_np, agent_index)
754771
break
755772

756773
self._env.step(action)

0 commit comments

Comments
 (0)