Skip to content

Commit 0475cbf

Browse files
author
Vincent Moens
committed
[Performance] Memoize calls to encode and related methods within step
ghstack-source-id: 8acd483 Pull Request resolved: #2907
1 parent c5afe3c commit 0475cbf

File tree

5 files changed

+491
-14
lines changed

5 files changed

+491
-14
lines changed

test/test_specs.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,21 +69,24 @@
6969

7070
pytestmark = [
7171
pytest.mark.filterwarnings("error"),
72+
pytest.mark.filterwarnings("ignore: memoized encoding is an experimental feature"),
7273
]
7374

7475

7576
class TestRanges:
7677
@pytest.mark.parametrize(
7778
"dtype", [torch.float32, torch.float16, torch.float64, None]
7879
)
79-
def test_bounded(self, dtype):
80+
@pytest.mark.parametrize("memo", [True, False])
81+
def test_bounded(self, dtype, memo):
8082
torch.manual_seed(0)
8183
np.random.seed(0)
8284
for _ in range(100):
8385
bounds = torch.randn(2).sort()[0]
8486
ts = Bounded(
8587
bounds[0].item(), bounds[1].item(), torch.Size((1,)), dtype=dtype
8688
)
89+
ts.memoize_encode(mode=memo)
8790
_dtype = dtype
8891
if dtype is None:
8992
_dtype = torch.get_default_dtype()
@@ -93,28 +96,36 @@ def test_bounded(self, dtype):
9396
assert ts.is_in(r)
9497
assert r.dtype is _dtype
9598
ts.is_in(ts.encode(bounds.mean()))
99+
ts.erase_memoize_cache()
96100
ts.is_in(ts.encode(bounds.mean().item()))
101+
ts.erase_memoize_cache()
97102
assert (ts.encode(ts.to_numpy(r)) == r).all()
98103

99104
@pytest.mark.parametrize("cls", [OneHot, Categorical])
100-
def test_discrete(self, cls):
105+
@pytest.mark.parametrize("memo", [True, False])
106+
def test_discrete(self, cls, memo):
101107
torch.manual_seed(0)
102108
np.random.seed(0)
103109

104110
ts = cls(10)
111+
ts.memoize_encode(memo)
105112
for _ in range(100):
106113
r = ts.rand()
107114
assert (ts._project(r) == r).all()
108115
ts.to_numpy(r)
109116
ts.encode(torch.tensor([5]))
117+
ts.erase_memoize_cache()
110118
ts.encode(torch.tensor(5).numpy())
119+
ts.erase_memoize_cache()
111120
ts.encode(9)
112121
with pytest.raises(AssertionError), set_global_var(
113122
torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True
114123
):
124+
ts.erase_memoize_cache()
115125
ts.encode(torch.tensor([11])) # out of bounds
116126
assert not torchrl.data.tensor_specs._CHECK_SPEC_ENCODE
117127
assert ts.is_in(r)
128+
ts.erase_memoize_cache()
118129
assert (ts.encode(ts.to_numpy(r)) == r).all()
119130

120131
@pytest.mark.parametrize(
@@ -139,14 +150,16 @@ def test_unbounded(self, dtype):
139150
"dtype", [torch.float32, torch.float16, torch.float64, None]
140151
)
141152
@pytest.mark.parametrize("shape", [[], torch.Size([3])])
142-
def test_ndbounded(self, dtype, shape):
153+
@pytest.mark.parametrize("memo", [True, False])
154+
def test_ndbounded(self, dtype, shape, memo):
143155
torch.manual_seed(0)
144156
np.random.seed(0)
145157

