diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp index 9a368f372c296..60c8660658a95 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" @@ -52,13 +53,25 @@ static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter, } static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, - vector::MaskedLoadOp maskedOp) { + vector::MaskedLoadOp maskedOp, + bool passthru) { VectorType vectorType = maskedOp.getVectorType(); Value load = builder.create( loc, vectorType, maskedOp.getBase(), maskedOp.getIndices()); - Value res = builder.create( - loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru()); - return res; + if (passthru) + load = builder.create(loc, vectorType, maskedOp.getMask(), + load, maskedOp.getPassThru()); + return load; +} + +/// Check if the given value comes from a broadcasted i1 condition. +static FailureOr matchFullMask(OpBuilder &b, Value val) { + auto broadcastOp = val.getDefiningOp(); + if (!broadcastOp) + return failure(); + if (isa(broadcastOp.getSourceType())) + return failure(); + return broadcastOp.getSource(); } static constexpr char kMaskedloadNeedsMask[] = @@ -78,6 +91,16 @@ struct MaskedLoadLowering final : OpRewritePattern { return failure(); } + // Check if this is either a full inbounds load or an empty, oob load. If + // so, take the fast path and don't generate an if condition, because we + // know doing the oob load is always safe. + if (succeeded(matchFullMask(rewriter, maskedOp.getMask()))) { + Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(), + maskedOp, /*passthru=*/true); + rewriter.replaceOp(maskedOp, load); + return success(); + } + Location loc = maskedOp.getLoc(); Value src = maskedOp.getBase(); @@ -135,7 +158,8 @@ struct MaskedLoadLowering final : OpRewritePattern { }; auto elseBuilder = [&](OpBuilder &builder, Location loc) { - Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp); + Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp, + /*passthru=*/true); rewriter.create(loc, res); }; @@ -148,11 +172,63 @@ struct MaskedLoadLowering final : OpRewritePattern { } }; +struct FullMaskedLoadToConditionalLoad + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp, + PatternRewriter &rewriter) const override { + FailureOr maybeCond = matchFullMask(rewriter, loadOp.getMask()); + if (failed(maybeCond)) { + return failure(); + } + + Value cond = maybeCond.value(); + auto trueBuilder = [&](OpBuilder &builder, Location loc) { + Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp, + /*passthru=*/false); + rewriter.create(loc, res); + }; + auto falseBuilder = [&](OpBuilder &builder, Location loc) { + rewriter.create(loc, loadOp.getPassThru()); + }; + auto ifOp = rewriter.create(loadOp.getLoc(), cond, trueBuilder, + falseBuilder); + rewriter.replaceOp(loadOp, ifOp); + return success(); + } +}; + +struct FullMaskedStoreToConditionalStore + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp, + PatternRewriter &rewriter) const override { + FailureOr maybeCond = matchFullMask(rewriter, storeOp.getMask()); + if (failed(maybeCond)) { + return failure(); + } + Value cond = maybeCond.value(); + + auto trueBuilder = [&](OpBuilder &builder, Location loc) { + rewriter.create(loc, storeOp.getValueToStore(), + storeOp.getBase(), storeOp.getIndices()); + rewriter.create(loc); + }; + auto ifOp = rewriter.create(storeOp.getLoc(), cond, trueBuilder); + rewriter.replaceOp(storeOp, ifOp); + return success(); + } +}; + } // namespace void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); + patterns.add(patterns.getContext(), + benefit); } struct AmdgpuMaskedloadToLoadPass final diff --git a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir index febe46bf7a759..f1d0ad545539a 100644 --- a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir +++ b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir @@ -114,3 +114,56 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space>, %idx : index, %mask : vector<4xi1>, %passthru : vector<4xf32>) -> vector<4xf32> { + %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf32, #amdgpu.address_space>, vector<4xi1>, vector<4xf32> into vector<4xf32> + return %res : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @full_select_maskedload_fatrawbuffer_to_load +func.func @full_select_maskedload_fatrawbuffer_to_load(%arg0: memref<8x8xf16, #amdgpu.address_space>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) -> vector<4xf16> { + %0 = vector.broadcast %arg2 : i1 to vector<4xi1> + %1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16, #amdgpu.address_space>, vector<4xi1>, vector<4xf16> into vector<4xf16> + return %1 : vector<4xf16> +} +// CHECK-NOT: vector.maskedload +// CHECK: vector.load +// CHECK: arith.select + +// ----- + +// CHECK-LABEL: func.func @full_select_maskedload_to_load +// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>, +// CHECK-SAME: %[[IDX:.+]]: index, +// CHECK-SAME: %[[PRED:.+]]: i1, +// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>) +func.func @full_select_maskedload_to_load(%arg0: memref<8x8xf16>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) -> vector<4xf16> { + %0 = vector.broadcast %arg2 : i1 to vector<4xi1> + %1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16> + return %1 : vector<4xf16> +} +// CHECK-NOT: vector.maskedload +// CHECK: scf.if %[[PRED]] +// CHECK: %[[LOAD:.+]] = vector.load +// CHECK: scf.yield %[[LOAD]] +// CHECK: else +// CHECK: scf.yield %[[PASSTHRU]] + +// ----- + +// CHECK-LABEL: func.func @full_mask_maskedstore_to_store +// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>, +// CHECK-SAME: %[[IDX:.+]]: index, +// CHECK-SAME: %[[PRED:.+]]: i1, +func.func @full_mask_maskedstore_to_store(%arg0: memref<8x8xf16>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) { + %0 = vector.broadcast %arg2 : i1 to vector<4xi1> + vector.maskedstore %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> + return +} +// CHECK-NOT: vector.maskedstore +// CHECK: scf.if %[[PRED]] +// CHECK: vector.store