@@ -189,7 +189,9 @@ def test_additivegaussian_wrapper(
189
189
default_interaction_mode = "random" ,
190
190
).to (device )
191
191
given_spec = action_spec if spec_origin == "spec" else None
192
- exploratory_policy = AdditiveGaussianWrapper (policy , spec = given_spec ).to (device )
192
+ exploratory_policy = AdditiveGaussianWrapper (
193
+ policy , spec = given_spec , safe = False
194
+ ).to (device )
193
195
194
196
tensordict = TensorDict (
195
197
batch_size = [batch ],
@@ -198,21 +200,20 @@ def test_additivegaussian_wrapper(
198
200
)
199
201
out_noexp = []
200
202
out = []
201
- if exploratory_policy .spec is not None :
202
- for _ in range (n_steps ):
203
- tensordict_noexp = policy (tensordict .select ("observation" ))
204
- tensordict = exploratory_policy (tensordict )
205
- out .append (tensordict .clone ())
206
- out_noexp .append (tensordict_noexp .clone ())
207
- tensordict .set_ ("observation" , torch .randn (batch , d_obs , device = device ))
208
- out = torch .stack (out , 0 )
209
- out_noexp = torch .stack (out_noexp , 0 )
210
- assert (out_noexp .get ("action" ) != out .get ("action" )).all ()
211
- if spec_origin is not None :
212
- assert (out .get ("action" ) <= 1.0 ).all (), out .get ("action" ).min ()
213
- assert (out .get ("action" ) >= - 1.0 ).all (), out .get ("action" ).max ()
214
- if action_spec is not None :
215
- assert action_spec .is_in (out .get ("action" ))
203
+ for _ in range (n_steps ):
204
+ tensordict_noexp = policy (tensordict .select ("observation" ))
205
+ tensordict = exploratory_policy (tensordict )
206
+ out .append (tensordict .clone ())
207
+ out_noexp .append (tensordict_noexp .clone ())
208
+ tensordict .set_ ("observation" , torch .randn (batch , d_obs , device = device ))
209
+ out = torch .stack (out , 0 )
210
+ out_noexp = torch .stack (out_noexp , 0 )
211
+ assert (out_noexp .get ("action" ) != out .get ("action" )).all ()
212
+ if spec_origin is not None :
213
+ assert (out .get ("action" ) <= 1.0 ).all (), out .get ("action" ).min ()
214
+ assert (out .get ("action" ) >= - 1.0 ).all (), out .get ("action" ).max ()
215
+ if action_spec is not None :
216
+ assert action_spec .is_in (out .get ("action" ))
216
217
217
218
218
219
@pytest .mark .parametrize ("state_dim" , [7 ])
0 commit comments