16
16
from functools import partial
17
17
from sys import platform
18
18
19
+ import numpy as np
20
+
19
21
import pytest
20
22
21
23
import tensordict .tensordict
@@ -2288,7 +2290,7 @@ class TestHash(TransformBase):
2288
2290
def test_transform_no_env (self , datatype ):
2289
2291
if datatype == "tensor" :
2290
2292
obs = torch .tensor (10 )
2291
- hash_fn = hash
2293
+ hash_fn = lambda x : torch . tensor ( hash ( x ))
2292
2294
elif datatype == "str" :
2293
2295
obs = "abcdefg"
2294
2296
hash_fn = Hash .reproducible_hash
@@ -2302,6 +2304,7 @@ def test_transform_no_env(self, datatype):
2302
2304
)
2303
2305
2304
2306
def fn0 (x ):
2307
+ # return tuple([tuple(Hash.reproducible_hash(x_).tolist()) for x_ in x])
2305
2308
return torch .stack ([Hash .reproducible_hash (x_ ) for x_ in x ])
2306
2309
2307
2310
hash_fn = fn0
@@ -2334,7 +2337,7 @@ def test_single_trans_env_check(self, datatype):
2334
2337
t = Hash (
2335
2338
in_keys = ["observation" ],
2336
2339
out_keys = ["hashing" ],
2337
- hash_fn = hash ,
2340
+ hash_fn = lambda x : torch . tensor ( hash ( x )) ,
2338
2341
)
2339
2342
base_env = CountingEnv ()
2340
2343
elif datatype == "str" :
@@ -2353,7 +2356,7 @@ def make_env():
2353
2356
t = Hash (
2354
2357
in_keys = ["observation" ],
2355
2358
out_keys = ["hashing" ],
2356
- hash_fn = hash ,
2359
+ hash_fn = lambda x : torch . tensor ( hash ( x )) ,
2357
2360
)
2358
2361
base_env = CountingEnv ()
2359
2362
@@ -2376,7 +2379,7 @@ def make_env():
2376
2379
t = Hash (
2377
2380
in_keys = ["observation" ],
2378
2381
out_keys = ["hashing" ],
2379
- hash_fn = hash ,
2382
+ hash_fn = lambda x : torch . tensor ( hash ( x )) ,
2380
2383
)
2381
2384
base_env = CountingEnv ()
2382
2385
elif datatype == "str" :
@@ -2402,7 +2405,7 @@ def test_trans_serial_env_check(self, datatype):
2402
2405
t = Hash (
2403
2406
in_keys = ["observation" ],
2404
2407
out_keys = ["hashing" ],
2405
- hash_fn = lambda x : [hash (x [0 ]), hash (x [1 ])],
2408
+ hash_fn = lambda x : torch . tensor ( [hash (x [0 ]), hash (x [1 ])]) ,
2406
2409
)
2407
2410
base_env = CountingEnv
2408
2411
elif datatype == "str" :
@@ -2422,7 +2425,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype):
2422
2425
t = Hash (
2423
2426
in_keys = ["observation" ],
2424
2427
out_keys = ["hashing" ],
2425
- hash_fn = lambda x : [hash (x [0 ]), hash (x [1 ])],
2428
+ hash_fn = lambda x : torch . tensor ( [hash (x [0 ]), hash (x [1 ])]) ,
2426
2429
)
2427
2430
base_env = CountingEnv
2428
2431
elif datatype == "str" :
@@ -2457,7 +2460,7 @@ def test_transform_compose(self, datatype):
2457
2460
t = Hash (
2458
2461
in_keys = ["observation" ],
2459
2462
out_keys = ["hashing" ],
2460
- hash_fn = hash ,
2463
+ hash_fn = lambda x : torch . tensor ( hash ( x )) ,
2461
2464
)
2462
2465
t = Compose (t )
2463
2466
td_hashed = t (td )
@@ -2469,7 +2472,7 @@ def test_transform_model(self):
2469
2472
t = Hash (
2470
2473
in_keys = [("next" , "observation" ), ("observation" ,)],
2471
2474
out_keys = [("next" , "hashing" ), ("hashing" ,)],
2472
- hash_fn = hash ,
2475
+ hash_fn = lambda x : torch . tensor ( hash ( x )) ,
2473
2476
)
2474
2477
model = nn .Sequential (t , nn .Identity ())
2475
2478
td = TensorDict (
@@ -2486,7 +2489,7 @@ def test_transform_env(self):
2486
2489
t = Hash (
2487
2490
in_keys = ["observation" ],
2488
2491
out_keys = ["hashing" ],
2489
- hash_fn = hash ,
2492
+ hash_fn = lambda x : torch . tensor ( hash ( x )) ,
2490
2493
)
2491
2494
env = TransformedEnv (GymEnv (PENDULUM_VERSIONED ()), t )
2492
2495
assert env .observation_spec ["hashing" ]
@@ -2499,7 +2502,7 @@ def test_transform_rb(self, rbclass):
2499
2502
t = Hash (
2500
2503
in_keys = [("next" , "observation" ), ("observation" ,)],
2501
2504
out_keys = [("next" , "hashing" ), ("hashing" ,)],
2502
- hash_fn = lambda x : [hash (x [0 ]), hash (x [1 ])],
2505
+ hash_fn = lambda x : torch . tensor ( [hash (x [0 ]), hash (x [1 ])]) ,
2503
2506
)
2504
2507
rb = rbclass (storage = LazyTensorStorage (10 ))
2505
2508
rb .append_transform (t )
@@ -2519,18 +2522,73 @@ def test_transform_rb(self, rbclass):
2519
2522
assert "observation" in td .keys ()
2520
2523
assert ("next" , "observation" ) in td .keys (True )
2521
2524
2522
- def test_transform_inverse (self ):
2523
- return
2524
- env = CountingEnv ()
2525
- with pytest .raises (TypeError ):
2526
- env = env .append_transform (
2527
- Hash (
2528
- in_keys = [],
2529
- out_keys = [],
2530
- in_keys_inv = ["action" ],
2531
- out_keys_inv = ["action_hash" ],
2532
- )
2533
- )
2525
+ @pytest .mark .parametrize ("repertoire_gen" , [lambda : None , lambda : {}])
2526
+ def test_transform_inverse (self , repertoire_gen ):
2527
+ repertoire = repertoire_gen ()
2528
+ t = Hash (
2529
+ in_keys = ["observation" ],
2530
+ out_keys = ["hashing" ],
2531
+ in_keys_inv = ["observation" ],
2532
+ out_keys_inv = ["hashing" ],
2533
+ repertoire = repertoire ,
2534
+ )
2535
+ inputs = [
2536
+ TensorDict ({"observation" : "test string" }),
2537
+ TensorDict ({"observation" : torch .randn (10 )}),
2538
+ TensorDict ({"observation" : "another string" }),
2539
+ TensorDict ({"observation" : torch .randn (3 , 2 , 1 , 8 )}),
2540
+ ]
2541
+ outputs = [t (input .clone ()).exclude ("observation" ) for input in inputs ]
2542
+
2543
+ # Run the inputs through again, just to make sure that using the same
2544
+ # inputs doesn't overwrite the repertoire.
2545
+ for input in inputs :
2546
+ t (input .clone ())
2547
+
2548
+ assert len (t ._repertoire ) == 4
2549
+
2550
+ inv_inputs = [t .inv (output .clone ()) for output in outputs ]
2551
+
2552
+ for input , inv_input in zip (inputs , inv_inputs ):
2553
+ if torch .is_tensor (input ["observation" ]):
2554
+ assert (input ["observation" ] == inv_input ["observation" ]).all ()
2555
+ else :
2556
+ assert input ["observation" ] == inv_input ["observation" ]
2557
+
2558
+ @pytest .mark .parametrize ("repertoire_gen" , [lambda : None , lambda : {}])
2559
+ def test_repertoire (self , repertoire_gen ):
2560
+ repertoire = repertoire_gen ()
2561
+ t = Hash (in_keys = ["observation" ], out_keys = ["hashing" ], repertoire = repertoire )
2562
+ inputs = [
2563
+ "string" ,
2564
+ ["a" , "b" ],
2565
+ torch .randn (3 , 4 , 1 ),
2566
+ torch .randn (()),
2567
+ torch .randn (0 ),
2568
+ 1234 ,
2569
+ [1 , 2 , 3 , 4 ],
2570
+ ]
2571
+ outputs = []
2572
+
2573
+ for input in inputs :
2574
+ td = TensorDict ({"observation" : input })
2575
+ outputs .append (t (td .clone ()).clone ()["hashing" ])
2576
+
2577
+ for output , input in zip (outputs , inputs ):
2578
+ if repertoire is not None :
2579
+ stored_input = repertoire [t .hash_to_repertoire_key (output )]
2580
+ assert stored_input is t .get_input_from_hash (output )
2581
+
2582
+ if torch .is_tensor (stored_input ):
2583
+ assert (stored_input == torch .as_tensor (input )).all ()
2584
+ elif isinstance (stored_input , np .ndarray ):
2585
+ assert (stored_input == np .asarray (input )).all ()
2586
+
2587
+ else :
2588
+ assert stored_input == input
2589
+ else :
2590
+ with pytest .raises (RuntimeError ):
2591
+ stored_input = t .get_input_from_hash (output )
2534
2592
2535
2593
2536
2594
@pytest .mark .skipif (
0 commit comments