Skip to content

Commit 5d72561

Browse files
author
Vincent Moens
committed
[Refactor] TransformersWrapper class
ghstack-source-id: 8d54426 Pull Request resolved: #2871
1 parent 1ee0a83 commit 5d72561

File tree

8 files changed

+746
-39
lines changed

8 files changed

+746
-39
lines changed

pytree.ipynb

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"id": "5f53cf70-25e6-4802-a5fe-cdaacf6deff7",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"from torch.utils._cxx_pytree import tree_map, tree_leaves, tree_flatten"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": 2,
16+
"id": "2111a53d-0714-42bf-9051-4eaee5d8a86c",
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"from tensordict import TensorDict, lazy_stack, is_tensor_collection\n",
21+
"import torch\n",
22+
"from tensordict._pytree import *"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": 3,
28+
"id": "951e96a4-4a8c-432a-80e3-c1d30d165ab7",
29+
"metadata": {},
30+
"outputs": [
31+
{
32+
"data": {
33+
"text/plain": [
34+
"99"
35+
]
36+
},
37+
"execution_count": 3,
38+
"metadata": {},
39+
"output_type": "execute_result"
40+
}
41+
],
42+
"source": [
43+
"d_ = d = {}\n",
44+
"for _ in range(100):\n",
45+
" newd = {}\n",
46+
" d_[\"a\"] = newd\n",
47+
" d_[\"t\"] = torch.zeros((1,))\n",
48+
" d_ = newd\n",
49+
"td = TensorDict(d, batch_size=(1,))\n",
50+
"td.depth"
51+
]
52+
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": 4,
56+
"id": "53df5c9e-aac3-4b84-95ed-9619273da95f",
57+
"metadata": {},
58+
"outputs": [
59+
{
60+
"name": "stdout",
61+
"output_type": "stream",
62+
"text": [
63+
"581 μs ± 9.34 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
64+
]
65+
}
66+
],
67+
"source": [
68+
"%%timeit\n",
69+
"tree_map(lambda x: x+1, td)"
70+
]
71+
},
72+
{
73+
"cell_type": "code",
74+
"execution_count": 5,
75+
"id": "7ff87f96-00c3-4134-926e-d2dba44d0b11",
76+
"metadata": {},
77+
"outputs": [
78+
{
79+
"name": "stdout",
80+
"output_type": "stream",
81+
"text": [
82+
"2.32 ms ± 37.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
83+
]
84+
}
85+
],
86+
"source": [
87+
"%%timeit\n",
88+
"td + 1"
89+
]
90+
},
91+
{
92+
"cell_type": "code",
93+
"execution_count": 6,
94+
"id": "5e370705-afe2-44d0-bb79-088f4b4c7c75",
95+
"metadata": {},
96+
"outputs": [
97+
{
98+
"name": "stdout",
99+
"output_type": "stream",
100+
"text": [
101+
"694 μs ± 9.36 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
102+
]
103+
}
104+
],
105+
"source": [
106+
"%%timeit\n",
107+
"td.apply(lambda x: x+1)"
108+
]
109+
},
110+
{
111+
"cell_type": "code",
112+
"execution_count": 7,
113+
"id": "e3f93951-7be8-4be5-a385-735d456dc9d5",
114+
"metadata": {},
115+
"outputs": [
116+
{
117+
"data": {
118+
"text/plain": [
119+
"torch.Size([1])"
120+
]
121+
},
122+
"execution_count": 7,
123+
"metadata": {},
124+
"output_type": "execute_result"
125+
}
126+
],
127+
"source": [
128+
"tree_map(lambda x: x+1, td).batch_size"
129+
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": 8,
134+
"id": "6a76f7f9-a813-4cba-92c8-a90240050a99",
135+
"metadata": {},
136+
"outputs": [],
137+
"source": [
138+
"assert (tree_map(lambda x: x+1, td) == 1).all()"
139+
]
140+
},
141+
{
142+
"cell_type": "code",
143+
"execution_count": 9,
144+
"id": "8b6a46c1-6dbc-4aa4-868c-6891453efb32",
145+
"metadata": {},
146+
"outputs": [
147+
{
148+
"name": "stdout",
149+
"output_type": "stream",
150+
"text": [
151+
"214 μs ± 2.42 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
152+
]
153+
}
154+
],
155+
"source": [
156+
"%%timeit\n",
157+
"tree_flatten(td)"
158+
]
159+
},
160+
{
161+
"cell_type": "code",
162+
"execution_count": 10,
163+
"id": "68fbb5c4-1d89-4620-ac27-8761b97b18e8",
164+
"metadata": {},
165+
"outputs": [
166+
{
167+
"name": "stdout",
168+
"output_type": "stream",
169+
"text": [
170+
"287 μs ± 7.88 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n"
171+
]
172+
}
173+
],
174+
"source": [
175+
"%%timeit\n",
176+
"list(td.values(True, True))"
177+
]
178+
},
179+
{
180+
"cell_type": "code",
181+
"execution_count": 11,
182+
"id": "1fa58391-21b6-4f18-9b20-dd4165f7c877",
183+
"metadata": {},
184+
"outputs": [],
185+
"source": [
186+
"d_ = d = {}\n",
187+
"for _ in range(10):\n",
188+
" newd = {}\n",
189+
" d_[\"a\"] = newd\n",
190+
" d_[\"t\"] = torch.zeros((1,))\n",
191+
" d_ = newd\n",
192+
"tdls = TensorDict(d, batch_size=(1,))\n",
193+
"tdls.depth\n",
194+
"\n",
195+
"tdls = lazy_stack([tdls.clone() for _ in range(100)])"
196+
]
197+
},
198+
{
199+
"cell_type": "code",
200+
"execution_count": 12,
201+
"id": "d87daadc-0a5e-46c6-8382-f6599d199d36",
202+
"metadata": {},
203+
"outputs": [
204+
{
205+
"data": {
206+
"text/plain": [
207+
"torch.Size([100, 1])"
208+
]
209+
},
210+
"execution_count": 12,
211+
"metadata": {},
212+
"output_type": "execute_result"
213+
}
214+
],
215+
"source": [
216+
"tree_map(lambda x: x+1, tdls).batch_size"
217+
]
218+
},
219+
{
220+
"cell_type": "code",
221+
"execution_count": 13,
222+
"id": "015dbd32-cf97-4b60-9d56-550f82e7e238",
223+
"metadata": {},
224+
"outputs": [
225+
{
226+
"name": "stdout",
227+
"output_type": "stream",
228+
"text": [
229+
"6.07 ms ± 75.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
230+
]
231+
}
232+
],
233+
"source": [
234+
"%%timeit\n",
235+
"tree_map(lambda x: x+1, tdls)"
236+
]
237+
},
238+
{
239+
"cell_type": "code",
240+
"execution_count": 14,
241+
"id": "7cff3ea1-239c-4cab-9d47-2a7f31862b74",
242+
"metadata": {},
243+
"outputs": [
244+
{
245+
"name": "stdout",
246+
"output_type": "stream",
247+
"text": [
248+
"6.29 ms ± 101 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
249+
]
250+
}
251+
],
252+
"source": [
253+
"%%timeit\n",
254+
"tdls + 1"
255+
]
256+
},
257+
{
258+
"cell_type": "code",
259+
"execution_count": 15,
260+
"id": "e2a0a78c-0401-4d6c-92c7-b3ddbe74de42",
261+
"metadata": {},
262+
"outputs": [
263+
{
264+
"name": "stdout",
265+
"output_type": "stream",
266+
"text": [
267+
"7.1 ms ± 130 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
268+
]
269+
}
270+
],
271+
"source": [
272+
"%%timeit\n",
273+
"tdls.apply(lambda x: x+1)"
274+
]
275+
},
276+
{
277+
"cell_type": "code",
278+
"execution_count": null,
279+
"id": "2fc590d1-0774-4f79-90a2-98fe6b3ac2c3",
280+
"metadata": {},
281+
"outputs": [],
282+
"source": [
283+
"%%timeit\n",
284+
"tree_flatten(tdls)"
285+
]
286+
},
287+
{
288+
"cell_type": "code",
289+
"execution_count": null,
290+
"id": "579a2a89-802b-4f5b-9354-7a9cbb5530f5",
291+
"metadata": {},
292+
"outputs": [],
293+
"source": [
294+
"%%timeit\n",
295+
"list(tdls.values(True, True))"
296+
]
297+
},
298+
{
299+
"cell_type": "code",
300+
"execution_count": null,
301+
"id": "01f21ad4-57f6-4523-9b48-eb2697c9f8ec",
302+
"metadata": {},
303+
"outputs": [],
304+
"source": []
305+
}
306+
],
307+
"metadata": {
308+
"kernelspec": {
309+
"display_name": "Python 3 (ipykernel)",
310+
"language": "python",
311+
"name": "python3"
312+
},
313+
"language_info": {
314+
"codemirror_mode": {
315+
"name": "ipython",
316+
"version": 3
317+
},
318+
"file_extension": ".py",
319+
"mimetype": "text/x-python",
320+
"name": "python",
321+
"nbconvert_exporter": "python",
322+
"pygments_lexer": "ipython3",
323+
"version": "3.10.16"
324+
}
325+
},
326+
"nbformat": 4,
327+
"nbformat_minor": 5
328+
}

