Skip to content

Commit 44d1dd3

Browse files
authored
increase torch dynamo cache size limit to support all tests (#2470)
increase torch dynamo cache size limit to support all tests (#2470) Summary: Some tests was suppressed due to cache size limitation. This diff increase the cache size to unblock the issue. Reviewed By: jerryzh168 Differential Revision: D76794879
1 parent 5fa5e4c commit 44d1dd3

File tree

2 files changed

+10
-15
lines changed

2 files changed

+10
-15
lines changed

test/quantization/pt2e/test_numeric_debugger.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,17 @@
2323
if TORCH_VERSION_AT_LEAST_2_8:
2424
from torch.export import export_for_training
2525

26+
# Increase cache size limit to avoid FailOnRecompileLimitHit error when running multiple tests
27+
# that use export_for_training, which causes many dynamo recompilations
28+
if TORCH_VERSION_AT_LEAST_2_8:
29+
torch._dynamo.config.cache_size_limit = 128
30+
2631

2732
@unittest.skipIf(
2833
not TORCH_VERSION_AT_LEAST_2_8, "Requires torch 2.8 and above, including nightly"
2934
)
3035
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
3136
class TestNumericDebuggerInfra(PT2ENumericDebuggerTestCase):
32-
@unittest.skip(
33-
"torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..."
34-
)
3537
def test_simple(self):
3638
m = TestHelperModules.Conv2dThenConv1d()
3739
example_inputs = m.example_inputs()
@@ -88,9 +90,6 @@ def test_deepcopy_preserve_handle(self):
8890
set(from_node_source_map.values()), set(from_node_source_map_ref.values())
8991
)
9092

91-
@unittest.skip(
92-
"torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..."
93-
)
9493
def test_re_export_preserve_handle(self):
9594
m = TestHelperModules.Conv2dThenConv1d()
9695
example_inputs = m.example_inputs()
@@ -108,9 +107,6 @@ def test_re_export_preserve_handle(self):
108107

109108
self.assertEqual(from_node_source_map, from_node_source_map_ref)
110109

111-
@unittest.skip(
112-
"torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..."
113-
)
114110
def test_run_decompositions_same_handle_id(self):
115111
m = TestHelperModules.Conv2dThenConv1d()
116112
example_inputs = m.example_inputs()
@@ -132,9 +128,6 @@ def test_run_decompositions_same_handle_id(self):
132128
set(from_node_source_map.values()), set(from_node_source_map_ref.values())
133129
)
134130

135-
@unittest.skip(
136-
"torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..."
137-
)
138131
def test_run_decompositions_map_handle_to_new_nodes(self):
139132
test_models = [
140133
TestHelperModules.TwoLinearModule(),

torchao/testing/pt2e/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def _extract_from_node_source_with_prev_decomp_op_from_node(node):
180180
nonlocal prev_decomp_op_to_from_node_source_map
181181
if FROM_NODE_KEY in node.meta and node.meta[FROM_NODE_KEY] is not None:
182182
prev_decomp_op = str(node.meta.get("nn_module_stack"))
183-
from_node_source = node.meta[FROM_NODE_KEY]
183+
from_node_source = _extract_node_source_debug_info(node)
184184
if prev_decomp_op not in prev_decomp_op_to_from_node_source_map:
185185
prev_decomp_op_to_from_node_source_map[prev_decomp_op] = (
186186
from_node_source
@@ -189,8 +189,10 @@ def _extract_from_node_source_with_prev_decomp_op_from_node(node):
189189
assert (
190190
prev_decomp_op_to_from_node_source_map[prev_decomp_op]
191191
== from_node_source
192-
), f"Node {node} has different from_node info {from_node_source}"
193-
"than previous node sharing the same decomp op {prev_decomp_op}"
192+
), (
193+
f"Node {node} has different from_node info {from_node_source}"
194+
f"than previous node sharing the same decomp op {prev_decomp_op}"
195+
)
194196

195197
bfs_trace_with_node_process(
196198
model, _extract_from_node_source_with_prev_decomp_op_from_node

0 commit comments

Comments
 (0)