diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h index 4edf432d9d97d..39c3451b1369f 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h @@ -27,8 +27,8 @@ namespace linalg { /// dominated by the transfer_write (i.e. no aliasing between the write and /// the read across the loop) /// 4. The source operands for vector.transfer_{read|write} do not originate -/// from Ops implementing ViewLikeOpInterface (to reduce the risk of -/// aliasing). +/// from ops implementing ViewLikeOpInterface (to reduce the risk of +/// aliasing), except memref::AssumeAlignmentOp. /// 5. If `verifyNonZeroTrip` is true, then the lower bound of the loop must /// be statically smaller than the upper bound of the loop, guaranteeing that /// the loop body will execute at least once. @@ -39,8 +39,8 @@ namespace linalg { /// /// TODO: To further improve hoisting opportunities, fold aliasing memref /// operations into respective vector.transfer{read|write} operations and -/// avoid using ops implementing ViewLikeOpInterface as the source for transfer -/// Ops. +/// avoid using ops implementing ViewLikeOpInterface, except +/// memref::AssumeAlignmentOp, as the source for transfer ops. /// /// WARNING: This hoisting does not model parallelism and is generally incorrect /// when used on distributed loops with memref semantics! diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index 707b63ff9335b..d473a1cff08cc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -303,7 +304,8 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root, // 1. indices, vector type and permutation map are the same (i.e., the // transfer_read/transfer_write ops are matching), // 2. source operands for transfer.{read|write} do not originate from - // Ops implementing ViewLikeOpInterface. + // ops implementing ViewLikeOpInterface, except + // memref::AssumeAlingmentOp. // 3. no other operations in the loop access the same memref except // for transfer_read/transfer_write accessing statically disjoint // slices. @@ -313,11 +315,13 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root, return WalkResult::advance(); auto *source = transferRead.getBase().getDefiningOp(); - if (source && isa_and_nonnull(source)) + if (source && isa_and_nonnull(source) && + !isa(source)) return WalkResult::advance(); source = transferWrite.getBase().getDefiningOp(); - if (source && isa_and_nonnull(source)) + if (source && isa_and_nonnull(source) && + !isa(source)) return WalkResult::advance(); // TODO: may want to memoize this information for performance but it diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir index 318edca73cce1..a71bd044d98c1 100644 --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -18,11 +18,13 @@ func.func @hoist_vector_transfer_pairs( %val: index, %lb : index, %ub : index, %step: index, %cmp: i1) { %c0 = arith.constant 0 : index %cst = arith.constant 0.0 : f32 + %assume_align = memref.assume_alignment %memref0, 64 : memref // CHECK: vector.transfer_read %{{.*}} : memref, vector<1xf32> -// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) { +// CHECK: vector.transfer_read %{{.*}} : memref, vector<1xf32> +// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<1xf32>) { // CHECK: vector.transfer_read %{{.*}} : memref, vector<2xf32> -// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) { +// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>, vector<1xf32>) { // CHECK: vector.transfer_read %{{.*}} : memref, vector<3xf32> // CHECK: vector.transfer_read %{{.*}} : memref, vector<4xf32> // CHECK: "some_crippling_use"(%[[MEMREF4]]) : (memref) -> () @@ -43,6 +45,7 @@ func.func @hoist_vector_transfer_pairs( // CHECK: scf.yield {{.*}} : vector<1xf32> // CHECK: } // CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref +// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref // CHECK: "unrelated_use"(%[[MEMREF1]]) : (memref) -> () scf.for %i = %lb to %ub step %step { scf.for %j = %lb to %ub step %step { @@ -53,6 +56,7 @@ func.func @hoist_vector_transfer_pairs( "some_crippling_use"(%memref4) : (memref) -> () %r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref, vector<5xf32> %r5 = vector.transfer_read %memref5[%c0, %c0], %cst: memref, vector<6xf32> + %r6 = vector.transfer_read %assume_align[%c0, %c0], %cst: memref, vector<1xf32> "some_crippling_use"(%memref5) : (memref) -> () %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32> %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32> @@ -60,12 +64,14 @@ func.func @hoist_vector_transfer_pairs( %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32> %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32> %u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32> + %u6 = "some_use"(%r6) : (vector<1xf32>) -> vector<1xf32> vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref vector.transfer_write %u5, %memref5[%c0, %c0] : vector<6xf32>, memref + vector.transfer_write %u6, %assume_align[%c0, %c0] : vector<1xf32>, memref "some_crippling_use"(%memref3) : (memref) -> () } "unrelated_use"(%memref0) : (memref) -> ()