9
9
import warnings
10
10
from typing import Dict , List , Tuple , Union
11
11
12
+ import numpy as np
12
13
import packaging
13
14
import torch
14
15
from tensordict import TensorDictBase
@@ -72,6 +73,19 @@ def _load_available_envs() -> Dict:
72
73
return all_environments
73
74
74
75
76
+ def _extract_nested_with_index (
77
+ data : Union [np .ndarray , Dict [str , np .ndarray ]], index : int
78
+ ):
79
+ if isinstance (data , np .ndarray ):
80
+ return data [index ]
81
+ elif isinstance (data , dict ):
82
+ return {
83
+ key : _extract_nested_with_index (value , index ) for key , value in data .items ()
84
+ }
85
+ else :
86
+ raise NotImplementedError (f"Invalid type of data { data } " )
87
+
88
+
75
89
class PettingZooWrapper (_EnvWrapper ):
76
90
"""PettingZoo environment wrapper.
77
91
@@ -735,7 +749,9 @@ def _step_parallel(
735
749
"full_action_spec" , group , "action"
736
750
].to_numpy (group_action )
737
751
for index , agent in enumerate (agents ):
738
- action_dict [agent ] = group_action_np [index ]
752
+ # group_action_np can be a dict or an array. We need to recursively index it
753
+ action = _extract_nested_with_index (group_action_np , index )
754
+ action_dict [agent ] = action
739
755
740
756
return self ._env .step (action_dict )
741
757
@@ -750,7 +766,8 @@ def _step_aec(
750
766
group_action_np = self .input_spec [
751
767
"full_action_spec" , group , "action"
752
768
].to_numpy (group_action )
753
- action = group_action_np [agent_index ]
769
+ # group_action_np can be a dict or an array. We need to recursively index it
770
+ action = _extract_nested_with_index (group_action_np , agent_index )
754
771
break
755
772
756
773
self ._env .step (action )
0 commit comments