@@ -131,9 +131,9 @@ def make_policy(env):
131
131
132
132
133
133
def _is_consistent_device_type (
134
- device_type , policy_device_type , passing_device_type , tensordict_device_type
134
+ device_type , policy_device_type , storing_device_type , tensordict_device_type
135
135
):
136
- if passing_device_type is None :
136
+ if storing_device_type is None :
137
137
if device_type is None :
138
138
if policy_device_type is None :
139
139
return tensordict_device_type == "cpu"
@@ -142,7 +142,7 @@ def _is_consistent_device_type(
142
142
143
143
return tensordict_device_type == device_type
144
144
145
- return tensordict_device_type == passing_device_type
145
+ return tensordict_device_type == storing_device_type
146
146
147
147
148
148
@pytest .mark .skipif (
@@ -152,12 +152,12 @@ def _is_consistent_device_type(
152
152
@pytest .mark .parametrize ("num_env" , [1 , 2 ])
153
153
@pytest .mark .parametrize ("device" , ["cuda" , "cpu" , None ])
154
154
@pytest .mark .parametrize ("policy_device" , ["cuda" , "cpu" , None ])
155
- @pytest .mark .parametrize ("passing_device " , ["cuda" , "cpu" , None ])
155
+ @pytest .mark .parametrize ("storing_device " , ["cuda" , "cpu" , None ])
156
156
def test_output_device_consistency (
157
- num_env , device , policy_device , passing_device , seed = 40
157
+ num_env , device , policy_device , storing_device , seed = 40
158
158
):
159
159
if (
160
- device == "cuda" or policy_device == "cuda" or passing_device == "cuda"
160
+ device == "cuda" or policy_device == "cuda" or storing_device == "cuda"
161
161
) and not torch .cuda .is_available ():
162
162
pytest .skip ("cuda is not available" )
163
163
@@ -169,7 +169,7 @@ def test_output_device_consistency(
169
169
170
170
_device = "cuda:0" if device == "cuda" else device
171
171
_policy_device = "cuda:0" if policy_device == "cuda" else policy_device
172
- _passing_device = "cuda:0" if passing_device == "cuda" else passing_device
172
+ _storing_device = "cuda:0" if storing_device == "cuda" else storing_device
173
173
174
174
if num_env == 1 :
175
175
@@ -201,12 +201,12 @@ def env_fn(seed):
201
201
max_frames_per_traj = 2000 ,
202
202
total_frames = 20000 ,
203
203
device = _device ,
204
- passing_device = _passing_device ,
204
+ storing_device = _storing_device ,
205
205
pin_memory = False ,
206
206
)
207
207
for _ , d in enumerate (collector ):
208
208
assert _is_consistent_device_type (
209
- device , policy_device , passing_device , d .device .type
209
+ device , policy_device , storing_device , d .device .type
210
210
)
211
211
break
212
212
@@ -220,13 +220,13 @@ def env_fn(seed):
220
220
max_frames_per_traj = 2000 ,
221
221
total_frames = 20000 ,
222
222
device = _device ,
223
- passing_device = _passing_device ,
223
+ storing_device = _storing_device ,
224
224
pin_memory = False ,
225
225
)
226
226
227
227
for _ , d in enumerate (ccollector ):
228
228
assert _is_consistent_device_type (
229
- device , policy_device , passing_device , d .device .type
229
+ device , policy_device , storing_device , d .device .type
230
230
)
231
231
break
232
232
@@ -833,7 +833,7 @@ def create_env():
833
833
[create_env ] * 3 ,
834
834
policy = policy ,
835
835
devices = [torch .device ("cuda:0" )] * 3 ,
836
- passing_devices = [torch .device ("cuda:0" )] * 3 ,
836
+ storing_devices = [torch .device ("cuda:0" )] * 3 ,
837
837
)
838
838
# collect state_dict
839
839
state_dict = collector .state_dict ()
@@ -1010,13 +1010,13 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe
1010
1010
1011
1011
1012
1012
@pytest .mark .parametrize ("device" , ["cuda" , "cpu" ])
1013
- @pytest .mark .parametrize ("passing_device " , ["cuda" , "cpu" ])
1013
+ @pytest .mark .parametrize ("storing_device " , ["cuda" , "cpu" ])
1014
1014
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "no cuda device found" )
1015
- def test_collector_device_combinations (device , passing_device ):
1015
+ def test_collector_device_combinations (device , storing_device ):
1016
1016
if (
1017
1017
_os_is_windows
1018
1018
and _python_is_3_10
1019
- and passing_device == "cuda"
1019
+ and storing_device == "cuda"
1020
1020
and device == "cuda"
1021
1021
):
1022
1022
pytest .skip ("Windows fatal exception: access violation in torch.storage" )
@@ -1036,11 +1036,11 @@ def env_fn(seed):
1036
1036
max_frames_per_traj = 2000 ,
1037
1037
total_frames = 20000 ,
1038
1038
device = device ,
1039
- passing_device = passing_device ,
1039
+ storing_device = storing_device ,
1040
1040
pin_memory = False ,
1041
1041
)
1042
1042
batch = next (collector .iterator ())
1043
- assert batch .device == torch .device (passing_device )
1043
+ assert batch .device == torch .device (storing_device )
1044
1044
collector .shutdown ()
1045
1045
1046
1046
collector = MultiSyncDataCollector (
@@ -1057,13 +1057,13 @@ def env_fn(seed):
1057
1057
devices = [
1058
1058
device ,
1059
1059
],
1060
- passing_devices = [
1061
- passing_device ,
1060
+ storing_devices = [
1061
+ storing_device ,
1062
1062
],
1063
1063
pin_memory = False ,
1064
1064
)
1065
1065
batch = next (collector .iterator ())
1066
- assert batch .device == torch .device (passing_device )
1066
+ assert batch .device == torch .device (storing_device )
1067
1067
collector .shutdown ()
1068
1068
1069
1069
collector = MultiaSyncDataCollector (
@@ -1080,13 +1080,13 @@ def env_fn(seed):
1080
1080
devices = [
1081
1081
device ,
1082
1082
],
1083
- passing_devices = [
1084
- passing_device ,
1083
+ storing_devices = [
1084
+ storing_device ,
1085
1085
],
1086
1086
pin_memory = False ,
1087
1087
)
1088
1088
batch = next (collector .iterator ())
1089
- assert batch .device == torch .device (passing_device )
1089
+ assert batch .device == torch .device (storing_device )
1090
1090
collector .shutdown ()
1091
1091
1092
1092
0 commit comments