From bf34bb1c5be8cbc20b7880d9e4b148ffe741f909 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Fri, 2 May 2025 13:46:58 -0400 Subject: [PATCH] [mlir][SCF][TilingInterface] Add support for divisibility hints to tiling options The tiling helpers in SCF have support for emitting residual tiles when the iteration space is indivisible by a tile size. For static iteration spaces, the helpers omit checks for residual tiles when it's obvious they aren't needed, however we are always pessimistic for dynamic sizes. Rather than requiring the helpers to analyze the IR for divisibility information, which can require an expensive analysis, it's better to do such analysis in advance and then tile. This patch adds a tiling option called |divisibilityHint| that allows specifying which dimensions of the iteration space are statically known to be divisible by their respective tile sizes. --- .../SCF/Transforms/TileUsingInterface.h | 11 +++ .../SCF/Transforms/TileUsingInterface.cpp | 31 +++--- .../TilingInterface/tile-using-scfforall.mlir | 96 +++++++++++++++++++ .../TestTilingInterfaceTransformOps.cpp | 35 ++++--- .../TestTilingInterfaceTransformOps.td | 6 ++ 5 files changed, 157 insertions(+), 22 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 33a43ce2ee7bb..3ca1bdd0fbf76 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -124,6 +124,17 @@ struct SCFTilingOptions { mappingVector = llvm::to_vector(mapping); return *this; } + + /// Gives hints for whether the tile sizes divide the iteration space evenly. + /// For static sizes, this is trivially verifiable (and the helpers here take + /// advantage of that), however for dynamic sizes we are always forced to be + /// pessimistic. This allows external analysis to check for divisibility and + /// pass on the info to tiling. + SmallVector divisibilityHint = {}; + SCFTilingOptions &setDivisibilityHint(ArrayRef hint) { + divisibilityHint.assign(hint.begin(), hint.end()); + return *this; + } }; /// Transformation information returned after tiling. diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 7edf19689d2e1..16edb1d8c6fce 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -248,7 +248,8 @@ static std::tuple, SmallVector> getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef iterationDomain, ArrayRef tileSizes, - ArrayRef numThreads) { + ArrayRef numThreads, + ArrayRef divisibilityHint) { SmallVector offsets, sizes; int materializedLoopNum = 0; @@ -260,8 +261,8 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, offsetExpr = d0 + d1 * s0; residualTileSizeExpr = s1 - (d0 + d1 * s0); - for (auto [nt, tileSize, loopRange] : - llvm::zip_equal(numThreads, tileSizes, iterationDomain)) { + for (auto [nt, tileSize, loopRange, divHint] : llvm::zip_equal( + numThreads, tileSizes, iterationDomain, divisibilityHint)) { // Non-tiled cases, set the offset and size to the // `loopRange.offset/size`. @@ -280,7 +281,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, {loopRange.offset, nt, tileSize, loopRange.size}); OpFoldResult size = tileSize; - if (!isConstantIntValue(residualTileSize, 0)) { + if (!isConstantIntValue(residualTileSize, 0) && !divHint) { OpFoldResult sizeMinusOffsetPerThread = affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0, {offset, loopRange.size}); @@ -299,7 +300,8 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, // `nonNegativeTileSize = affine.max(0, tileSize)`. // This `max` can be avoided if // `offset + tileSize * (numThreads - 1) < (ub - lb)` - if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) { + if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size) && + !divHint) { AffineMap maxMap = AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); size = affine::makeComposedFoldedAffineMax( @@ -311,8 +313,8 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, } return {offsets, sizes}; } else { - for (auto [tileSize, loopRange] : - llvm::zip_equal(tileSizes, iterationDomain)) { + for (auto [tileSize, loopRange, divHint] : + llvm::zip_equal(tileSizes, iterationDomain, divisibilityHint)) { // Non-tiled cases, set the offset and size to the // `loopRange.offset/size`. @@ -325,8 +327,9 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, Value iv = ivs[materializedLoopNum++]; OpFoldResult offset = getAsOpFoldResult(iv); offsets.push_back(offset); - OpFoldResult size = - getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize); + OpFoldResult size = divHint ? tileSize + : getBoundedTileSize(rewriter, loc, loopRange, + offset, tileSize); sizes.push_back(size); } return {offsets, sizes}; @@ -950,6 +953,11 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, std::tie(tileSizes, numThreads) = getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options); + // 2a. Pad the divisibility hints to the domain rank. + SmallVector divisibilityHint = options.divisibilityHint; + divisibilityHint.append(iterationDomain.size() - divisibilityHint.size(), + false); + // Check if it is safe to tile. This is hold over from previous iterations // of tile to for-all. Consider dropping it. if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { @@ -982,8 +990,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, -> LogicalResult { // 4a. Compute the `offsets` and `sizes` to use for tiling. SmallVector offsets, sizes; - std::tie(offsets, sizes) = getTileOffsetAndSizes( - rewriter, loc, ivs, iterationDomain, tileSizes, numThreads); + std::tie(offsets, sizes) = + getTileOffsetAndSizes(rewriter, loc, ivs, iterationDomain, tileSizes, + numThreads, divisibilityHint); // 4b. If interchange was provided, apply inverse of the interchange // to get back the offsets/sizes in the order to be specified. diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir index 745a82fc0da75..558ca798fffc2 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir @@ -349,3 +349,99 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: func @check_scalar_memref_operation // CHECK-NOT: scf.for // CHECK: linalg.generic + +// ----- + +func.func @simple_matmul_assume_divisible_n(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %0 = linalg.matmul + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.tile_using_forall %matmul [10, 20] + divisibility_hint = [false, true] mapping = [#gpu.block, #gpu.block] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)> +// CHECK: func.func @simple_matmul_assume_divisible_n( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK: %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) = +// CHECK-SAME: (0, 0) to (%[[M]], %[[N]]) step (10, 20) shared_outs(%[[INIT:.+]] = %[[ARG2]]) +// CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] +// CHECK: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1] +// CHECK: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]] +// CHECK-SAME: [0, %[[IV1]]] [%[[K]], 20] [1, 1] +// CHECK: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], 20] [1, 1] +// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : tensor, tensor +// CHECK-SAME: outs(%[[INIT_TILE]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], 20] [1, 1] +// CHECK: mapping = [#gpu.block, #gpu.block] +// CHECK: return %[[RESULT]] + +// ----- + +func.func @simple_matmul_extend_divisibility(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %0 = linalg.matmul + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.tile_using_forall %matmul [10, 20] + divisibility_hint = [true] mapping = [#gpu.block, #gpu.block] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 20)> +// CHECK: func.func @simple_matmul_extend_divisibility( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK: %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) = +// CHECK-SAME: (0, 0) to (%[[M]], %[[N]]) step (10, 20) shared_outs(%[[INIT:.+]] = %[[ARG2]]) +// CHECK: %[[TS_X:.+]] = affine.min #[[MAP0]](%[[IV1]])[%[[N]]] +// CHECK: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[IV0]], 0] [10, %[[K]]] [1, 1] +// CHECK: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]] +// CHECK-SAME: [0, %[[IV1]]] [%[[K]], %[[TS_X]]] [1, 1] +// CHECK: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [10, %[[TS_X]]] [1, 1] +// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : tensor<10x?xf32>, tensor +// CHECK-SAME: outs(%[[INIT_TILE]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [10, %[[TS_X]]] [1, 1] +// CHECK: mapping = [#gpu.block, #gpu.block] +// CHECK: return %[[RESULT]] diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 45d6ae3820159..41446ba176a3b 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/TilingInterface.h" +#include "llvm/ADT/SmallVectorExtras.h" #define GET_OP_CLASSES #include "TestTilingInterfaceTransformOps.h.inc" @@ -54,12 +55,11 @@ static llvm::SmallDenseSet collectTiledAndFusedOps(Operation *op) { /// Apply a tile and fuse transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. template -static LogicalResult -applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp, - Range &&payloadOps, unsigned numLoops, - ArrayRef tileSizes, - ArrayRef interchange, bool useForall, - TransformResults &transformResults) { +static LogicalResult applyTileAndFuseToAll( + RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, + unsigned numLoops, ArrayRef tileSizes, + ArrayRef interchange, ArrayRef divisibilityHint, + bool useForall, TransformResults &transformResults) { SmallVector tiledOps; SmallVector> loopOps(numLoops); @@ -85,6 +85,7 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp, if (useForall) { tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); } + tilingOptions.setDivisibilityHint(divisibilityHint); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.setTilingOptions(tilingOptions); @@ -151,13 +152,16 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, SmallVector tileInterchange = extractFromIntegerArrayAttr(getTileInterchange()); + SmallVector divisibilityHint( + getDivisibilityHint().getAsValueRange()); + SmallVector tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); LogicalResult result = applyTileAndFuseToAll( rewriter, getOperation(), state.getPayloadOps(getTarget()), tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr, - tileInterchange, getUseForall(), transformResults); + tileInterchange, divisibilityHint, getUseForall(), transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } @@ -237,7 +241,8 @@ template static LogicalResult applyTileToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, ArrayRef tileSizes, - ArrayRef interchange, std::optional mapping, + ArrayRef interchange, ArrayRef divisibilityHint, + std::optional mapping, TransformResults &transformResults) { SmallVector tiledOps; SmallVector loopOps; @@ -251,6 +256,7 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp, if (mapping) { tilingOptions.setMapping(mapping.value().getValue()); } + tilingOptions.setDivisibilityHint(divisibilityHint); tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); rewriter.setInsertionPoint(target); @@ -287,9 +293,12 @@ transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter, SmallVector tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); - LogicalResult result = - applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()), - tileSizesOfr, interchange, getMapping(), transformResults); + SmallVector divisibilityHint( + getDivisibilityHint().getAsValueRange()); + + LogicalResult result = applyTileToAll( + rewriter, getOperation(), state.getPayloadOps(getTarget()), tileSizesOfr, + interchange, divisibilityHint, getMapping(), transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } @@ -363,11 +372,15 @@ transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter, SmallVector tileInterchange = extractFromIntegerArrayAttr(getInterchange()); + SmallVector divisibilityHint( + getDivisibilityHint().getAsValueRange()); + scf::SCFTilingOptions tilingOptions; tilingOptions.interchangeVector = tileInterchange; SmallVector tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); + tilingOptions = tilingOptions.setDivisibilityHint(divisibilityHint); tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions = tilingOptions; diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td index 98f7145c99cb1..e7d8732808d78 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td @@ -38,12 +38,14 @@ def TestFuseAndYieldOp : Op:$tile_sizes, DefaultValuedAttr:$tile_interchange, + DefaultValuedOptionalAttr:$divisibility_hint, DefaultValuedAttr:$use_forall); let results = (outs TransformHandleTypeInterface:$transfomed, Variadic:$loops); let assemblyFormat = [{ $target ($tile_sizes^)? (`interchange` $tile_interchange^)? + (`divisibility_hint` `=` $divisibility_hint^)? (`use_forall` $use_forall^)? attr-dict `:` functional-type(operands, results) }]; @@ -91,12 +93,14 @@ def TestTileUsingForallOp : Op:$tile_sizes, DefaultValuedOptionalAttr:$interchange, + DefaultValuedOptionalAttr:$divisibility_hint, OptionalAttr:$mapping); let results = (outs TransformHandleTypeInterface:$tiled_op, Variadic:$loops); let assemblyFormat = [{ $target ($tile_sizes^)? (`interchange` `=` $interchange^)? + (`divisibility_hint` `=` $divisibility_hint^)? (`mapping` `=` $mapping^)? attr-dict `:` functional-type(operands, results) }]; @@ -114,12 +118,14 @@ def TestFuseUsingForallOp : Op:$tile_sizes, DefaultValuedOptionalAttr:$interchange, + DefaultValuedOptionalAttr:$divisibility_hint, OptionalAttr:$mapping); let results = (outs TransformHandleTypeInterface:$tiled_ops, Variadic:$loops); let assemblyFormat = [{ $root_op ($tile_sizes^)? (`interchange` $interchange^)? + (`divisibility_hint` `=` $divisibility_hint^)? (`mapping` `=` $mapping^)? attr-dict `:` functional-type(operands, results) }];