Skip to content

Commit 5c4b8bd

Browse files
committed
Automated tutorials push
1 parent 2c88aec commit 5c4b8bd

File tree

216 files changed

+17987
-15715
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

216 files changed

+17987
-15715
lines changed

_downloads/0bd6b9a8e47e1d64e4d20ef356a6095d/onnx_registry_tutorial.ipynb

Lines changed: 214 additions & 248 deletions
Large diffs are not rendered by default.
Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {
7+
"collapsed": false
8+
},
9+
"outputs": [],
10+
"source": [
11+
"# For tips on running notebooks in Google Colab, see\n",
12+
"# https://pytorch.org/tutorials/beginner/colab\n",
13+
"%matplotlib inline"
14+
]
15+
},
16+
{
17+
"cell_type": "markdown",
18+
"metadata": {},
19+
"source": [
20+
"[Introduction to ONNX](intro_onnx.html) \\|\\| [Exporting a PyTorch model\n",
21+
"to ONNX](export_simple_model_to_onnx_tutorial.html) \\|\\| [Extending the\n",
22+
"ONNX exporter operator support](onnx_registry_tutorial.html) \\|\\|\n",
23+
"**\\`Export a model with control flow to ONNX**\n",
24+
"\n",
25+
"Export a model with control flow to ONNX\n",
26+
"========================================\n",
27+
"\n",
28+
"**Author**: [Xavier Dupré](https://github.com/xadupre)\n"
29+
]
30+
},
31+
{
32+
"cell_type": "markdown",
33+
"metadata": {},
34+
"source": [
35+
"Overview\n",
36+
"========\n",
37+
"\n",
38+
"This tutorial demonstrates how to handle control flow logic while\n",
39+
"exporting a PyTorch model to ONNX. It highlights the challenges of\n",
40+
"exporting conditional statements directly and provides solutions to\n",
41+
"circumvent them.\n",
42+
"\n",
43+
"Conditional logic cannot be exported into ONNX unless they refactored to\n",
44+
"use `torch.cond`{.interpreted-text role=\"func\"}. Let\\'s start with a\n",
45+
"simple model implementing a test.\n",
46+
"\n",
47+
"What you will learn:\n",
48+
"\n",
49+
"- How to refactor the model to use `torch.cond`{.interpreted-text\n",
50+
" role=\"func\"} for exporting.\n",
51+
"- How to export a model with control flow logic to ONNX.\n",
52+
"- How to optimize the exported model using the ONNX optimizer.\n",
53+
"\n",
54+
"Prerequisites\n",
55+
"-------------\n",
56+
"\n",
57+
"- `torch >= 2.6`\n"
58+
]
59+
},
60+
{
61+
"cell_type": "code",
62+
"execution_count": null,
63+
"metadata": {
64+
"collapsed": false
65+
},
66+
"outputs": [],
67+
"source": [
68+
"import torch"
69+
]
70+
},
71+
{
72+
"cell_type": "markdown",
73+
"metadata": {},
74+
"source": [
75+
"Define the Models\n",
76+
"=================\n",
77+
"\n",
78+
"Two models are defined:\n",
79+
"\n",
80+
"`ForwardWithControlFlowTest`: A model with a forward method containing\n",
81+
"an if-else conditional.\n",
82+
"\n",
83+
"`ModelWithControlFlowTest`: A model that incorporates\n",
84+
"`ForwardWithControlFlowTest` as part of a simple MLP. The models are\n",
85+
"tested with a random input tensor to confirm they execute as expected.\n"
86+
]
87+
},
88+
{
89+
"cell_type": "code",
90+
"execution_count": null,
91+
"metadata": {
92+
"collapsed": false
93+
},
94+
"outputs": [],
95+
"source": [
96+
"class ForwardWithControlFlowTest(torch.nn.Module):\n",
97+
" def forward(self, x):\n",
98+
" if x.sum():\n",
99+
" return x * 2\n",
100+
" return -x\n",
101+
"\n",
102+
"\n",
103+
"class ModelWithControlFlowTest(torch.nn.Module):\n",
104+
" def __init__(self):\n",
105+
" super().__init__()\n",
106+
" self.mlp = torch.nn.Sequential(\n",
107+
" torch.nn.Linear(3, 2),\n",
108+
" torch.nn.Linear(2, 1),\n",
109+
" ForwardWithControlFlowTest(),\n",
110+
" )\n",
111+
"\n",
112+
" def forward(self, x):\n",
113+
" out = self.mlp(x)\n",
114+
" return out\n",
115+
"\n",
116+
"\n",
117+
"model = ModelWithControlFlowTest()"
118+
]
119+
},
120+
{
121+
"cell_type": "markdown",
122+
"metadata": {},
123+
"source": [
124+
"Exporting the Model: First Attempt\n",
125+
"==================================\n",
126+
"\n",
127+
"Exporting this model using torch.export.export fails because the control\n",
128+
"flow logic in the forward pass creates a graph break that the exporter\n",
129+
"cannot handle. This behavior is expected, as conditional logic not\n",
130+
"written using `torch.cond`{.interpreted-text role=\"func\"} is\n",
131+
"unsupported.\n",
132+
"\n",
133+
"A try-except block is used to capture the expected failure during the\n",
134+
"export process. If the export unexpectedly succeeds, an `AssertionError`\n",
135+
"is raised.\n"
136+
]
137+
},
138+
{
139+
"cell_type": "code",
140+
"execution_count": null,
141+
"metadata": {
142+
"collapsed": false
143+
},
144+
"outputs": [],
145+
"source": [
146+
"x = torch.randn(3)\n",
147+
"model(x)\n",
148+
"\n",
149+
"try:\n",
150+
" torch.export.export(model, (x,), strict=False)\n",
151+
" raise AssertionError(\"This export should failed unless PyTorch now supports this model.\")\n",
152+
"except Exception as e:\n",
153+
" print(e)"
154+
]
155+
},
156+
{
157+
"cell_type": "markdown",
158+
"metadata": {},
159+
"source": [
160+
"Using `torch.onnx.export`{.interpreted-text role=\"func\"} with JIT\n",
161+
"Tracing\n",
162+
"\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\--\n",
163+
"\n",
164+
"When exporting the model using `torch.onnx.export`{.interpreted-text\n",
165+
"role=\"func\"} with the dynamo=True argument, the exporter defaults to\n",
166+
"using JIT tracing. This fallback allows the model to export, but the\n",
167+
"resulting ONNX graph may not faithfully represent the original model\n",
168+
"logic due to the limitations of tracing.\n"
169+
]
170+
},
171+
{
172+
"cell_type": "code",
173+
"execution_count": null,
174+
"metadata": {
175+
"collapsed": false
176+
},
177+
"outputs": [],
178+
"source": [
179+
"onnx_program = torch.onnx.export(model, (x,), dynamo=True) \n",
180+
"print(onnx_program.model)"
181+
]
182+
},
183+
{
184+
"cell_type": "markdown",
185+
"metadata": {},
186+
"source": [
187+
"Suggested Patch: Refactoring with `torch.cond`{.interpreted-text\n",
188+
"role=\"func\"}\n",
189+
"\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\--\n",
190+
"\n",
191+
"To make the control flow exportable, the tutorial demonstrates replacing\n",
192+
"the forward method in `ForwardWithControlFlowTest` with a refactored\n",
193+
"version that uses `torch.cond`{.interpreted-text role=\"func\"}\\`.\n",
194+
"\n",
195+
"Details of the Refactoring:\n",
196+
"\n",
197+
"Two helper functions (identity2 and neg) represent the branches of the\n",
198+
"conditional logic: \\* `torch.cond`{.interpreted-text role=\"func\"}[ is\n",
199+
"used to specify the condition and the two branches along with the input\n",
200+
"arguments. \\* The updated forward method is then dynamically assigned to\n",
201+
"the ]{.title-ref}[ForwardWithControlFlowTest]{.title-ref}\\` instance\n",
202+
"within the model. A list of submodules is printed to confirm the\n",
203+
"replacement.\n"
204+
]
205+
},
206+
{
207+
"cell_type": "code",
208+
"execution_count": null,
209+
"metadata": {
210+
"collapsed": false
211+
},
212+
"outputs": [],
213+
"source": [
214+
"def new_forward(x):\n",
215+
" def identity2(x):\n",
216+
" return x * 2\n",
217+
"\n",
218+
" def neg(x):\n",
219+
" return -x\n",
220+
"\n",
221+
" return torch.cond(x.sum() > 0, identity2, neg, (x,))\n",
222+
"\n",
223+
"\n",
224+
"print(\"the list of submodules\")\n",
225+
"for name, mod in model.named_modules():\n",
226+
" print(name, type(mod))\n",
227+
" if isinstance(mod, ForwardWithControlFlowTest):\n",
228+
" mod.forward = new_forward"
229+
]
230+
},
231+
{
232+
"cell_type": "markdown",
233+
"metadata": {},
234+
"source": [
235+
"Let\\'s see what the FX graph looks like.\n"
236+
]
237+
},
238+
{
239+
"cell_type": "code",
240+
"execution_count": null,
241+
"metadata": {
242+
"collapsed": false
243+
},
244+
"outputs": [],
245+
"source": [
246+
"print(torch.export.export(model, (x,), strict=False))"
247+
]
248+
},
249+
{
250+
"cell_type": "markdown",
251+
"metadata": {},
252+
"source": [
253+
"Let\\'s export again.\n"
254+
]
255+
},
256+
{
257+
"cell_type": "code",
258+
"execution_count": null,
259+
"metadata": {
260+
"collapsed": false
261+
},
262+
"outputs": [],
263+
"source": [
264+
"onnx_program = torch.onnx.export(model, (x,), dynamo=True) \n",
265+
"print(onnx_program.model)"
266+
]
267+
},
268+
{
269+
"cell_type": "markdown",
270+
"metadata": {},
271+
"source": [
272+
"We can optimize the model and get rid of the model local functions\n",
273+
"created to capture the control flow branches.\n"
274+
]
275+
},
276+
{
277+
"cell_type": "code",
278+
"execution_count": null,
279+
"metadata": {
280+
"collapsed": false
281+
},
282+
"outputs": [],
283+
"source": [
284+
"onnx_program.optimize() \n",
285+
"print(onnx_program.model)"
286+
]
287+
},
288+
{
289+
"cell_type": "markdown",
290+
"metadata": {},
291+
"source": [
292+
"Conclusion\n",
293+
"==========\n",
294+
"\n",
295+
"This tutorial demonstrates the challenges of exporting models with\n",
296+
"conditional logic to ONNX and presents a practical solution using\n",
297+
"`torch.cond`{.interpreted-text role=\"func\"}. While the default exporters\n",
298+
"may fail or produce imperfect graphs, refactoring the model\\'s logic\n",
299+
"ensures compatibility and generates a faithful ONNX representation.\n",
300+
"\n",
301+
"By understanding these techniques, we can overcome common pitfalls when\n",
302+
"working with control flow in PyTorch models and ensure smooth\n",
303+
"integration with ONNX workflows.\n",
304+
"\n",
305+
"Further reading\n",
306+
"===============\n",
307+
"\n",
308+
"The list below refers to tutorials that ranges from basic examples to\n",
309+
"advanced scenarios, not necessarily in the order they are listed. Feel\n",
310+
"free to jump directly to specific topics of your interest or sit tight\n",
311+
"and have fun going through all of them to learn all there is about the\n",
312+
"ONNX exporter.\n",
313+
"\n",
314+
"::: {.toctree hidden=\"\"}\n",
315+
":::\n"
316+
]
317+
}
318+
],
319+
"metadata": {
320+
"kernelspec": {
321+
"display_name": "Python 3",
322+
"language": "python",
323+
"name": "python3"
324+
},
325+
"language_info": {
326+
"codemirror_mode": {
327+
"name": "ipython",
328+
"version": 3
329+
},
330+
"file_extension": ".py",
331+
"mimetype": "text/x-python",
332+
"name": "python",
333+
"nbconvert_exporter": "python",
334+
"pygments_lexer": "ipython3",
335+
"version": "3.10.12"
336+
}
337+
},
338+
"nbformat": 4,
339+
"nbformat_minor": 0
340+
}

0 commit comments

Comments
 (0)