From 6b4787f1e030286de4b47781337fa55f809fcb56 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Sep 2024 06:11:25 +0000 Subject: [PATCH 1/4] update --- test/nn/models/test_compile.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/nn/models/test_compile.py b/test/nn/models/test_compile.py index ec53c0d7f..36c760d33 100644 --- a/test/nn/models/test_compile.py +++ b/test/nn/models/test_compile.py @@ -34,7 +34,7 @@ gamma=0.1, ), None, - 7, + 2, id="TabNet", ), pytest.param( @@ -47,21 +47,21 @@ ffn_dropout=0.5, ), None, - 4, + 0, id="TabTransformer", ), pytest.param( Trompt, dict(channels=8, num_prompts=2), None, - 16, + 7, id="Trompt", ), pytest.param( ExcelFormer, dict(in_channels=8, num_cols=3, num_heads=1), [stype.numerical], - 4, + 1, id="ExcelFormer", ), ], @@ -89,4 +89,4 @@ def test_compile_graph_break( **model_kwargs, ) explanation = torch._dynamo.explain(model)(tf) - assert explanation.graph_break_count <= expected_graph_breaks + assert explanation.graph_break_count == expected_graph_breaks From dfe4ca49f589d1b8aeac1fc4e1b1f5b591f4cc3e Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Sep 2024 06:15:23 +0000 Subject: [PATCH 2/4] . --- test/nn/models/test_compile.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/nn/models/test_compile.py b/test/nn/models/test_compile.py index 36c760d33..f7fddeb5e 100644 --- a/test/nn/models/test_compile.py +++ b/test/nn/models/test_compile.py @@ -89,4 +89,5 @@ def test_compile_graph_break( **model_kwargs, ) explanation = torch._dynamo.explain(model)(tf) - assert explanation.graph_break_count == expected_graph_breaks + graph_breaks = explanation.graph_break_count + assert graph_breaks == expected_graph_breaks From d9c75b791469e05a4ea2b854a049c4851669e9b1 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Sep 2024 07:10:08 +0000 Subject: [PATCH 3/4] update --- test/nn/models/test_compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/nn/models/test_compile.py b/test/nn/models/test_compile.py index f7fddeb5e..de5ff4305 100644 --- a/test/nn/models/test_compile.py +++ b/test/nn/models/test_compile.py @@ -14,7 +14,7 @@ from torch_frame.testing import withPackage -@withPackage("torch>=2.1.0") +@withPackage("torch>=2.4.0") @pytest.mark.parametrize( "model_cls, model_kwargs, stypes, expected_graph_breaks", [ From 9436ac037a920d76795c61efc2952232a9ee8d7a Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Sep 2024 17:04:34 +0900 Subject: [PATCH 4/4] Apply suggestions from code review --- test/nn/models/test_compile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/nn/models/test_compile.py b/test/nn/models/test_compile.py index de5ff4305..dc22527cf 100644 --- a/test/nn/models/test_compile.py +++ b/test/nn/models/test_compile.py @@ -14,7 +14,7 @@ from torch_frame.testing import withPackage -@withPackage("torch>=2.4.0") +@withPackage("torch>=2.5.0") @pytest.mark.parametrize( "model_cls, model_kwargs, stypes, expected_graph_breaks", [ @@ -54,7 +54,7 @@ Trompt, dict(channels=8, num_prompts=2), None, - 7, + 4, id="Trompt", ), pytest.param(