Skip to content

Commit c624027

Browse files
[mlir][linalg][TransformOps] Connect hoistRedundantVectorTransfers
Connect the hoistRedundantVectorTransfers functionality to the transform dialect. Authored-by: Quentin Colombet <quentin.colombet@gmail.com> Differential Revision: https://reviews.llvm.org/D144260
1 parent be88b58 commit c624027

File tree

9 files changed

+113
-62
lines changed

9 files changed

+113
-62
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
1010
#define MLIR_DIALECT_LINALG_TRANSFORMOPS_LINALGTRANSFORMOPS_H
1111

12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
1213
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
1314
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
1415
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1706,4 +1706,43 @@ def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
17061706
}];
17071707
}
17081708

1709+
//===----------------------------------------------------------------------===//
1710+
// HoistRedundantVectorTransfersOp
1711+
//===----------------------------------------------------------------------===//
1712+
1713+
def HoistRedundantVectorTransfersOp :
1714+
Op<Transform_Dialect, "structured.hoist_redundant_vector_transfers",
1715+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
1716+
TransformEachOpTrait, TransformOpInterface]> {
1717+
let description = [{
1718+
Hoist vector.transfer_read / vector.transfer_write pairs out of immediately
1719+
enclosing scf::ForOp iteratively, if the following conditions are true:
1720+
1. The 2 ops access the same memref with the same indices.
1721+
2. All operands are invariant under the enclosing scf::ForOp.
1722+
3. No uses of the memref either dominate the transfer_read or are
1723+
dominated by the transfer_write (i.e. no aliasing between the write and
1724+
the read across the loop)
1725+
1726+
#### Return modes:
1727+
1728+
The operation always succeeds and returns a handle to the transformed
1729+
function op.
1730+
}];
1731+
1732+
let arguments = (ins TransformHandleTypeInterface:$target);
1733+
let results = (outs TransformHandleTypeInterface:$transformed);
1734+
1735+
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
1736+
1737+
let builders = [
1738+
OpBuilder<(ins "Value":$target)>,
1739+
];
1740+
let extraClassDeclaration = [{
1741+
::mlir::DiagnosedSilenceableFailure applyToOne(
1742+
::mlir::func::FuncOp target,
1743+
::mlir::transform::ApplyToEachResultList &results,
1744+
::mlir::transform::TransformState &state);
1745+
}];
1746+
}
1747+
17091748
#endif // LINALG_TRANSFORM_OPS

mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRLinalgTransformOps
1010
LINK_LIBS PUBLIC
1111
MLIRAffineDialect
1212
MLIRArithDialect
13+
MLIRFuncDialect
1314
MLIRIR
1415
MLIRLinalgDialect
1516
MLIRLinalgTransforms

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/Arith/IR/Arith.h"
1414
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1515
#include "mlir/Dialect/Linalg/IR/Linalg.h"
16+
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
1617
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1718
#include "mlir/Dialect/Linalg/Utils/Utils.h"
1819
#include "mlir/Dialect/PDL/IR/PDL.h"
@@ -3058,6 +3059,19 @@ SmallVector<OpFoldResult> MaskedVectorizeOp::getMixedVectorSizes() {
30583059
return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
30593060
}
30603061

3062+
//===----------------------------------------------------------------------===//
3063+
// HoistRedundantVectorTransfersOp
3064+
//===----------------------------------------------------------------------===//
3065+
3066+
DiagnosedSilenceableFailure
3067+
transform::HoistRedundantVectorTransfersOp::applyToOne(
3068+
func::FuncOp target, transform::ApplyToEachResultList &results,
3069+
transform::TransformState &state) {
3070+
linalg::hoistRedundantVectorTransfers(target);
3071+
linalg::hoistRedundantVectorTransfersOnTensor(target);
3072+
results.push_back(target);
3073+
return DiagnosedSilenceableFailure::success();
3074+
}
30613075
//===----------------------------------------------------------------------===//
30623076
// Transform op registration
30633077
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-redundant-transfers -allow-unregistered-dialect -split-input-file | FileCheck %s
1+
// RUN: mlir-opt -test-transform-dialect-interpreter --split-input-file --allow-unregistered-dialect %s | FileCheck %s
22

33
// CHECK-LABEL: func @hoist_vector_transfer_pairs(
44
// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
@@ -74,6 +74,14 @@ func.func @hoist_vector_transfer_pairs(
7474
return
7575
}
7676

77+
transform.sequence failures(propagate) {
78+
^bb1(%arg1: !pdl.operation):
79+
%0 = transform.structured.match ops{["func.func"]} in %arg1
80+
: (!pdl.operation) -> !pdl.operation
81+
transform.structured.hoist_redundant_vector_transfers %0
82+
: (!pdl.operation) -> !pdl.operation
83+
}
84+
7785
// -----
7886

7987
// CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint(
@@ -155,6 +163,14 @@ func.func @hoist_vector_transfer_pairs_disjoint(
155163
return
156164
}
157165

