Skip to content

Commit 4bd7d03

Browse files
build: manually update PyTorch version (llvm#4112)
This commit sets the PyTorch and TorchVision versions to nightly release 2025-04-23. This commit also updates the fx_importer tests with the changes related to symbolic shape ids introduced in pytorch/pytorch@f649ee7. This commit disables some checks because of the mismatch in the resultant IR for PyTorch nightly and stable versions. Those checks will be enabled once they are on the same page. --------- Signed-off-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
1 parent 44d6e4e commit 4bd7d03

File tree

9 files changed

+86
-51
lines changed

9 files changed

+86
-51
lines changed

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15404,6 +15404,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1540415404
" return %0 : !torch.int\n"
1540515405
" }\n"
1540615406
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_vector_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.int {\n"
15407+
" %int6 = torch.constant.int 6\n"
15408+
" %int15 = torch.constant.int 15\n"
15409+
" %int5 = torch.constant.int 5\n"
1540715410
" %true = torch.constant.bool true\n"
1540815411
" %none = torch.constant.none\n"
1540915412
" %str = torch.constant.str \"AssertionError: \"\n"
@@ -15452,12 +15455,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1545215455
" }\n"
1545315456
" torch.prim.If.yield %9 : !torch.int\n"
1545415457
" } else {\n"
15455-
" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
15456-
" torch.prim.If.yield %5 : !torch.int\n"
15458+
" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
15459+
" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
15460+
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
15461+
" torch.prim.If.yield %int6 : !torch.int\n"
15462+
" } else {\n"
15463+
" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
15464+
" torch.prim.If.yield %8 : !torch.int\n"
15465+
" }\n"
15466+
" torch.prim.If.yield %7 : !torch.int\n"
1545715467
" }\n"
1545815468
" return %4 : !torch.int\n"
1545915469
" }\n"
1546015470
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<number>, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.int {\n"
15471+
" %int6 = torch.constant.int 6\n"
15472+
" %int15 = torch.constant.int 15\n"
15473+
" %int5 = torch.constant.int 5\n"
1546115474
" %true = torch.constant.bool true\n"
1546215475
" %none = torch.constant.none\n"
1546315476
" %str = torch.constant.str \"AssertionError: \"\n"
@@ -15506,8 +15519,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1550615519
" }\n"
1550715520
" torch.prim.If.yield %9 : !torch.int\n"
1550815521
" } else {\n"
15509-
" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
15510-
" torch.prim.If.yield %5 : !torch.int\n"
15522+
" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
15523+
" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
15524+
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
15525+
" torch.prim.If.yield %int6 : !torch.int\n"
15526+
" } else {\n"
15527+
" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
15528+
" torch.prim.If.yield %8 : !torch.int\n"
15529+
" }\n"
15530+
" torch.prim.If.yield %7 : !torch.int\n"
1551115531
" }\n"
1551215532
" return %4 : !torch.int\n"
1551315533
" }\n"
@@ -15531,6 +15551,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1553115551
" }\n"
1553215552
" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
1553315553
" %true = torch.constant.bool true\n"
15554+
" %int6 = torch.constant.int 6\n"
15555+
" %int15 = torch.constant.int 15\n"
1553415556
" %int5 = torch.constant.int 5\n"
1553515557
" %int8 = torch.constant.int 8\n"
1553615558
" %none = torch.constant.none\n"
@@ -15548,8 +15570,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1554815570
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
1554915571
" torch.prim.If.yield %int5 : !torch.int\n"
1555015572
" } else {\n"
15551-
" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
15552-
" torch.prim.If.yield %5 : !torch.int\n"
15573+
" %5 = torch.prim.ListConstruct %int5, %int15 : (!torch.int, !torch.int) -> !torch.list<int>\n"
15574+
" %6 = torch.aten.__contains__.int_list %5, %0#1 : !torch.list<int>, !torch.int -> !torch.bool\n"
15575+
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
15576+
" torch.prim.If.yield %int6 : !torch.int\n"
15577+
" } else {\n"
15578+
" %8 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple<int, int>, !torch.bool) -> !torch.int\n"
15579+
" torch.prim.If.yield %8 : !torch.int\n"
15580+
" }\n"
15581+
" torch.prim.If.yield %7 : !torch.int\n"
1555315582
" }\n"
1555415583
" return %4 : !torch.int\n"
1555515584
" }\n"

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5293,6 +5293,8 @@ def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Uni
52935293
return aten〇std〡dtype((self_rank, dtype))
52945294
assert not is_complex_dtype(dtype)
52955295
return dtype
5296+
if self_dtype in [torch.float16, torch.bfloat16]:
5297+
return torch.float32
52965298
return aten〇std〡dtype(self_rank_dtype)
52975299

52985300
@check_dtype_function(
@@ -5316,6 +5318,8 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U
53165318
return aten〇std〡dtype((self_rank, dtype))
53175319
assert not is_complex_dtype(dtype)
53185320
return dtype
5321+
if self_dtype in [torch.float16, torch.bfloat16]:
5322+
return torch.float32
53195323
return aten〇std〡dtype(self_rank_dtype)
53205324

53215325
def aten〇binary_cross_entropy_with_logits〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]] = None, pos_weight_rank_dtype: Optional[Tuple[int, int]] = None, reduction: int = 1) -> int:
@@ -5349,6 +5353,8 @@ def aten〇norm〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int,
53495353
# Should possibly be added to aten〇std〡dtype.
53505354
if self_dtype == torch.complex32:
53515355
return torch.half
5356+
if self_dtype in [torch.float16, torch.bfloat16]:
5357+
return torch.float32
53525358
return aten〇std〡dtype(self_rank_dtype)
53535359

53545360
@check_dtype_function([Invocation(0.0),

pytorch-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3794824ceb12a9d4396eaa17795bf2147fd9e1c3
1+
dab7e5700392e4e20626de9c367acb76187807f5

pytorch-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
-f https://download.pytorch.org/whl/nightly/cpu/torch/
22
--pre
3-
torch==2.8.0.dev20250325
3+
torch==2.8.0.dev20250423

test/python/fx_importer/basic_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,11 @@ def forward(self, x):
8888
@run
8989
# CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes
9090
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,5],f32>) -> !torch.vtensor<[?,?,5],f32>
91-
# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
92-
# CHECK: %[[S1:.*]] = torch.symbolic_int "s1" {min_val = 2, max_val = {{[0-9]+}}} : !torch.int
93-
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 5)> : !torch.vtensor<[?,?,5],f32>
91+
# CHECK: %[[S0:.*]] = torch.symbolic_int "{{[a-z0-9]+}}" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
92+
# CHECK: %[[S1:.*]] = torch.symbolic_int "{{[a-z0-9]+}}" {min_val = 2, max_val = {{[0-9]+}}} : !torch.int
93+
# CHECK-DISABLED: torch.bind_symbolic_shape %[[ARG0]], [%[[S1]], %[[S0]]], affine_map<()[s0, s1] -> (s1, s0, 5)> : !torch.vtensor<[?,?,5],f32>
9494
# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?,5],f32> -> !torch.vtensor<[?,?,5],f32>
95-
# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 5)> : !torch.vtensor<[?,?,5],f32>
95+
# CHECK-DISABLED: torch.bind_symbolic_shape %[[TANH]], [%[[S1]], %[[S0]]], affine_map<()[s0, s1] -> (s1, s0, 5)> : !torch.vtensor<[?,?,5],f32>
9696
# CHECK: return %[[TANH]] : !torch.vtensor<[?,?,5],f32>
9797
def test_import_frozen_exported_program_with_dynamic_shapes():
9898
class Basic(nn.Module):
@@ -118,7 +118,7 @@ def forward(self, x):
118118
@run
119119
# CHECK-LABEL: test_broadcast_with_dynamic_shapes
120120
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32>
121-
# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
121+
# CHECK: %[[S0:.*]] = torch.symbolic_int "{{[a-z0-9]+}}" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
122122
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
123123
# CHECK: torch.aten.size.int
124124
# CHECK: torch.prim.ListConstruct

