@@ -198,20 +198,21 @@ def test_additivegaussian_wrapper(
198
198
)
199
199
out_noexp = []
200
200
out = []
201
- for _ in range (n_steps ):
202
- tensordict_noexp = policy (tensordict .select ("observation" ))
203
- tensordict = exploratory_policy (tensordict )
204
- out .append (tensordict .clone ())
205
- out_noexp .append (tensordict_noexp .clone ())
206
- tensordict .set_ ("observation" , torch .randn (batch , d_obs , device = device ))
207
- out = torch .stack (out , 0 )
208
- out_noexp = torch .stack (out_noexp , 0 )
209
- assert (out_noexp .get ("action" ) != out .get ("action" )).all ()
210
- if spec_origin is not None :
211
- assert (out .get ("action" ) <= 1.0 ).all (), out .get ("action" ).min ()
212
- assert (out .get ("action" ) >= - 1.0 ).all (), out .get ("action" ).max ()
213
- if action_spec is not None :
214
- assert action_spec .is_in (out .get ("action" ))
201
+ if exploratory_policy .spec is not None :
202
+ for _ in range (n_steps ):
203
+ tensordict_noexp = policy (tensordict .select ("observation" ))
204
+ tensordict = exploratory_policy (tensordict )
205
+ out .append (tensordict .clone ())
206
+ out_noexp .append (tensordict_noexp .clone ())
207
+ tensordict .set_ ("observation" , torch .randn (batch , d_obs , device = device ))
208
+ out = torch .stack (out , 0 )
209
+ out_noexp = torch .stack (out_noexp , 0 )
210
+ assert (out_noexp .get ("action" ) != out .get ("action" )).all ()
211
+ if spec_origin is not None :
212
+ assert (out .get ("action" ) <= 1.0 ).all (), out .get ("action" ).min ()
213
+ assert (out .get ("action" ) >= - 1.0 ).all (), out .get ("action" ).max ()
214
+ if action_spec is not None :
215
+ assert action_spec .is_in (out .get ("action" ))
215
216
216
217
217
218
@pytest .mark .parametrize ("state_dim" , [7 ])
0 commit comments