Skip to content

Commit f6024bf

Browse files
angelayisvekars
andauthored
[export] Add non-strict and derived dims tutorials (#2838)
* [export] Add non-strict and derived dims tutorials * add another derived dim example --------- Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
1 parent 0f40dc9 commit f6024bf

File tree

1 file changed

+109
-7
lines changed

1 file changed

+109
-7
lines changed

intermediate_source/torch_export_tutorial.py

Lines changed: 109 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,12 @@ def forward(self, x, y):
114114
# ------------
115115
#
116116
# Although ``torch.export`` shares components with ``torch.compile``,
117-
# the key limitation of ``torch.export``, especially when compared to ``torch.compile``, is that it does not
118-
# support graph breaks. This is because handling graph breaks involves interpreting
119-
# the unsupported operation with default Python evaluation, which is incompatible
120-
# with the export use case. Therefore, in order to make your model code compatible
121-
# with ``torch.export``, you will need to modify your code to remove graph breaks.
117+
# the key limitation of ``torch.export``, especially when compared to
118+
# ``torch.compile``, is that it does not support graph breaks. This is because
119+
# handling graph breaks involves interpreting the unsupported operation with
120+
# default Python evaluation, which is incompatible with the export use case.
121+
# Therefore, in order to make your model code compatible with ``torch.export``,
122+
# you will need to modify your code to remove graph breaks.
122123
#
123124
# A graph break is necessary in cases such as:
124125
#
@@ -180,8 +181,68 @@ def forward(self, x):
180181
tb.print_exc()
181182

182183
######################################################################
183-
# The sections below demonstrate some ways you can modify your code
184-
# in order to remove graph breaks.
184+
# Non-Strict Export
185+
# -----------------
186+
#
187+
# To trace the program, ``torch.export`` uses TorchDynamo, a byte code analysis
188+
# engine, to symbolically analyze the Python code and build a graph based on the
189+
# results. This analysis allows ``torch.export`` to provide stronger guarantees
190+
# about safety, but not all Python code is supported, causing these graph
191+
# breaks.
192+
#
193+
# To address this issue, in PyTorch 2.3, we introduced a new mode of
194+
# exporting called non-strict mode, where we trace through the program using the
195+
# Python interpreter executing it exactly as it would in eager mode, allowing us
196+
# to skip over unsupported Python features. This is done through adding a
197+
# ``strict=False`` flag.
198+
#
199+
# Looking at some of the previous examples which resulted in graph breaks:
200+
#
201+
# - Accessing tensor data with ``.data`` now works correctly
202+
203+
class Bad2(torch.nn.Module):
204+
def forward(self, x):
205+
x.data[0, 0] = 3
206+
return x
207+
208+
bad2_nonstrict = export(Bad2(), (torch.randn(3, 3),), strict=False)
209+
print(bad2_nonstrict.module()(torch.ones(3, 3)))
210+
211+
######################################################################
212+
# - Calling unsupported functions (such as many built-in functions) traces
213+
# through, but in this case, ``id(x)`` gets specialized as a constant integer in
214+
# the graph. This is because ``id(x)`` is not a tensor operation, so the
215+
# operation is not recorded in the graph.
216+
217+
class Bad3(torch.nn.Module):
218+
def forward(self, x):
219+
x = x + 1
220+
return x + id(x)
221+
222+
bad3_nonstrict = export(Bad3(), (torch.randn(3, 3),), strict=False)
223+
print(bad3_nonstrict)
224+
print(bad3_nonstrict.module()(torch.ones(3, 3)))
225+
226+
######################################################################
227+
# - Unsupported Python language features (such as throwing exceptions, match
228+
# statements) now also get traced through.
229+
230+
class Bad4(torch.nn.Module):
231+
def forward(self, x):
232+
try:
233+
x = x + 1
234+
raise RuntimeError("bad")
235+
except:
236+
x = x + 2
237+
return x
238+
239+
bad4_nonstrict = export(Bad4(), (torch.randn(3, 3),), strict=False)
240+
print(bad4_nonstrict.module()(torch.ones(3, 3)))
241+
242+
243+
######################################################################
244+
# However, there are still some features that require rewrites to the original
245+
# module:
185246

186247
######################################################################
187248
# Control Flow Ops
@@ -365,6 +426,47 @@ def forward(self, x, y):
365426
except Exception:
366427
tb.print_exc()
367428

429+
######################################################################
430+
# We can also describe one dimension in terms of other. There are some
431+
# restrictions to how detailed we can specify one dimension in terms of another,
432+
# but generally, those in the form of ``A * Dim + B`` should work.
433+
434+
class DerivedDimExample1(torch.nn.Module):
435+
def forward(self, x, y):
436+
return x + y[1:]
437+
438+
foo = DerivedDimExample1()
439+
440+
x, y = torch.randn(5), torch.randn(6)
441+
dimx = torch.export.Dim("dimx", min=3, max=6)
442+
dimy = dimx + 1
443+
derived_dynamic_shapes1 = ({0: dimx}, {0: dimy})
444+
445+
derived_dim_example1 = export(foo, (x, y), dynamic_shapes=derived_dynamic_shapes1)
446+
447+
print(derived_dim_example1.module()(torch.randn(4), torch.randn(5)))
448+
449+
try:
450+
derived_dim_example1.module()(torch.randn(4), torch.randn(6))
451+
except Exception:
452+
tb.print_exc()
453+
454+
455+
class DerivedDimExample2(torch.nn.Module):
456+
def forward(self, z, y):
457+
return z[1:] + y[1::3]
458+
459+
foo = DerivedDimExample2()
460+
461+
z, y = torch.randn(4), torch.randn(10)
462+
dx = torch.export.Dim("dx", min=3, max=6)
463+
dz = dx + 1
464+
dy = dx * 3 + 1
465+
derived_dynamic_shapes2 = ({0: dz}, {0: dy})
466+
467+
derived_dim_example2 = export(foo, (z, y), dynamic_shapes=derived_dynamic_shapes2)
468+
print(derived_dim_example2.module()(torch.randn(7), torch.randn(19)))
469+
368470
######################################################################
369471
# We can actually use ``torch.export`` to guide us as to which ``dynamic_shapes`` constraints
370472
# are necessary. We can do this by relaxing all constraints (recall that if we

0 commit comments

Comments
 (0)