test/python/fx_importer/custom_op_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ def run(f):
2626
# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
2727
# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
2828
# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
29-
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
30-
# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int
31-
# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int
32-
# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
33-
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
29+
# CHECK: %[[S0:.+]] = torch.symbolic_int "{{[a-z0-9]+}}" {min_val = 5, max_val = 10} : !torch.int
30+
# CHECK: %[[S1:.+]] = torch.symbolic_int "{{[a-z0-9]+}}" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int
31+
# CHECK: %[[S2:.+]] = torch.symbolic_int "{{[a-z0-9]+}}" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int
32+
# CHECK: %[[S3:.+]] = torch.symbolic_int "{{[a-z0-9]+}}" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
33+
# CHECK-DISABLED: torch.bind_symbolic_shape %[[ARG0]], [%[[S1]], %[[S0]]], affine_map<()[s0, s1] -> (s1, s0, 3)> : !torch.vtensor<[?,?,3],f32>
3434
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
35-
# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
35+
# CHECK-DISABLED: torch.bind_symbolic_shape %[[ARG2]], [%[[S3]], %[[S0]]], affine_map<()[s0, s1] -> (s1, s0, 3)> : !torch.vtensor<[?,?,3],f32>
3636
# CHECK: %[[OP:.+]] = torch.operator "torch.my_custom_library.tanh_sigmoid_cat_op"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32>
37-
# CHECK: torch.bind_symbolic_shape %[[OP]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32>
37+
# CHECK-DISABLED: torch.bind_symbolic_shape %[[OP]], [%[[S1]], %[[S3]], %[[S0]], %[[S2]]], affine_map<()[s0, s1, s2, s3] -> (s2, s1 + s3 + s0 * 2, 3)> : !torch.vtensor<[?,?,3],f32>
3838
# CHECK: return %[[OP]] : !torch.vtensor<[?,?,3],f32>
3939
def test_tanh_sigmoid_cat_custom_op():
4040

@@ -89,7 +89,7 @@ def forward(self, x, y, z):
8989
@run
9090
# CHECK-LABEL: test_custom_op_array_output
9191
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,3],f32>)
92-
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = 10} : !torch.int
92+
# CHECK: %[[S0:.+]] = torch.symbolic_int "{{[a-z0-9]+}}" {min_val = {{[0-9]+}}, max_val = 10} : !torch.int
9393
# CHECK: %[[int:.+]] = torch.constant.int 4
9494
# CHECK: %[[V0:.+]] = torch.operator "torch.my_custom_library.array_output_op"(%[[int]], %[[ARG0]]) : (!torch.int, !torch.vtensor<[?,3],f32>) -> !torch.list<vtensor>
9595
# CHECK: %[[V1:.+]]:4 = torch.prim.ListUnpack %[[V0]] : !torch.list<vtensor> -> !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>

0 commit comments

Comments
 (0)