146158
for _ in range(100):
147159
lb = torch.rand(10) - 1
148160
ub = torch.rand(10) + 1
149161
ts = Bounded(lb, ub, dtype=dtype)
162+
ts.memoize_encode(memo)
150163
_dtype = dtype
151164
if dtype is None:
152165
_dtype = torch.get_default_dtype()
@@ -160,19 +173,23 @@ def test_ndbounded(self, dtype, shape):
160173
).all(), f"{r[r <= lb] - lb.expand_as(r)[r <= lb]} -- {r[r >= ub] - ub.expand_as(r)[r >= ub]} "
161174
ts.to_numpy(r)
162175
assert ts.is_in(r)
176+
ts.erase_memoize_cache()
163177
ts.encode(lb + torch.rand(10) * (ub - lb))
178+
ts.erase_memoize_cache()
164179
ts.encode((lb + torch.rand(10) * (ub - lb)).numpy())
165180

166181
if not shape:
167182
assert (ts.encode(ts.to_numpy(r)) == r).all()
168183
else:
169184
with pytest.raises(RuntimeError, match="Shape mismatch"):
185+
ts.erase_memoize_cache()
170186
ts.encode(ts.to_numpy(r))
171187
assert (ts.expand(*shape, *ts.shape).encode(ts.to_numpy(r)) == r).all()
172188

173189
with pytest.raises(AssertionError), set_global_var(
174190
torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True
175191
):
192+
ts.erase_memoize_cache()
176193
ts.encode(torch.rand(10) + 3) # out of bounds
177194
with pytest.raises(AssertionError), set_global_var(
178195
torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True
@@ -242,10 +259,12 @@ def test_binary(self, n, shape):
242259
],
243260
)
244261
@pytest.mark.parametrize("shape", [(), torch.Size([3])])
245-
def test_mult_onehot(self, shape, ns):
262+
@pytest.mark.parametrize("memo", [True, False])
263+
def test_mult_onehot(self, shape, ns, memo):
246264
torch.manual_seed(0)
247265
np.random.seed(0)
248266
ts = MultiOneHot(nvec=ns)
267+
ts.memoize_encode(memo)
249268
for _ in range(100):
250269
r = ts.rand(shape)
251270
assert (ts._project(r) == r).all()
@@ -260,9 +279,11 @@ def test_mult_onehot(self, shape, ns):
260279
assert not ts.is_in(categorical)
261280
# assert (ts.encode(categorical) == r).all()
262281
if not shape:
282+
ts.erase_memoize_cache()
263283
assert (ts.encode(categorical) == r).all()
264284
else:
265285
with pytest.raises(RuntimeError, match="is invalid for input of size"):
286+
ts.erase_memoize_cache()
266287
ts.encode(categorical)
267288
assert (ts.expand(*shape, *ts.shape).encode(categorical) == r).all()
268289

@@ -455,8 +476,10 @@ def test_del(self, shape, is_complete, device, dtype):
455476
assert "obs" not in ts.keys()
456477
assert "act" in ts.keys()
457478

458-
def test_encode(self, shape, is_complete, device, dtype):
479+
@pytest.mark.parametrize("memo", [True, False])
480+
def test_encode(self, shape, is_complete, device, dtype, memo):
459481
ts = self._composite_spec(shape, is_complete, device, dtype)
482+
ts.memoize_encode(memo)
460483
if dtype is None:
461484
dtype = torch.get_default_dtype()
462485

@@ -465,6 +488,7 @@ def test_encode(self, shape, is_complete, device, dtype):
465488
raw_vals = {"obs": r["obs"].cpu().numpy()}
466489
if is_complete:
467490
raw_vals["act"] = r["act"].cpu().numpy()
491+
ts.erase_memoize_cache()
468492
encoded_vals = ts.encode(raw_vals)
469493

470494
assert encoded_vals["obs"].dtype == dtype

torchrl/collectors/collectors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import numpy as np
2828
import torch
2929
import torch.nn as nn
30+
3031
from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
3132
from tensordict.base import NO_DEFAULT
3233
from tensordict.nn import CudaGraphModule, TensorDictModule

0 commit comments

Comments
 (0)