test/test_actors.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
from torchrl.data.llm.dataset import _has_transformers
2323
from torchrl.envs import LLMEnv
2424
from torchrl.modules import (
25-
from_hf_transformers,
2625
MLP,
2726
SafeModule,
2827
TanhDelta,
2928
TanhNormal,
29+
TransformersWrapper,
3030
vLLMWrapper,
3131
)
3232
from torchrl.modules.tensordict_module.actors import (
@@ -961,7 +961,7 @@ def vllm_instance(self):
961961
(False, True, False, torch.randint(1024, (1, 10)), None),
962962
],
963963
)
964-
def test_from_hf_transformers(
964+
def test_TransformersWrapper(
965965
self, from_text, generate, return_log_probs, tokens, attention_mask
966966
):
967967
torch.manual_seed(0)
@@ -978,7 +978,7 @@ def test_from_hf_transformers(
978978
tokenizer.pad_token = tokenizer.eos_token
979979
tokenizer.padding_side = "left"
980980

981-
m = from_hf_transformers(
981+
m = TransformersWrapper(
982982
model,
983983
tokenizer=tokenizer,
984984
from_text=from_text,
@@ -1173,14 +1173,14 @@ def test_from_hf_logprobs(self, from_text, tokens, attention_mask):
11731173
tokenizer.pad_token = tokenizer.eos_token
11741174
tokenizer.padding_side = "left"
11751175

1176-
m_generate = from_hf_transformers(
1176+
m_generate = TransformersWrapper(
11771177
model,
11781178
tokenizer=tokenizer,
11791179
from_text=from_text,
11801180
generate=True,
11811181
return_log_probs=True,
11821182
)
1183-
m_logprobs = from_hf_transformers(
1183+
m_logprobs = TransformersWrapper(
11841184
model, tokenizer=tokenizer, from_text=from_text, generate=False
11851185
)
11861186
self._check_lps(

torchrl/modules/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
)
9494
from .utils import get_primers_from_module
9595
from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip
96-
from .llm import from_hf_transformers, vLLMWrapper
96+
from .llm import TransformersWrapper, vLLMWrapper
9797

9898
__all__ = [
9999
"Actor",
@@ -177,7 +177,7 @@
177177
"VmapModule",
178178
"WorldModelWrapper",
179179
"distributions_maps",
180-
"from_hf_transformers",
180+
"TransformersWrapper",
181181
"vLLMWrapper",
182182
"get_primers_from_module",
183183
"recurrent_mode",

0 commit comments

Comments
 (0)