166+
transform.sequence failures(propagate) {
167+
^bb1(%arg1: !pdl.operation):
168+
%0 = transform.structured.match ops{["func.func"]} in %arg1
169+
: (!pdl.operation) -> !pdl.operation
170+
transform.structured.hoist_redundant_vector_transfers %0
171+
: (!pdl.operation) -> !pdl.operation
172+
}
173+
158174
// -----
159175

160176
// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor
@@ -236,6 +252,14 @@ func.func @hoist_vector_transfer_pairs_tensor(
236252
tensor<?x?xf32>, tensor<?x?xf32>
237253
}
238254

255+
transform.sequence failures(propagate) {
256+
^bb1(%arg1: !pdl.operation):
257+
%0 = transform.structured.match ops{["func.func"]} in %arg1
258+
: (!pdl.operation) -> !pdl.operation
259+
transform.structured.hoist_redundant_vector_transfers %0
260+
: (!pdl.operation) -> !pdl.operation
261+
}
262+
239263
// -----
240264

241265
// CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint_tensor(
@@ -323,6 +347,14 @@ func.func @hoist_vector_transfer_pairs_disjoint_tensor(
323347
return %0#0, %0#1, %0#2, %0#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
324348
}
325349

350+
transform.sequence failures(propagate) {
351+
^bb1(%arg1: !pdl.operation):
352+
%0 = transform.structured.match ops{["func.func"]} in %arg1
353+
: (!pdl.operation) -> !pdl.operation
354+
transform.structured.hoist_redundant_vector_transfers %0
355+
: (!pdl.operation) -> !pdl.operation
356+
}
357+
326358
// -----
327359

328360
// CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor_and_slices
@@ -432,6 +464,14 @@ func.func @hoist_vector_transfer_pairs_tensor_and_slices(
432464
return %0#0, %0#1, %0#2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
433465
}
434466

467+
transform.sequence failures(propagate) {
468+
^bb1(%arg1: !pdl.operation):
469+
%0 = transform.structured.match ops{["func.func"]} in %arg1
470+
: (!pdl.operation) -> !pdl.operation
471+
transform.structured.hoist_redundant_vector_transfers %0
472+
: (!pdl.operation) -> !pdl.operation
473+
}
474+
435475
// -----
436476

437477
// CHECK-LABEL: func @hoist_vector_transfer_write_pairs_disjoint_tensor(
@@ -469,6 +509,14 @@ func.func @hoist_vector_transfer_write_pairs_disjoint_tensor(
469509
return %1 : tensor<?x?xf32>
470510
}
471511

512+
transform.sequence failures(propagate) {
513+
^bb1(%arg1: !pdl.operation):
514+
%0 = transform.structured.match ops{["func.func"]} in %arg1
515+
: (!pdl.operation) -> !pdl.operation
516+
transform.structured.hoist_redundant_vector_transfers %0
517+
: (!pdl.operation) -> !pdl.operation
518+
}
519+
472520
// -----
473521

474522
// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops(
@@ -505,3 +553,11 @@ func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi3
505553
}
506554
return
507555
}
556+
557+
transform.sequence failures(propagate) {
558+
^bb1(%arg1: !pdl.operation):
559+
%0 = transform.structured.match ops{["func.func"]} in %arg1
560+
: (!pdl.operation) -> !pdl.operation
561+
transform.structured.hoist_redundant_vector_transfers %0
562+
: (!pdl.operation) -> !pdl.operation
563+
}

mlir/test/lib/Dialect/Linalg/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ add_mlir_library(MLIRLinalgTestPasses
44
TestLinalgDecomposeOps.cpp
55
TestLinalgElementwiseFusion.cpp
66
TestLinalgFusionTransforms.cpp
7-
TestLinalgHoisting.cpp
87
TestLinalgTransforms.cpp
98
TestPadFusion.cpp
109

mlir/test/lib/Dialect/Linalg/TestLinalgHoisting.cpp

Lines changed: 0 additions & 58 deletions
This file was deleted.

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ void registerTestLastModifiedPass();
9595
void registerTestLinalgDecomposeOps();
9696
void registerTestLinalgElementwiseFusion();
9797
void registerTestLinalgGreedyFusion();
98-
void registerTestLinalgHoisting();
9998
void registerTestLinalgTransforms();
10099
void registerTestLivenessPass();
101100
void registerTestLoopFusion();
@@ -205,7 +204,6 @@ void registerTestPasses() {
205204
mlir::test::registerTestLinalgDecomposeOps();
206205
mlir::test::registerTestLinalgElementwiseFusion();
207206
mlir::test::registerTestLinalgGreedyFusion();
208-
mlir::test::registerTestLinalgHoisting();
209207
mlir::test::registerTestLinalgTransforms();
210208
mlir::test::registerTestLivenessPass();
211209
mlir::test::registerTestLoopFusion();

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8350,6 +8350,7 @@ cc_library(
83508350
":AsmParser",
83518351
":ControlFlowDialect",
83528352
":DialectUtils",
8353+
":FuncDialect",
83538354
":GPUDialect",
83548355
":IR",
83558356
":LinalgDialect",

0 commit comments

Comments
 (0)