69
69
70
70
pytestmark = [
71
71
pytest .mark .filterwarnings ("error" ),
72
+ pytest .mark .filterwarnings ("ignore: memoized encoding is an experimental feature" ),
72
73
]
73
74
74
75
75
76
class TestRanges :
76
77
@pytest .mark .parametrize (
77
78
"dtype" , [torch .float32 , torch .float16 , torch .float64 , None ]
78
79
)
79
- def test_bounded (self , dtype ):
80
+ @pytest .mark .parametrize ("memo" , [True , False ])
81
+ def test_bounded (self , dtype , memo ):
80
82
torch .manual_seed (0 )
81
83
np .random .seed (0 )
82
84
for _ in range (100 ):
83
85
bounds = torch .randn (2 ).sort ()[0 ]
84
86
ts = Bounded (
85
87
bounds [0 ].item (), bounds [1 ].item (), torch .Size ((1 ,)), dtype = dtype
86
88
)
89
+ ts .memoize_encode (mode = memo )
87
90
_dtype = dtype
88
91
if dtype is None :
89
92
_dtype = torch .get_default_dtype ()
@@ -93,28 +96,36 @@ def test_bounded(self, dtype):
93
96
assert ts .is_in (r )
94
97
assert r .dtype is _dtype
95
98
ts .is_in (ts .encode (bounds .mean ()))
99
+ ts .erase_memoize_cache ()
96
100
ts .is_in (ts .encode (bounds .mean ().item ()))
101
+ ts .erase_memoize_cache ()
97
102
assert (ts .encode (ts .to_numpy (r )) == r ).all ()
98
103
99
104
@pytest .mark .parametrize ("cls" , [OneHot , Categorical ])
100
- def test_discrete (self , cls ):
105
+ @pytest .mark .parametrize ("memo" , [True , False ])
106
+ def test_discrete (self , cls , memo ):
101
107
torch .manual_seed (0 )
102
108
np .random .seed (0 )
103
109
104
110
ts = cls (10 )
111
+ ts .memoize_encode (memo )
105
112
for _ in range (100 ):
106
113
r = ts .rand ()
107
114
assert (ts ._project (r ) == r ).all ()
108
115
ts .to_numpy (r )
109
116
ts .encode (torch .tensor ([5 ]))
117
+ ts .erase_memoize_cache ()
110
118
ts .encode (torch .tensor (5 ).numpy ())
119
+ ts .erase_memoize_cache ()
111
120
ts .encode (9 )
112
121
with pytest .raises (AssertionError ), set_global_var (
113
122
torchrl .data .tensor_specs , "_CHECK_SPEC_ENCODE" , True
114
123
):
124
+ ts .erase_memoize_cache ()
115
125
ts .encode (torch .tensor ([11 ])) # out of bounds
116
126
assert not torchrl .data .tensor_specs ._CHECK_SPEC_ENCODE
117
127
assert ts .is_in (r )
128
+ ts .erase_memoize_cache ()
118
129
assert (ts .encode (ts .to_numpy (r )) == r ).all ()
119
130
120
131
@pytest .mark .parametrize (
@@ -139,14 +150,16 @@ def test_unbounded(self, dtype):
139
150
"dtype" , [torch .float32 , torch .float16 , torch .float64 , None ]
140
151
)
141
152
@pytest .mark .parametrize ("shape" , [[], torch .Size ([3 ])])
142
- def test_ndbounded (self , dtype , shape ):
153
+ @pytest .mark .parametrize ("memo" , [True , False ])
154
+ def test_ndbounded (self , dtype , shape , memo ):
143
155
torch .manual_seed (0 )
144
156
np .random .seed (0 )
145
157
146
158
for _ in range (100 ):
147
159
lb = torch .rand (10 ) - 1
148
160
ub = torch .rand (10 ) + 1
149
161
ts = Bounded (lb , ub , dtype = dtype )
162
+ ts .memoize_encode (memo )
150
163
_dtype = dtype
151
164
if dtype is None :
152
165
_dtype = torch .get_default_dtype ()
@@ -160,19 +173,23 @@ def test_ndbounded(self, dtype, shape):
160
173
).all (), f"{ r [r <= lb ] - lb .expand_as (r )[r <= lb ]} -- { r [r >= ub ] - ub .expand_as (r )[r >= ub ]} "
161
174
ts .to_numpy (r )
162
175
assert ts .is_in (r )
176
+ ts .erase_memoize_cache ()
163
177
ts .encode (lb + torch .rand (10 ) * (ub - lb ))
178
+ ts .erase_memoize_cache ()
164
179
ts .encode ((lb + torch .rand (10 ) * (ub - lb )).numpy ())
165
180
166
181
if not shape :
167
182
assert (ts .encode (ts .to_numpy (r )) == r ).all ()
168
183
else :
169
184
with pytest .raises (RuntimeError , match = "Shape mismatch" ):
185
+ ts .erase_memoize_cache ()
170
186
ts .encode (ts .to_numpy (r ))
171
187
assert (ts .expand (* shape , * ts .shape ).encode (ts .to_numpy (r )) == r ).all ()
172
188
173
189
with pytest .raises (AssertionError ), set_global_var (
174
190
torchrl .data .tensor_specs , "_CHECK_SPEC_ENCODE" , True
175
191
):
192
+ ts .erase_memoize_cache ()
176
193
ts .encode (torch .rand (10 ) + 3 ) # out of bounds
177
194
with pytest .raises (AssertionError ), set_global_var (
178
195
torchrl .data .tensor_specs , "_CHECK_SPEC_ENCODE" , True
@@ -242,10 +259,12 @@ def test_binary(self, n, shape):
242
259
],
243
260
)
244
261
@pytest .mark .parametrize ("shape" , [(), torch .Size ([3 ])])
245
- def test_mult_onehot (self , shape , ns ):
262
+ @pytest .mark .parametrize ("memo" , [True , False ])
263
+ def test_mult_onehot (self , shape , ns , memo ):
246
264
torch .manual_seed (0 )
247
265
np .random .seed (0 )
248
266
ts = MultiOneHot (nvec = ns )
267
+ ts .memoize_encode (memo )
249
268
for _ in range (100 ):
250
269
r = ts .rand (shape )
251
270
assert (ts ._project (r ) == r ).all ()
@@ -260,9 +279,11 @@ def test_mult_onehot(self, shape, ns):
260
279
assert not ts .is_in (categorical )
261
280
# assert (ts.encode(categorical) == r).all()
262
281
if not shape :
282
+ ts .erase_memoize_cache ()
263
283
assert (ts .encode (categorical ) == r ).all ()
264
284
else :
265
285
with pytest .raises (RuntimeError , match = "is invalid for input of size" ):
286
+ ts .erase_memoize_cache ()
266
287
ts .encode (categorical )
267
288
assert (ts .expand (* shape , * ts .shape ).encode (categorical ) == r ).all ()
268
289
@@ -455,8 +476,10 @@ def test_del(self, shape, is_complete, device, dtype):
455
476
assert "obs" not in ts .keys ()
456
477
assert "act" in ts .keys ()
457
478
458
- def test_encode (self , shape , is_complete , device , dtype ):
479
+ @pytest .mark .parametrize ("memo" , [True , False ])
480
+ def test_encode (self , shape , is_complete , device , dtype , memo ):
459
481
ts = self ._composite_spec (shape , is_complete , device , dtype )
482
+ ts .memoize_encode (memo )
460
483
if dtype is None :
461
484
dtype = torch .get_default_dtype ()
462
485
@@ -465,6 +488,7 @@ def test_encode(self, shape, is_complete, device, dtype):
465
488
raw_vals = {"obs" : r ["obs" ].cpu ().numpy ()}
466
489
if is_complete :
467
490
raw_vals ["act" ] = r ["act" ].cpu ().numpy ()
491
+ ts .erase_memoize_cache ()
468
492
encoded_vals = ts .encode (raw_vals )
469
493
470
494
assert encoded_vals ["obs" ].dtype == dtype
0 commit comments