Skip to content

Commit da237d0

Browse files
authored
LLVM / Torch-Mlir Upgrade (#8)
- Remove StableHLO dependency - Apply patches to build torch-mlir without stablehlo and stablehlo lit tests - Update APIs realted to` isa<>`, `cast<>`, and `dyn_cast<>` - Add a new EliminateUnusedTorchOpsPass to remove unused torch ops in torch-to-tcp (see example below): ``` module { func.func @func_main(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[?,3],f32> { %int2 = torch.constant.int 2 %int0 = torch.constant.int 0 %0 = torch.symbolic_int "s35" {min_val = 0, max_val = 9223372036854775807} : !torch.int torch.bind_symbolic_shape %arg0, [%0], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> torch.bind_symbolic_shape %arg1, [%0], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> %1 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int %2 = torch.aten.size.int %arg1, %int0 : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int %3 = torch.aten.sub.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.int -> !torch.vtensor<[?,3],f32> torch.bind_symbolic_shape %3, [%0], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> %4 = torch.aten.eq.int %1, %2 : !torch.int, !torch.int -> !torch.bool %5 = torch.aten.Int.bool %4 : !torch.bool -> !torch.int %6 = torch.aten.Bool.int %5 : !torch.int -> !torch.bool torch.runtime.assert %6, "Runtime assertion failed for expression Eq(s35, s58) on node 'eq_2'" return %3 : !torch.vtensor<[?,3],f32> } } ``` where the `torch.runtime.assert` and related checking ops can be removed.
1 parent ef17021 commit da237d0

33 files changed

+558
-565
lines changed

.bazelignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44
# Also available under a BSD-style license. See LICENSE.
55

6-
# ignore local_repos of llvm-project, torch-mlir, stablehlo
6+
# ignore local_repos of llvm-project, torch-mlir
77
third_party/llvm-project
88
third_party/torch-mlir
9-
third_party/stablehlo

.github/workflows/bazelBuildAndTestStablehlo.yml

Lines changed: 0 additions & 64 deletions
This file was deleted.

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@ bazel-out
99
bazel-mlir-tcp
1010
bazel-testlogs
1111

12-
# ignore local_repos of llvm, torch-mlir, stablehlo
12+
# ignore local_repos of llvm, torch-mlir
1313
third_party/llvm-project
1414
third_party/torch-mlir
15-
third_party/stablehlo
1615

1716
# clangd related
1817
.cache

BUILD

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ cc_library(
150150
name = "TcpDialectPasses",
151151
srcs = [
152152
"lib/Dialect/Transforms/DropSymbolicShapeOpsPass.cpp",
153+
"lib/Dialect/Transforms/EliminateUnusedTorchOpsPass.cpp",
153154
"lib/Dialect/Transforms/FuseTcpOpsPass.cpp",
154155
"lib/Dialect/Transforms/FusionPatterns.cpp",
155156
"lib/Dialect/Transforms/IsolateGroupOpsPass.cpp",
@@ -160,6 +161,7 @@ cc_library(
160161
],
161162
hdrs = [
162163
"include/mlir-tcp/Dialect/Transforms/DropSymbolicShapeOpsPass.h",
164+
"include/mlir-tcp/Dialect/Transforms/EliminateUnusedTorchOpsPass.h",
163165
"include/mlir-tcp/Dialect/Transforms/FuseTcpOpsPass.h",
164166
"include/mlir-tcp/Dialect/Transforms/FusionPatterns.h",
165167
"include/mlir-tcp/Dialect/Transforms/IsolateGroupOpsPass.h",
@@ -175,6 +177,7 @@ cc_library(
175177
"@llvm-project//mlir:TensorDialect",
176178
"@llvm-project//mlir:TensorTransforms",
177179
"@llvm-project//mlir:Transforms",
180+
"@torch-mlir//:TorchMLIRTorchDialect",
178181
],
179182
)
180183

@@ -198,7 +201,6 @@ cc_library(
198201
hdrs = ["include/mlir-tcp/Conversion/Passes.h"],
199202
strip_include_prefix = "include",
200203
deps = [
201-
":StablehloToTcp",
202204
":TcpToArith",
203205
":TcpToLinalg",
204206
":TcpToTensor",
@@ -237,25 +239,6 @@ cc_library(
237239
],
238240
)
239241

240-
cc_library(
241-
name = "StablehloToTcp",
242-
srcs = [
243-
"lib/Conversion/PassDetail.h",
244-
"lib/Conversion/StablehloToTcp/StablehloToTcp.cpp",
245-
],
246-
hdrs = ["include/mlir-tcp/Conversion/StablehloToTcp/StablehloToTcp.h"],
247-
strip_include_prefix = "include",
248-
deps = [
249-
":TcpConversionPassesIncGen",
250-
":TcpDialect",
251-
"@llvm-project//mlir:Dialect",
252-
"@llvm-project//mlir:LinalgDialect",
253-
"@llvm-project//mlir:Pass",
254-
"@llvm-project//mlir:Transforms",
255-
"@stablehlo//:stablehlo_ops",
256-
],
257-
)
258-
259242
cc_library(
260243
name = "TcpToLinalg",
261244
srcs = [
@@ -364,6 +347,5 @@ cc_binary(
364347
"@llvm-project//mlir:AllPassesAndDialects",
365348
"@llvm-project//mlir:MlirOptLib",
366349
"@llvm-project//mlir:QuantOps",
367-
"@stablehlo//:register",
368350
],
369351
)

README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,15 @@ bazel run //tools/clangd:refresh_compile_commands
5151
```
5252
When run successfully, a `compile_commands.json` is generated at the workspace root (and refreshed upon re-runs). If you're using VSCode, just hit CMD+SHIFT+P and select `clangd: Restart language server` to start clangd. Note that this only works for non-docker builds at the moment.
5353

54-
When bumping upstream dependencies (LLVM, Torch-MLIR, StableHLO), you may validate the set of "green commits" by running the corresponding third-party tests:
54+
When bumping upstream dependencies (LLVM, Torch-MLIR), you may validate the set of "green commits" by running the corresponding third-party tests:
5555
```shell
5656
bazel test @llvm-project//mlir/...
5757
bazel test @torch-mlir//...
58-
bazel test @stablehlo//...
5958
```
6059

6160
The following CI workflows are automatically triggered anytime upstream dependencies (`deps.bzl`) are updated:
6261
- [![Bazel Build and Test (llvm-project)](https://github.com/llvm/mlir-tcp/actions/workflows/bazelBuildAndTestLlvm.yml/badge.svg)](https://github.com/llvm/mlir-tcp/actions/workflows/bazelBuildAndTestLlvm.yml)
6362
- [![Bazel Build and Test (torch-mlir)](https://github.com/llvm/mlir-tcp/actions/workflows/bazelBuildAndTestTorchmlir.yml/badge.svg)](https://github.com/llvm/mlir-tcp/actions/workflows/bazelBuildAndTestTorchmlir.yml)
64-
- [![Bazel Build and Test (stablehlo)](https://github.com/llvm/mlir-tcp/actions/workflows/bazelBuildAndTestStablehlo.yml/badge.svg)](https://github.com/llvm/mlir-tcp/actions/workflows/bazelBuildAndTestStablehlo.yml)
6563

6664
To use newer `torch-mlir` and/or `torch` python packages in our hermetic python sandbox, just regenerate `requirements_lock.txt` as follows:
6765
```shell

deps.bzl

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
77
load(
88
":local_repos.bzl",
99
"local_llvm_repo_path",
10-
"local_stablehlo_repo_path",
1110
"local_torch_mlir_repo_path",
1211
"use_local_llvm_repo",
13-
"use_local_stablehlo_repo",
1412
"use_local_torch_mlir_repo",
1513
)
1614

@@ -22,8 +20,8 @@ def third_party_deps():
2220
path = local_llvm_repo_path(),
2321
)
2422
else:
25-
LLVM_COMMIT = "72144d119a7291f8b6b8e022a2947fbe31e66afc"
26-
LLVM_SHA256 = "2caacb6925a13cb5886a5d7f225fa408b80ca8e1efe0736186954b2abc4ee1c3"
23+
LLVM_COMMIT = "b231e5ff504295641b0f580ceefa2e1048011614"
24+
LLVM_SHA256 = "88dfa59052730710cb48fa20b00a4344144edd1c3cb524c06d983899835e491a"
2725
http_archive(
2826
name = "llvm-raw",
2927
build_file_content = "# empty",
@@ -39,32 +37,15 @@ def third_party_deps():
3937
path = local_torch_mlir_repo_path(),
4038
)
4139
else:
42-
TORCH_MLIR_COMMIT = "9f2ba5abaa85cefd95cc85579fafd0c53c1101e8"
43-
TORCH_MLIR_SHA256 = "09444281839eeae4aff42c029d87b1728f307fa26511b896ff448d51aaa98049"
40+
TORCH_MLIR_COMMIT = "1ad9702d2a290b693c4f6f17921d0e0a8d14a999"
41+
TORCH_MLIR_SHA256 = "8843399168c34ca3ca16d2417703fe4e1440ca7240d9e04844b3deedf256f0ab"
4442
http_archive(
4543
name = "torch-mlir-raw",
4644
build_file_content = "# empty",
45+
patches = ["//third_party/patches:torch-mlir-bazel-build.1.patch", "//third_party/patches:torch-mlir-bazel-build.2.patch"],
4746
sha256 = TORCH_MLIR_SHA256,
4847
strip_prefix = "torch-mlir-" + TORCH_MLIR_COMMIT,
4948
urls = ["https://github.com/llvm/torch-mlir/archive/{commit}.tar.gz".format(commit = TORCH_MLIR_COMMIT)],
50-
patches = [
51-
"//third_party/patches:torch-mlir.1.patch",
52-
],
53-
)
54-
55-
if use_local_stablehlo_repo():
56-
native.local_repository(
57-
name = "stablehlo",
58-
path = local_stablehlo_repo_path(),
59-
)
60-
else:
61-
STABLEHLO_COMMIT = "a54938f0651d3b4b7be9771848eda2463c92a8e7"
62-
STABLEHLO_SHA256 = "edab2288f0b19e3efbf08815d17d4efb106984aa6fe02fed0cb2165284e6a5b7"
63-
http_archive(
64-
name = "stablehlo",
65-
sha256 = STABLEHLO_SHA256,
66-
strip_prefix = "stablehlo-" + STABLEHLO_COMMIT,
67-
urls = ["https://github.com/openxla/stablehlo/archive/{commit}.tar.gz".format(commit = STABLEHLO_COMMIT)],
6849
)
6950

7051
SKYLIB_VERSION = "1.3.0"

include/mlir-tcp/Conversion/Passes.td

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,6 @@ def ConvertTorchToTcpCustomOp : Pass<"convert-torch-to-tcp-custom-op", "func::Fu
4646
];
4747
}
4848

49-
//===----------------------------------------------------------------------===//
50-
// StablehloToTcp
51-
//===----------------------------------------------------------------------===//
52-
53-
def ConvertStablehloToTcp
54-
: Pass<"convert-stablehlo-to-tcp", "func::FuncOp"> {
55-
let summary = "Lower StableHLO to TCP";
56-
let description = [{
57-
Pass that converts StableHLO operations to equivalent operations in TCP.
58-
}];
59-
60-
let constructor = "mlir::tcp::createConvertStablehloToTcpPass()";
61-
let dependentDialects = [
62-
"mlir::tcp::TcpDialect",
63-
];
64-
}
65-
6649
//===----------------------------------------------------------------------===//
6750
// TcpToLinalg
6851
//===----------------------------------------------------------------------===//

include/mlir-tcp/Dialect/IR/TcpTypes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ include "mlir-tcp/Dialect/IR/TcpBase.td"
2424
// Where low and high ends are 0,255 when unsigned, -128,127 when signed, for
2525
// the 8-bit case.
2626
class Tcp_QuantizedType<string n, list<int> params, bit signed>
27-
: Type<And<[CPred<"$_self.isa<mlir::quant::QuantizedType>()">,
28-
CPred<"$_self.cast<mlir::quant::QuantizedType>()" #
27+
: Type<And<[CPred<"::llvm::isa<mlir::quant::QuantizedType>($_self)">,
28+
CPred<"::llvm::cast<mlir::quant::QuantizedType>($_self)" #
2929
".getStorageTypeIntegralWidth() == " # !head(params)>]>,
3030
"Q" # !if (signed, "int", "uint") # !head(params) # " type"> {
3131
string name = n;

include/mlir-tcp/Conversion/StablehloToTcp/StablehloToTcp.h renamed to include/mlir-tcp/Dialect/Transforms/EliminateUnusedTorchOpsPass.h

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,13 @@
99

1010
#pragma once
1111

12-
#include "mlir/Dialect/Func/IR/FuncOps.h"
12+
#include "mlir/IR/BuiltinOps.h"
1313
#include "mlir/Pass/Pass.h"
14+
#include <memory>
1415

15-
namespace mlir {
16+
namespace mlir::tcp {
1617

17-
#define GEN_PASS_DECL_CONVERTSTABLEHLOTOTCP
18-
#include "mlir-tcp/Conversion/Passes.h.inc"
18+
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
19+
createEliminateUnusedTorchOpsPass();
1920

20-
namespace tcp {
21-
22-
std::unique_ptr<OperationPass<func::FuncOp>> createConvertStablehloToTcpPass();
23-
24-
} // namespace tcp
25-
} // namespace mlir
21+
} // namespace mlir::tcp

include/mlir-tcp/Dialect/Transforms/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,10 @@ def DropSymbolicShapeOps : Pass<"drop-symbolic-shape-ops", "func::FuncOp"> {
4545
let constructor = "mlir::tcp::createDropSymbolicShapeOpsPass()";
4646
}
4747

48+
// \brief This pass removes unused torch ops.
49+
def EliminateUnusedTorchOps : Pass<"eliminate-unused-torch-ops", "ModuleOp"> {
50+
let summary = "Removes unused/unnecessary torch ops";
51+
let constructor = "mlir::tcp::createEliminateUnusedTorchOpsPass()";
52+
}
53+
4854
#endif // TCP_PASSES

lib/Conversion/Passes.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
#include "mlir-tcp/Conversion/Passes.h"
1111

12-
#include "mlir-tcp/Conversion/StablehloToTcp/StablehloToTcp.h"
1312
#include "mlir-tcp/Conversion/TcpToArith/TcpToArith.h"
1413
#include "mlir-tcp/Conversion/TcpToLinalg/TcpToLinalg.h"
1514
#include "mlir-tcp/Conversion/TcpToTensor/TcpToTensor.h"

lib/Conversion/StablehloToTcp/StablehloToTcp.cpp

Lines changed: 0 additions & 75 deletions
This file was deleted.

0 commit comments

Comments
 (0)