From 8b0c534e597b9dd9c1f89fade5daab0e3e1469e6 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Wed, 2 Jul 2025 18:25:45 +0100 Subject: [PATCH 1/3] [mlir][AMDGPU] Add better load/store lowering for full select mask --- .../AMDGPU/Transforms/MaskedloadToLoad.cpp | 95 ++++++++++++++++++- .../Dialect/AMDGPU/maskedload-to-load.mlir | 25 +++++ 2 files changed, 119 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp index 9a368f372c296..b290dc46910e3 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp @@ -61,6 +61,36 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, return res; } +/// Check if the given value comes from a: +/// +/// arith.select %cond, TRUE/FALSE, TRUE/FALSE +/// +/// i.e the condition is either always true or it's always false. +/// +/// Returns the condition to use for scf.if (condition) { true } else { false }. +static FailureOr matchFullSelect(OpBuilder &b, Value val) { + auto selectOp = val.getDefiningOp(); + if (!selectOp) + return failure(); + std::optional trueInt = getConstantIntValue(selectOp.getTrueValue()); + std::optional falseInt = + getConstantIntValue(selectOp.getFalseValue()); + if (!trueInt || !falseInt) + return failure(); + // getConstantIntValue returns -1 for "true" for bools. + if (trueInt.value() == -1 && falseInt.value() == 0) + return selectOp.getCondition(); + + if (trueInt.value() == 0 && falseInt.value() == -1) { + Value cond = selectOp.getCondition(); + Value one = b.create(cond.getLoc(), /*value=*/true, + /*width=*/1); + Value inverse = b.create(cond.getLoc(), cond, one); + return inverse; + } + return failure(); +} + static constexpr char kMaskedloadNeedsMask[] = "amdgpu.buffer_maskedload_needs_mask"; @@ -78,6 +108,16 @@ struct MaskedLoadLowering final : OpRewritePattern { return failure(); } + // Check if this is either a full inbounds load or an empty, oob load. If + // so, take the fast path and don't generate a if condition, because we know + // doing the oob load is always safe. + if (succeeded(matchFullSelect(rewriter, maskedOp.getMask()))) { + Value load = + createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(), maskedOp); + rewriter.replaceOp(maskedOp, load); + return success(); + } + Location loc = maskedOp.getLoc(); Value src = maskedOp.getBase(); @@ -148,11 +188,64 @@ struct MaskedLoadLowering final : OpRewritePattern { } }; +struct FullMaskedLoadToConditionalLoad + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + +public: + LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp, + PatternRewriter &rewriter) const override { + FailureOr maybeCond = matchFullSelect(rewriter, loadOp.getMask()); + if (failed(maybeCond)) { + return failure(); + } + + Value cond = maybeCond.value(); + auto trueBuilder = [&](OpBuilder &builder, Location loc) { + Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp); + rewriter.create(loc, res); + }; + auto falseBuilder = [&](OpBuilder &builder, Location loc) { + rewriter.create(loc, loadOp.getPassThru()); + }; + auto ifOp = rewriter.create(loadOp.getLoc(), cond, trueBuilder, + falseBuilder); + rewriter.replaceOp(loadOp, ifOp); + return success(); + } +}; + +struct FullMaskedStoreToConditionalStore + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + +public: + LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp, + PatternRewriter &rewriter) const override { + FailureOr maybeCond = matchFullSelect(rewriter, storeOp.getMask()); + if (failed(maybeCond)) { + return failure(); + } + Value cond = maybeCond.value(); + + auto trueBuilder = [&](OpBuilder &builder, Location loc) { + rewriter.create(loc, storeOp.getValueToStore(), + storeOp.getBase(), storeOp.getIndices()); + rewriter.create(loc); + }; + auto ifOp = rewriter.create(storeOp.getLoc(), cond, trueBuilder); + rewriter.replaceOp(storeOp, ifOp); + return success(); + } +}; + } // namespace void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); + patterns.add(patterns.getContext(), + benefit); } struct AmdgpuMaskedloadToLoadPass final diff --git a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir index febe46bf7a759..d6682ba14eeca 100644 --- a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir +++ b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir @@ -114,3 +114,28 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space>, %idx : index, %mask : vector<4xi1>, %passthru : vector<4xf32>) -> vector<4xf32> { + %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf32, #amdgpu.address_space>, vector<4xi1>, vector<4xf32> into vector<4xf32> + return %res : vector<4xf32> +} + +// ----- + +func.func @full_select_maskedload_fatrawbuffer_to_load(%mem : memref<8x8xf16, #amdgpu.address_space>, %idx : index, %cond : i1, %passthru : vector<4xf16>) -> vector<4xf16> { + %true = arith.constant dense : vector<4xi1> + %false = arith.constant dense : vector<4xi1> + %mask = arith.select %cond, %true, %false : vector<4xi1> + %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf16, #amdgpu.address_space>, vector<4xi1>, vector<4xf16> into vector<4xf16> + return %res : vector<4xf16> +} + +func.func @full_select_maskedload_to_load(%mem : memref<8x8xf16>, %idx : index, %cond : i1, %passthru : vector<4xf16>) -> vector<4xf16> { + %true = arith.constant dense : vector<4xi1> + %false = arith.constant dense : vector<4xi1> + %mask = arith.select %cond, %true, %false : vector<4xi1> + %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16> + return %res : vector<4xf16> +} From 86115e7ae2fd3b7e91f453ff6e1ac5310bcf5e67 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Thu, 10 Jul 2025 12:27:37 +0100 Subject: [PATCH 2/3] Fix impl and address comments --- .../AMDGPU/Transforms/MaskedloadToLoad.cpp | 65 +++++++------------ .../Dialect/AMDGPU/maskedload-to-load.mlir | 52 +++++++++++---- 2 files changed, 64 insertions(+), 53 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp index b290dc46910e3..23783362454a6 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" @@ -52,42 +53,24 @@ static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter, } static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, - vector::MaskedLoadOp maskedOp) { + vector::MaskedLoadOp maskedOp, + bool passthru) { VectorType vectorType = maskedOp.getVectorType(); Value load = builder.create( loc, vectorType, maskedOp.getBase(), maskedOp.getIndices()); - Value res = builder.create( - loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru()); - return res; + if (passthru) + load = builder.create(loc, vectorType, maskedOp.getMask(), + load, maskedOp.getPassThru()); + return load; } -/// Check if the given value comes from a: -/// -/// arith.select %cond, TRUE/FALSE, TRUE/FALSE -/// -/// i.e the condition is either always true or it's always false. -/// -/// Returns the condition to use for scf.if (condition) { true } else { false }. -static FailureOr matchFullSelect(OpBuilder &b, Value val) { - auto selectOp = val.getDefiningOp(); - if (!selectOp) - return failure(); - std::optional trueInt = getConstantIntValue(selectOp.getTrueValue()); - std::optional falseInt = - getConstantIntValue(selectOp.getFalseValue()); - if (!trueInt || !falseInt) +/// Check if the given value comes from a broadcasted i1 condition. +static FailureOr matchFullMask(OpBuilder &b, Value val) { + auto broadcastOp = val.getDefiningOp(); + if (!broadcastOp) return failure(); - // getConstantIntValue returns -1 for "true" for bools. - if (trueInt.value() == -1 && falseInt.value() == 0) - return selectOp.getCondition(); - - if (trueInt.value() == 0 && falseInt.value() == -1) { - Value cond = selectOp.getCondition(); - Value one = b.create(cond.getLoc(), /*value=*/true, - /*width=*/1); - Value inverse = b.create(cond.getLoc(), cond, one); - return inverse; - } + if (!isa(broadcastOp.getSourceType())) + return broadcastOp.getSource(); return failure(); } @@ -109,11 +92,11 @@ struct MaskedLoadLowering final : OpRewritePattern { } // Check if this is either a full inbounds load or an empty, oob load. If - // so, take the fast path and don't generate a if condition, because we know - // doing the oob load is always safe. - if (succeeded(matchFullSelect(rewriter, maskedOp.getMask()))) { - Value load = - createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(), maskedOp); + // so, take the fast path and don't generate an if condition, because we + // know doing the oob load is always safe. + if (succeeded(matchFullMask(rewriter, maskedOp.getMask()))) { + Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(), + maskedOp, /*passthru=*/true); rewriter.replaceOp(maskedOp, load); return success(); } @@ -175,7 +158,8 @@ struct MaskedLoadLowering final : OpRewritePattern { }; auto elseBuilder = [&](OpBuilder &builder, Location loc) { - Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp); + Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp, + /*passthru=*/true); rewriter.create(loc, res); }; @@ -192,17 +176,17 @@ struct FullMaskedLoadToConditionalLoad : OpRewritePattern { using OpRewritePattern::OpRewritePattern; -public: LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp, PatternRewriter &rewriter) const override { - FailureOr maybeCond = matchFullSelect(rewriter, loadOp.getMask()); + FailureOr maybeCond = matchFullMask(rewriter, loadOp.getMask()); if (failed(maybeCond)) { return failure(); } Value cond = maybeCond.value(); auto trueBuilder = [&](OpBuilder &builder, Location loc) { - Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp); + Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp, + /*passthru=*/false); rewriter.create(loc, res); }; auto falseBuilder = [&](OpBuilder &builder, Location loc) { @@ -219,10 +203,9 @@ struct FullMaskedStoreToConditionalStore : OpRewritePattern { using OpRewritePattern::OpRewritePattern; -public: LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp, PatternRewriter &rewriter) const override { - FailureOr maybeCond = matchFullSelect(rewriter, storeOp.getMask()); + FailureOr maybeCond = matchFullMask(rewriter, storeOp.getMask()); if (failed(maybeCond)) { return failure(); } diff --git a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir index d6682ba14eeca..f1d0ad545539a 100644 --- a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir +++ b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir @@ -124,18 +124,46 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad // ----- -func.func @full_select_maskedload_fatrawbuffer_to_load(%mem : memref<8x8xf16, #amdgpu.address_space>, %idx : index, %cond : i1, %passthru : vector<4xf16>) -> vector<4xf16> { - %true = arith.constant dense : vector<4xi1> - %false = arith.constant dense : vector<4xi1> - %mask = arith.select %cond, %true, %false : vector<4xi1> - %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf16, #amdgpu.address_space>, vector<4xi1>, vector<4xf16> into vector<4xf16> - return %res : vector<4xf16> +// CHECK-LABEL: func.func @full_select_maskedload_fatrawbuffer_to_load +func.func @full_select_maskedload_fatrawbuffer_to_load(%arg0: memref<8x8xf16, #amdgpu.address_space>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) -> vector<4xf16> { + %0 = vector.broadcast %arg2 : i1 to vector<4xi1> + %1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16, #amdgpu.address_space>, vector<4xi1>, vector<4xf16> into vector<4xf16> + return %1 : vector<4xf16> } +// CHECK-NOT: vector.maskedload +// CHECK: vector.load +// CHECK: arith.select -func.func @full_select_maskedload_to_load(%mem : memref<8x8xf16>, %idx : index, %cond : i1, %passthru : vector<4xf16>) -> vector<4xf16> { - %true = arith.constant dense : vector<4xi1> - %false = arith.constant dense : vector<4xi1> - %mask = arith.select %cond, %true, %false : vector<4xi1> - %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16> - return %res : vector<4xf16> +// ----- + +// CHECK-LABEL: func.func @full_select_maskedload_to_load +// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>, +// CHECK-SAME: %[[IDX:.+]]: index, +// CHECK-SAME: %[[PRED:.+]]: i1, +// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>) +func.func @full_select_maskedload_to_load(%arg0: memref<8x8xf16>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) -> vector<4xf16> { + %0 = vector.broadcast %arg2 : i1 to vector<4xi1> + %1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16> + return %1 : vector<4xf16> +} +// CHECK-NOT: vector.maskedload +// CHECK: scf.if %[[PRED]] +// CHECK: %[[LOAD:.+]] = vector.load +// CHECK: scf.yield %[[LOAD]] +// CHECK: else +// CHECK: scf.yield %[[PASSTHRU]] + +// ----- + +// CHECK-LABEL: func.func @full_mask_maskedstore_to_store +// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>, +// CHECK-SAME: %[[IDX:.+]]: index, +// CHECK-SAME: %[[PRED:.+]]: i1, +func.func @full_mask_maskedstore_to_store(%arg0: memref<8x8xf16>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) { + %0 = vector.broadcast %arg2 : i1 to vector<4xi1> + vector.maskedstore %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> + return } +// CHECK-NOT: vector.maskedstore +// CHECK: scf.if %[[PRED]] +// CHECK: vector.store From ede3a936b55025977e3330d9cc8f3958f17aa2cb Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Thu, 10 Jul 2025 15:36:31 +0100 Subject: [PATCH 3/3] address comment --- mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp index 23783362454a6..60c8660658a95 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp @@ -69,9 +69,9 @@ static FailureOr matchFullMask(OpBuilder &b, Value val) { auto broadcastOp = val.getDefiningOp(); if (!broadcastOp) return failure(); - if (!isa(broadcastOp.getSourceType())) - return broadcastOp.getSource(); - return failure(); + if (isa(broadcastOp.getSourceType())) + return failure(); + return broadcastOp.getSource(); } static constexpr char kMaskedloadNeedsMask[] =