From 586839035dfaaf45d66ad6b1184f94465c10906f Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 2 Jul 2025 17:29:08 +0000 Subject: [PATCH 01/13] working but bug in dead result --- .../Vector/Transforms/VectorDistribute.cpp | 66 ++++++++++++++----- 1 file changed, 50 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index af90ed8f5deaf..28c957bf61921 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SetVector.h" @@ -1777,24 +1778,42 @@ struct WarpOpScfForOp : public WarpDistributionPattern { if (llvm::is_contained(distTypes, Type{})) return failure(); + llvm::errs() << "escpaing values size: " << escapingValues.size() << "\n"; + + SmallVector yieldedValuesFromWarpOp; + // All init args of the forOp are yielded from the original warp op. + for (Value initArg : forOp.getInitArgs()) { + yieldedValuesFromWarpOp.push_back(initArg); + // find distributed type for the init arg. + Type distType = initArg.getType(); + if (auto vecType = dyn_cast(distType)) { + AffineMap map = distributionMapFn(initArg); + distType = getDistributedType(vecType, map, warpOp.getWarpSize()); + } + distTypes.push_back(distType); + } + // All escaping values are yielded from the original warp op. + yieldedValuesFromWarpOp.insert(yieldedValuesFromWarpOp.end(), + escapingValues.begin(), + escapingValues.end()); + SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, escapingValues.getArrayRef(), distTypes, - newRetIndices); + rewriter, warpOp, yieldedValuesFromWarpOp, distTypes, newRetIndices); yield = cast( newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); SmallVector newOperands; SmallVector resultIdx; - // Collect all the outputs coming from the forOp. + // Collect the new init args coming from the new warp op. + for (size_t i = 0; i < forOp.getInitArgs().size(); ++i) + newOperands.push_back(newWarpOp.getResult(newRetIndices[i])); for (OpOperand &yieldOperand : yield->getOpOperands()) { if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) continue; - auto forResult = cast(yieldOperand.get()); - newOperands.push_back( - newWarpOp.getResult(yieldOperand.getOperandNumber())); + OpResult forResult = cast(yieldOperand.get()); + resultIdx.push_back(forResult.getResultNumber()); yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); - resultIdx.push_back(yieldOperand.getOperandNumber()); } OpBuilder::InsertionGuard g(rewriter); @@ -1812,8 +1831,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern { SmallVector warpInputType(forOp.getResultTypes().begin(), forOp.getResultTypes().end()); llvm::SmallDenseMap argIndexMapping; - for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) { - warpInput.push_back(newWarpOp.getResult(retIdx)); + for (size_t i = forOp.getInitArgs().size(); i < newRetIndices.size(); ++i) { + warpInput.push_back(newWarpOp.getResult(i)); argIndexMapping[escapingValues[i]] = warpInputType.size(); warpInputType.push_back(inputTypes[i]); } @@ -1826,24 +1845,37 @@ struct WarpOpScfForOp : public WarpDistributionPattern { for (Value args : innerWarp.getBody()->getArguments()) { argMapping.push_back(args); } - argMapping.resize(forOp.getBody()->getNumArguments()); + auto forOpCopy = cast(rewriter.clone(*forOp.getOperation())); + argMapping.resize(forOpCopy.getBody()->getNumArguments()); SmallVector yieldOperands; - for (Value operand : forOp.getBody()->getTerminator()->getOperands()) + for (Value operand : forOpCopy.getBody()->getTerminator()->getOperands()) yieldOperands.push_back(operand); - rewriter.eraseOp(forOp.getBody()->getTerminator()); - rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); + + rewriter.eraseOp(forOpCopy.getBody()->getTerminator()); + rewriter.mergeBlocks(forOpCopy.getBody(), innerWarp.getBody(), argMapping); rewriter.setInsertionPointToEnd(innerWarp.getBody()); rewriter.create(innerWarp.getLoc(), yieldOperands); rewriter.setInsertionPointAfter(innerWarp); if (!innerWarp.getResults().empty()) - rewriter.create(forOp.getLoc(), innerWarp.getResults()); - rewriter.eraseOp(forOp); + rewriter.create(forOpCopy.getLoc(), innerWarp.getResults()); + // forOpCopy->getParentOp()->getParentOp()->print(llvm::outs()); + // llvm::outs() << "\n"; + // llvm::errs() << "erasing for op\n"; + + rewriter.eraseOp(forOpCopy); // Replace the warpOp result coming from the original ForOp. + // print resultIdx for debugging. + llvm::errs() << "resultIdx: "; + for (auto idx : resultIdx) + llvm::errs() << idx << " "; + llvm::errs() << "\n"; for (const auto &res : llvm::enumerate(resultIdx)) { rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()), newForOp.getResult(res.index())); - newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value())); + // newForOp->setOperand(res.index() + 3, + // newWarpOp.getResult(res.value())); } + rewriter.eraseOp(forOp); newForOp.walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { auto it = argIndexMapping.find(operand.get()); @@ -1852,6 +1884,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern { operand.set(innerWarp.getBodyRegion().getArgument(it->second)); } }); + newForOp->getParentOp()->print(llvm::outs()); + llvm::outs() << "\n"; // Finally, hoist out any now uniform code from the inner warp op. mlir::vector::moveScalarUniformCode(innerWarp); From 4c363175e0c5a0d6cddfe7ac3532051f76d88039 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 2 Jul 2025 17:59:39 +0000 Subject: [PATCH 02/13] working version --- .../Vector/Transforms/VectorDistribute.cpp | 40 ++++++++++++------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 28c957bf61921..dae62d2cecc04 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1796,26 +1796,34 @@ struct WarpOpScfForOp : public WarpDistributionPattern { yieldedValuesFromWarpOp.insert(yieldedValuesFromWarpOp.end(), escapingValues.begin(), escapingValues.end()); - - SmallVector newRetIndices; - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, yieldedValuesFromWarpOp, distTypes, newRetIndices); - yield = cast( - newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - - SmallVector newOperands; + // record result mapping. SmallVector resultIdx; - // Collect the new init args coming from the new warp op. - for (size_t i = 0; i < forOp.getInitArgs().size(); ++i) - newOperands.push_back(newWarpOp.getResult(newRetIndices[i])); for (OpOperand &yieldOperand : yield->getOpOperands()) { if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) continue; OpResult forResult = cast(yieldOperand.get()); resultIdx.push_back(forResult.getResultNumber()); - yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); + // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); } + // SmallVector newRetIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( + rewriter, warpOp, yieldedValuesFromWarpOp, distTypes); + yield = cast( + newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + + SmallVector newOperands; + // Collect the new init args coming from the new warp op. + for (size_t i = 0; i < forOp.getInitArgs().size(); ++i) + newOperands.push_back(newWarpOp.getResult(i)); + // for (OpOperand &yieldOperand : yield->getOpOperands()) { + // if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) + // continue; + // OpResult forResult = cast(yieldOperand.get()); + // resultIdx.push_back(forResult.getResultNumber()); + // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); + // } + OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); @@ -1831,7 +1839,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern { SmallVector warpInputType(forOp.getResultTypes().begin(), forOp.getResultTypes().end()); llvm::SmallDenseMap argIndexMapping; - for (size_t i = forOp.getInitArgs().size(); i < newRetIndices.size(); ++i) { + for (size_t i = forOp.getInitArgs().size(); i < newWarpOp->getNumResults(); + ++i) { warpInput.push_back(newWarpOp.getResult(i)); argIndexMapping[escapingValues[i]] = warpInputType.size(); warpInputType.push_back(inputTypes[i]); @@ -1870,12 +1879,13 @@ struct WarpOpScfForOp : public WarpDistributionPattern { llvm::errs() << idx << " "; llvm::errs() << "\n"; for (const auto &res : llvm::enumerate(resultIdx)) { - rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()), - newForOp.getResult(res.index())); + rewriter.replaceAllUsesExcept(warpOp.getResult(res.value()), + newForOp.getResult(res.index()), newForOp); // newForOp->setOperand(res.index() + 3, // newWarpOp.getResult(res.value())); } rewriter.eraseOp(forOp); + rewriter.eraseOp(warpOp); newForOp.walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { auto it = argIndexMapping.find(operand.get()); From 3595f1758ad2c71f2f265253cf0a66ec3bdc94d2 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 2 Jul 2025 22:10:56 +0000 Subject: [PATCH 03/13] working version refined --- .../Vector/Transforms/VectorDistribute.cpp | 75 +++++++++++++------ 1 file changed, 52 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index dae62d2cecc04..52a55d104c0bd 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/Value.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/FormatVariadic.h" @@ -1778,37 +1779,53 @@ struct WarpOpScfForOp : public WarpDistributionPattern { if (llvm::is_contained(distTypes, Type{})) return failure(); - llvm::errs() << "escpaing values size: " << escapingValues.size() << "\n"; + // record result mapping. + SmallVector resultIdx; + llvm::SmallDenseMap forResultToWarpResultMapping; + for (OpOperand &yieldOperand : yield->getOpOperands()) { + if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) + continue; + OpResult forResult = cast(yieldOperand.get()); + resultIdx.push_back(forResult.getResultNumber()); + forResultToWarpResultMapping[forResult.getResultNumber()] = + yieldOperand.getOperandNumber(); + // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); + } + + // llvm::errs() << "escpaing values size: " << escapingValues.size() << + // "\n"; SmallVector yieldedValuesFromWarpOp; + SmallVector yieldedTypesFromWarpOp; // All init args of the forOp are yielded from the original warp op. - for (Value initArg : forOp.getInitArgs()) { + for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) { yieldedValuesFromWarpOp.push_back(initArg); // find distributed type for the init arg. Type distType = initArg.getType(); if (auto vecType = dyn_cast(distType)) { - AffineMap map = distributionMapFn(initArg); - distType = getDistributedType(vecType, map, warpOp.getWarpSize()); + if (forResultToWarpResultMapping.contains(i)) { + // If the init arg is yielded from the warp op, we need to compute the + // distributed type. + distType = + warpOp.getResult(forResultToWarpResultMapping[i]).getType(); + } else { + AffineMap map = distributionMapFn(initArg); + distType = getDistributedType(vecType, map, warpOp.getWarpSize()); + } } - distTypes.push_back(distType); + // llvm::errs() << "distributed type: " << distType << "\n"; + yieldedTypesFromWarpOp.push_back(distType); } // All escaping values are yielded from the original warp op. yieldedValuesFromWarpOp.insert(yieldedValuesFromWarpOp.end(), escapingValues.begin(), escapingValues.end()); - // record result mapping. - SmallVector resultIdx; - for (OpOperand &yieldOperand : yield->getOpOperands()) { - if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) - continue; - OpResult forResult = cast(yieldOperand.get()); - resultIdx.push_back(forResult.getResultNumber()); - // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); - } + yieldedTypesFromWarpOp.insert(yieldedTypesFromWarpOp.end(), + distTypes.begin(), distTypes.end()); // SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, yieldedValuesFromWarpOp, distTypes); + rewriter, warpOp, yieldedValuesFromWarpOp, yieldedTypesFromWarpOp); yield = cast( newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); @@ -1839,15 +1856,27 @@ struct WarpOpScfForOp : public WarpDistributionPattern { SmallVector warpInputType(forOp.getResultTypes().begin(), forOp.getResultTypes().end()); llvm::SmallDenseMap argIndexMapping; + // llvm::errs() << "setting arg index mapping\n"; for (size_t i = forOp.getInitArgs().size(); i < newWarpOp->getNumResults(); ++i) { warpInput.push_back(newWarpOp.getResult(i)); - argIndexMapping[escapingValues[i]] = warpInputType.size(); - warpInputType.push_back(inputTypes[i]); + argIndexMapping[escapingValues[i - forOp.getInitArgs().size()]] = + warpInputType.size(); + warpInputType.push_back(inputTypes[i - forOp.getInitArgs().size()]); } + // for (auto [i, r] : llvm::enumerate( + // newWarpOp.getResults().drop_front(forOp.getInitArgs().size()))) + // { + // warpInput.push_back(r); + // argIndexMapping[escapingValues[i]] = warpInputType.size(); + // warpInputType.push_back(inputTypes[i]); + // } + // llvm::errs() << "go here\n"; auto innerWarp = rewriter.create( newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(), newWarpOp.getWarpSize(), warpInput, warpInputType); + // newForOp->getParentOp()->print(llvm::outs()); + // llvm::outs() << "\n"; SmallVector argMapping; argMapping.push_back(newForOp.getInductionVar()); @@ -1874,10 +1903,10 @@ struct WarpOpScfForOp : public WarpDistributionPattern { rewriter.eraseOp(forOpCopy); // Replace the warpOp result coming from the original ForOp. // print resultIdx for debugging. - llvm::errs() << "resultIdx: "; - for (auto idx : resultIdx) - llvm::errs() << idx << " "; - llvm::errs() << "\n"; + // llvm::errs() << "resultIdx: "; + // for (auto idx : resultIdx) + // llvm::errs() << idx << " "; + // llvm::errs() << "\n"; for (const auto &res : llvm::enumerate(resultIdx)) { rewriter.replaceAllUsesExcept(warpOp.getResult(res.value()), newForOp.getResult(res.index()), newForOp); @@ -1894,8 +1923,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern { operand.set(innerWarp.getBodyRegion().getArgument(it->second)); } }); - newForOp->getParentOp()->print(llvm::outs()); - llvm::outs() << "\n"; + // newForOp->getParentOp()->print(llvm::outs()); + // llvm::outs() << "\n"; // Finally, hoist out any now uniform code from the inner warp op. mlir::vector::moveScalarUniformCode(innerWarp); From ba94ee21098465ea466c79acaaf01953b34cfc70 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 8 Jul 2025 00:21:19 +0000 Subject: [PATCH 04/13] working failing case now --- .../Vector/Transforms/VectorDistribute.cpp | 81 +++++++++++++------ 1 file changed, 58 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 52a55d104c0bd..adfed18a625b3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -19,8 +19,10 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/FormatVariadic.h" @@ -1779,22 +1781,35 @@ struct WarpOpScfForOp : public WarpDistributionPattern { if (llvm::is_contained(distTypes, Type{})) return failure(); + SmallVector nonForYieldedValues; + // SmallVector nonForYieldedTypes; + SmallVector nonForResultIndices; + // record result mapping. - SmallVector resultIdx; - llvm::SmallDenseMap forResultToWarpResultMapping; + DenseMap forResultMapping; + DenseMap warpResultMapping; + // llvm::SmallDenseMap forResultToWarpResultMapping; for (OpOperand &yieldOperand : yield->getOpOperands()) { - if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) + if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) { + nonForYieldedValues.push_back(yieldOperand.get()); + // nonForYieldedTypes.push_back( + // warpOp.getResult(yieldOperand.getOperandNumber()).getType()); + nonForResultIndices.push_back(yieldOperand.getOperandNumber()); continue; + } OpResult forResult = cast(yieldOperand.get()); - resultIdx.push_back(forResult.getResultNumber()); - forResultToWarpResultMapping[forResult.getResultNumber()] = - yieldOperand.getOperandNumber(); + forResultMapping[yieldOperand.getOperandNumber()] = + forResult.getResultNumber(); + // forResultToWarpResultMapping[forResult.getResultNumber()] = + // yieldOperand.getOperandNumber(); // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); } + // llvm::errs() << "non for yielded values size: " + // << nonForYieldedValues.size() << "\n"; + // llvm::errs() << "escpaing values size: " << escapingValues.size() << // "\n"; - SmallVector yieldedValuesFromWarpOp; SmallVector yieldedTypesFromWarpOp; // All init args of the forOp are yielded from the original warp op. @@ -1803,15 +1818,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // find distributed type for the init arg. Type distType = initArg.getType(); if (auto vecType = dyn_cast(distType)) { - if (forResultToWarpResultMapping.contains(i)) { - // If the init arg is yielded from the warp op, we need to compute the - // distributed type. - distType = - warpOp.getResult(forResultToWarpResultMapping[i]).getType(); - } else { - AffineMap map = distributionMapFn(initArg); - distType = getDistributedType(vecType, map, warpOp.getWarpSize()); - } + // if (forResultToWarpResultMapping.contains(i)) { + // // If the init arg is yielded from the warp op, we need to compute + // the + // // distributed type. + // distType = + // warpOp.getResult(forResultToWarpResultMapping[i]).getType(); + // } else { + AffineMap map = distributionMapFn(initArg); + distType = getDistributedType(vecType, map, warpOp.getWarpSize()); + // } } // llvm::errs() << "distributed type: " << distType << "\n"; yieldedTypesFromWarpOp.push_back(distType); @@ -1823,12 +1839,23 @@ struct WarpOpScfForOp : public WarpDistributionPattern { yieldedTypesFromWarpOp.insert(yieldedTypesFromWarpOp.end(), distTypes.begin(), distTypes.end()); + for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) { + warpResultMapping[nonForResultIndices[i]] = + yieldedValuesFromWarpOp.size(); + yieldedValuesFromWarpOp.push_back(v); + yieldedTypesFromWarpOp.push_back( + warpOp.getResult(nonForResultIndices[i]).getType()); + } + // SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( rewriter, warpOp, yieldedValuesFromWarpOp, yieldedTypesFromWarpOp); yield = cast( newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + // newWarpOp->print(llvm::outs()); + // llvm::outs() << "\n"; + SmallVector newOperands; // Collect the new init args coming from the new warp op. for (size_t i = 0; i < forOp.getInitArgs().size(); ++i) @@ -1857,12 +1884,13 @@ struct WarpOpScfForOp : public WarpDistributionPattern { forOp.getResultTypes().end()); llvm::SmallDenseMap argIndexMapping; // llvm::errs() << "setting arg index mapping\n"; - for (size_t i = forOp.getInitArgs().size(); i < newWarpOp->getNumResults(); - ++i) { + unsigned escapingValuesStartIdx = forOp.getInitArgs().size(); + for (size_t i = escapingValuesStartIdx; + i < escapingValuesStartIdx + escapingValues.size(); ++i) { warpInput.push_back(newWarpOp.getResult(i)); - argIndexMapping[escapingValues[i - forOp.getInitArgs().size()]] = + argIndexMapping[escapingValues[i - escapingValuesStartIdx]] = warpInputType.size(); - warpInputType.push_back(inputTypes[i - forOp.getInitArgs().size()]); + warpInputType.push_back(inputTypes[i - escapingValuesStartIdx]); } // for (auto [i, r] : llvm::enumerate( // newWarpOp.getResults().drop_front(forOp.getInitArgs().size()))) @@ -1907,9 +1935,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // for (auto idx : resultIdx) // llvm::errs() << idx << " "; // llvm::errs() << "\n"; - for (const auto &res : llvm::enumerate(resultIdx)) { - rewriter.replaceAllUsesExcept(warpOp.getResult(res.value()), - newForOp.getResult(res.index()), newForOp); + for (auto [origIdx, newIdx] : forResultMapping) { + rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), + newForOp.getResult(newIdx), newForOp); + // newForOp->setOperand(res.index() + 3, + // newWarpOp.getResult(res.value())); + } + + for (auto [origIdx, newIdx] : warpResultMapping) { + rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), + newWarpOp.getResult(newIdx)); // newForOp->setOperand(res.index() + 3, // newWarpOp.getResult(res.value())); } From 28ef9c9400695acb954c34bc48c9a43b710fb92b Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Tue, 8 Jul 2025 23:27:46 +0000 Subject: [PATCH 05/13] add comments and tests --- .../Vector/Transforms/VectorDistribute.cpp | 217 ++++++++---------- .../Vector/vector-warp-distribute.mlir | 79 +++++++ 2 files changed, 172 insertions(+), 124 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index adfed18a625b3..b49c2063b075d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1749,19 +1749,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern { : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - auto yield = cast( + auto newWarpOpYield = cast( warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); // Only pick up forOp if it is the last op in the region. - Operation *lastNode = yield->getPrevNode(); + Operation *lastNode = newWarpOpYield->getPrevNode(); auto forOp = dyn_cast_or_null(lastNode); if (!forOp) return failure(); // Collect Values that come from the warp op but are outside the forOp. - // Those Value needs to be returned by the original warpOp and passed to - // the new op. + // Those Value needs to be returned by the new warp op. llvm::SmallSetVector escapingValues; - SmallVector inputTypes; - SmallVector distTypes; + SmallVector escapingValueInputTypes; + SmallVector escapingValuedistTypes; mlir::visitUsedValuesDefinedAbove( forOp.getBodyRegion(), [&](OpOperand *operand) { Operation *parent = operand->get().getParentRegion()->getParentOp(); @@ -1773,183 +1772,155 @@ struct WarpOpScfForOp : public WarpDistributionPattern { AffineMap map = distributionMapFn(operand->get()); distType = getDistributedType(vecType, map, warpOp.getWarpSize()); } - inputTypes.push_back(operand->get().getType()); - distTypes.push_back(distType); + escapingValueInputTypes.push_back(operand->get().getType()); + escapingValuedistTypes.push_back(distType); } }); - if (llvm::is_contained(distTypes, Type{})) + if (llvm::is_contained(escapingValuedistTypes, Type{})) return failure(); - + // Warp op can yield two types of values: + // 1. Values that are not results of the forOp: + // These values must also be yielded by the new warp op. Also, we need to + // record the index mapping for these values to replace them later. + // 2. Values that are results of the forOp: + // In this case, we record the index mapping between the warp op result + // index and matching forOp result index. SmallVector nonForYieldedValues; - // SmallVector nonForYieldedTypes; SmallVector nonForResultIndices; - - // record result mapping. DenseMap forResultMapping; - DenseMap warpResultMapping; - // llvm::SmallDenseMap forResultToWarpResultMapping; - for (OpOperand &yieldOperand : yield->getOpOperands()) { + for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) { + // Yielded value is not a result of the forOp. if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) { nonForYieldedValues.push_back(yieldOperand.get()); - // nonForYieldedTypes.push_back( - // warpOp.getResult(yieldOperand.getOperandNumber()).getType()); nonForResultIndices.push_back(yieldOperand.getOperandNumber()); continue; } OpResult forResult = cast(yieldOperand.get()); forResultMapping[yieldOperand.getOperandNumber()] = forResult.getResultNumber(); - // forResultToWarpResultMapping[forResult.getResultNumber()] = - // yieldOperand.getOperandNumber(); - // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); } - // llvm::errs() << "non for yielded values size: " - // << nonForYieldedValues.size() << "\n"; - - // llvm::errs() << "escpaing values size: " << escapingValues.size() << - // "\n"; - SmallVector yieldedValuesFromWarpOp; - SmallVector yieldedTypesFromWarpOp; - // All init args of the forOp are yielded from the original warp op. + // Newly created warp op will yield values in following order: + // 1. All init args of the forOp. + // 2. All escaping values. + // 3. All non-for yielded values. + SmallVector newWarpOpYieldValues; + SmallVector newWarpOpDistTypes; for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) { - yieldedValuesFromWarpOp.push_back(initArg); - // find distributed type for the init arg. + newWarpOpYieldValues.push_back(initArg); + // Compute the distributed type for this init arg. Type distType = initArg.getType(); if (auto vecType = dyn_cast(distType)) { - // if (forResultToWarpResultMapping.contains(i)) { - // // If the init arg is yielded from the warp op, we need to compute - // the - // // distributed type. - // distType = - // warpOp.getResult(forResultToWarpResultMapping[i]).getType(); - // } else { AffineMap map = distributionMapFn(initArg); distType = getDistributedType(vecType, map, warpOp.getWarpSize()); - // } } - // llvm::errs() << "distributed type: " << distType << "\n"; - yieldedTypesFromWarpOp.push_back(distType); + newWarpOpDistTypes.push_back(distType); } - // All escaping values are yielded from the original warp op. - yieldedValuesFromWarpOp.insert(yieldedValuesFromWarpOp.end(), - escapingValues.begin(), - escapingValues.end()); - yieldedTypesFromWarpOp.insert(yieldedTypesFromWarpOp.end(), - distTypes.begin(), distTypes.end()); - + // Insert escaping values and their distributed types. + newWarpOpYieldValues.insert(newWarpOpYieldValues.end(), + escapingValues.begin(), escapingValues.end()); + newWarpOpDistTypes.insert(newWarpOpDistTypes.end(), + escapingValuedistTypes.begin(), + escapingValuedistTypes.end()); + // Next, we insert all non-for yielded values and their distributed types. + // We also create a mapping between the non-for yielded value index and the + // corresponding new warp op yield value index (needed to update users + // later). + DenseMap warpResultMapping; for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) { - warpResultMapping[nonForResultIndices[i]] = - yieldedValuesFromWarpOp.size(); - yieldedValuesFromWarpOp.push_back(v); - yieldedTypesFromWarpOp.push_back( + warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size(); + newWarpOpYieldValues.push_back(v); + newWarpOpDistTypes.push_back( warpOp.getResult(nonForResultIndices[i]).getType()); } - - // SmallVector newRetIndices; + // Create the new warp op with the updated yield values and types. WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, yieldedValuesFromWarpOp, yieldedTypesFromWarpOp); - yield = cast( + rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); + newWarpOpYield = cast( newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - // newWarpOp->print(llvm::outs()); - // llvm::outs() << "\n"; - - SmallVector newOperands; - // Collect the new init args coming from the new warp op. - for (size_t i = 0; i < forOp.getInitArgs().size(); ++i) - newOperands.push_back(newWarpOp.getResult(i)); - // for (OpOperand &yieldOperand : yield->getOpOperands()) { - // if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) - // continue; - // OpResult forResult = cast(yieldOperand.get()); - // resultIdx.push_back(forResult.getResultNumber()); - // yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); - // } + // Next, we create a new for op with the init args yielded by the new + // warp op. + unsigned escapingValuesStartIdx = + forOp.getInitArgs().size(); // ForOp init args are positioned before + // escaping values in the new warp op. + SmallVector newForOpOperands; + for (size_t i = 0; i < escapingValuesStartIdx; ++i) + newForOpOperands.push_back(newWarpOp.getResult(i)); + // Create a new for op outside the new warp op region. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); - - // Create a new for op outside the region with a WarpExecuteOnLane0Op - // region inside. auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newOperands); + forOp.getStep(), newForOpOperands); + // Next, we insert a new warp op (called inner warp op) inside the + // newly created for op. This warp op will contain all ops that were + // contained within the original for op body. rewriter.setInsertionPointToStart(newForOp.getBody()); - SmallVector warpInput(newForOp.getRegionIterArgs().begin(), - newForOp.getRegionIterArgs().end()); - SmallVector warpInputType(forOp.getResultTypes().begin(), - forOp.getResultTypes().end()); + SmallVector innerWarpInput(newForOp.getRegionIterArgs().begin(), + newForOp.getRegionIterArgs().end()); + SmallVector innerWarpInputType(forOp.getResultTypes().begin(), + forOp.getResultTypes().end()); + // Escaping values are forwarded to the inner warp op as its (additional) + // arguments. We keep track of the mapping between these values and their + // argument index in the inner warp op (to replcace uses later). llvm::SmallDenseMap argIndexMapping; - // llvm::errs() << "setting arg index mapping\n"; - unsigned escapingValuesStartIdx = forOp.getInitArgs().size(); for (size_t i = escapingValuesStartIdx; i < escapingValuesStartIdx + escapingValues.size(); ++i) { - warpInput.push_back(newWarpOp.getResult(i)); + innerWarpInput.push_back(newWarpOp.getResult(i)); argIndexMapping[escapingValues[i - escapingValuesStartIdx]] = - warpInputType.size(); - warpInputType.push_back(inputTypes[i - escapingValuesStartIdx]); + innerWarpInputType.size(); + innerWarpInputType.push_back( + escapingValueInputTypes[i - escapingValuesStartIdx]); } - // for (auto [i, r] : llvm::enumerate( - // newWarpOp.getResults().drop_front(forOp.getInitArgs().size()))) - // { - // warpInput.push_back(r); - // argIndexMapping[escapingValues[i]] = warpInputType.size(); - // warpInputType.push_back(inputTypes[i]); - // } - // llvm::errs() << "go here\n"; + // Create the inner warp op with the new input values and types. auto innerWarp = rewriter.create( newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(), - newWarpOp.getWarpSize(), warpInput, warpInputType); - // newForOp->getParentOp()->print(llvm::outs()); - // llvm::outs() << "\n"; + newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType); + // Inline the for op body into the inner warp op body. SmallVector argMapping; argMapping.push_back(newForOp.getInductionVar()); - for (Value args : innerWarp.getBody()->getArguments()) { + for (Value args : innerWarp.getBody()->getArguments()) argMapping.push_back(args); - } - auto forOpCopy = cast(rewriter.clone(*forOp.getOperation())); - argMapping.resize(forOpCopy.getBody()->getNumArguments()); + + argMapping.resize(forOp.getBody()->getNumArguments()); SmallVector yieldOperands; - for (Value operand : forOpCopy.getBody()->getTerminator()->getOperands()) + for (Value operand : forOp.getBody()->getTerminator()->getOperands()) yieldOperands.push_back(operand); - rewriter.eraseOp(forOpCopy.getBody()->getTerminator()); - rewriter.mergeBlocks(forOpCopy.getBody(), innerWarp.getBody(), argMapping); + rewriter.eraseOp(forOp.getBody()->getTerminator()); + rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); + + // Insert a gpu yieldOp at the end of the inner warp op body that yields + // original forOp results. rewriter.setInsertionPointToEnd(innerWarp.getBody()); rewriter.create(innerWarp.getLoc(), yieldOperands); rewriter.setInsertionPointAfter(innerWarp); + // Insert a scf.yield op at the end of the new for op body that yields + // the inner warp op results. if (!innerWarp.getResults().empty()) - rewriter.create(forOpCopy.getLoc(), innerWarp.getResults()); - // forOpCopy->getParentOp()->getParentOp()->print(llvm::outs()); - // llvm::outs() << "\n"; - // llvm::errs() << "erasing for op\n"; - - rewriter.eraseOp(forOpCopy); - // Replace the warpOp result coming from the original ForOp. - // print resultIdx for debugging. - // llvm::errs() << "resultIdx: "; - // for (auto idx : resultIdx) - // llvm::errs() << idx << " "; - // llvm::errs() << "\n"; - for (auto [origIdx, newIdx] : forResultMapping) { + rewriter.create(forOp.getLoc(), innerWarp.getResults()); + + // Update the users of original warp op results that were coming from the + // original forOp to the corresponding new forOp result. + for (auto [origIdx, newIdx] : forResultMapping) rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), newForOp.getResult(newIdx), newForOp); - // newForOp->setOperand(res.index() + 3, - // newWarpOp.getResult(res.value())); - } - - for (auto [origIdx, newIdx] : warpResultMapping) { + // Similarly, update any users of the warp op results that were not + // results of the forOp. + for (auto [origIdx, newIdx] : warpResultMapping) rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), newWarpOp.getResult(newIdx)); - // newForOp->setOperand(res.index() + 3, - // newWarpOp.getResult(res.value())); - } + // Remove the original warp op and for op, they should not have any uses + // at this point. rewriter.eraseOp(forOp); rewriter.eraseOp(warpOp); + // Update any users of escaping values that were forwarded to the + // inner warp op. These values are now arguments of the inner warp op. newForOp.walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { auto it = argIndexMapping.find(operand.get()); @@ -1958,8 +1929,6 @@ struct WarpOpScfForOp : public WarpDistributionPattern { operand.set(innerWarp.getBodyRegion().getArgument(it->second)); } }); - // newForOp->getParentOp()->print(llvm::outs()); - // llvm::outs() << "\n"; // Finally, hoist out any now uniform code from the inner warp op. mlir::vector::moveScalarUniformCode(innerWarp); diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 7cfbcdf101d11..3982783c764df 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -584,6 +584,85 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref, %arg2 return } +// ----- +// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_for_result( +// CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) { +// CHECK-PROP: %[[INI0:.*]] = "some_def"() : () -> vector<128xf32> +// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32> +// CHECK-PROP: gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) { +// CHECK-PROP: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) { +// CHECK-PROP: %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32> +// CHECK-PROP: %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32> +// CHECK-PROP: gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32> +// CHECK-PROP: } +// CHECK-PROP: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32> +// CHECK-PROP: } +// CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> () +func.func @warp_scf_for_unused_for_result(%arg0: index) { + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) { + %ini = "some_def"() : () -> (vector<128xf32>) + %ini1 = "some_def"() : () -> (vector<128xf32>) + %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) { + %add = arith.addi %arg3, %c1 : index + %1 = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>) + %acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>) + scf.yield %acc, %1 : vector<128xf32>, vector<128xf32> + } + gpu.yield %3#0 : vector<128xf32> + } + "some_use"(%0) : (vector<4xf32>) -> () + return +} + +// ----- +// CHECK-PROP-LABEL: func.func @warp_scf_for_swapped_for_results( +// CHECK-PROP: %[[W0:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) { +// CHECK-PROP-NEXT: %[[INI0:.*]] = "some_def"() : () -> vector<256xf32> +// CHECK-PROP-NEXT: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32> +// CHECK-PROP-NEXT: %[[INI2:.*]] = "some_def"() : () -> vector<128xf32> +// CHECK-PROP-NEXT: gpu.yield %[[INI0]], %[[INI1]], %[[INI2]] : vector<256xf32>, vector<128xf32>, vector<128xf32> +// CHECK-PROP-NEXT: } +// CHECK-PROP-NEXT: %[[F0:.*]]:3 = scf.for {{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1, %{{.*}} = %[[W0]]#2) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) { +// CHECK-PROP-NEXT: %[[W1:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : +// CHECK-PROP-SAME: vector<8xf32>, vector<4xf32>, vector<4xf32>) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) { +// CHECK-PROP-NEXT: ^bb0(%{{.*}}: vector<256xf32>, %{{.*}}: vector<128xf32>, %{{.*}}: vector<128xf32>): +// CHECK-PROP-NEXT: %[[T3:.*]] = "some_def_1"(%{{.*}}) : (vector<256xf32>) -> vector<256xf32> +// CHECK-PROP-NEXT: %[[T4:.*]] = "some_def_2"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32> +// CHECK-PROP-NEXT: %[[T5:.*]] = "some_def_3"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32> +// CHECK-PROP-NEXT: gpu.yield %[[T3]], %[[T4]], %[[T5]] : vector<256xf32>, vector<128xf32>, vector<128xf32> +// CHECK-PROP-NEXT: } +// CHECK-PROP-NEXT: scf.yield %[[W1]]#0, %[[W1]]#1, %[[W1]]#2 : vector<8xf32>, vector<4xf32>, vector<4xf32> +// CHECK-PROP-NEXT: } +// CHECK-PROP-NEXT: "some_use_1"(%[[F0]]#2) : (vector<4xf32>) -> () +// CHECK-PROP-NEXT: "some_use_2"(%[[F0]]#1) : (vector<4xf32>) -> () +// CHECK-PROP-NEXT: "some_use_3"(%[[F0]]#0) : (vector<8xf32>) -> () +func.func @warp_scf_for_swapped_for_results(%arg0: index) { + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<8xf32>) { + %ini1 = "some_def"() : () -> (vector<256xf32>) + %ini2 = "some_def"() : () -> (vector<128xf32>) + %ini3 = "some_def"() : () -> (vector<128xf32>) + %3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -> (vector<256xf32>, vector<128xf32>, vector<128xf32>) { + %acc1 = "some_def_1"(%arg4) : (vector<256xf32>) -> (vector<256xf32>) + %acc2 = "some_def_2"(%arg5) : (vector<128xf32>) -> (vector<128xf32>) + %acc3 = "some_def_3"(%arg6) : (vector<128xf32>) -> (vector<128xf32>) + scf.yield %acc1, %acc2, %acc3 : vector<256xf32>, vector<128xf32>, vector<128xf32> + } + gpu.yield %3#2, %3#1, %3#0 : vector<128xf32>, vector<128xf32>, vector<256xf32> + } + "some_use_1"(%0#0) : (vector<4xf32>) -> () + "some_use_2"(%0#1) : (vector<4xf32>) -> () + "some_use_3"(%0#2) : (vector<8xf32>) -> () + return +} + // ----- // CHECK-PROP-LABEL: func @vector_reduction( From 537ca0e285c9a220a0dc0d53e24f51c86e81d5e7 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 9 Jul 2025 02:46:33 +0000 Subject: [PATCH 06/13] add missing logic --- .../Transforms/XeGPUSubgroupDistribute.cpp | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index c072557c2bd22..ef257307de569 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -34,6 +34,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" namespace mlir { namespace xegpu { @@ -876,15 +877,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() { // Step 3: Apply subgroup to workitem distribution patterns. RewritePatternSet patterns(&getContext()); xegpu::populateXeGPUSubgroupDistributePatterns(patterns); - // TODO: distributionFn and shuffleFn are not used at this point. + // distributionFn is used by vector distribution patterns to determine the + // distributed vector type for a given vector value. In XeGPU subgroup + // distribution context, we compute this based on lane layout. auto distributionFn = [](Value val) { VectorType vecType = dyn_cast(val.getType()); int64_t vecRank = vecType ? vecType.getRank() : 0; - OpBuilder builder(val.getContext()); if (vecRank == 0) return AffineMap::get(val.getContext()); - return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext()); + // Get the layout of the vector type. + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val); + // If no layout is specified, assume the inner most dimension is distributed + // for now. + if (!layout) + return AffineMap::getMultiDimMapWithTargets( + vecRank, {static_cast(vecRank - 1)}, val.getContext()); + SmallVector distributedDims; + // Get the distributed dimensions based on the layout. + ArrayRef laneLayout = layout.getLaneLayout().asArrayRef(); + for (unsigned i = 0; i < laneLayout.size(); ++i) { + if (laneLayout[i] > 1) + distributedDims.push_back(i); + } + return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims, + val.getContext()); }; + // TODO: shuffleFn is not used. auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx, int64_t warpSz) { return Value(); }; vector::populatePropagateWarpVectorDistributionPatterns( From 164e9d619880ae24388234edcd9297b66c31d209 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Wed, 9 Jul 2025 17:57:13 +0000 Subject: [PATCH 07/13] address comments --- .../Vector/Transforms/VectorDistribute.cpp | 88 +++++++++---------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 3ce134fe5f3ce..7d3d6b98666a1 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1751,13 +1751,13 @@ struct WarpOpScfForOp : public WarpDistributionPattern { PatternRewriter &rewriter) const override { auto newWarpOpYield = cast( warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - // Only pick up forOp if it is the last op in the region. + // Only pick up `ForOp` if it is the last op in the region. Operation *lastNode = newWarpOpYield->getPrevNode(); auto forOp = dyn_cast_or_null(lastNode); if (!forOp) return failure(); - // Collect Values that come from the warp op but are outside the forOp. - // Those Value needs to be returned by the new warp op. + // Collect Values that come from the `WarpOp` but are outside the `ForOp`. + // Those Values need to be returned by the new warp op. llvm::SmallSetVector escapingValues; SmallVector escapingValueInputTypes; SmallVector escapingValuedistTypes; @@ -1779,16 +1779,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern { if (llvm::is_contained(escapingValuedistTypes, Type{})) return failure(); - // Warp op can yield two types of values: - // 1. Values that are not results of the forOp: - // These values must also be yielded by the new warp op. Also, we need to - // record the index mapping for these values to replace them later. - // 2. Values that are results of the forOp: - // In this case, we record the index mapping between the warp op result - // index and matching forOp result index. + // `WarpOp` can yield two types of values: + // 1. Values that are not results of the `ForOp`: + // These values must also be yielded by the new `WarpOp`. Also, we need + // to record the index mapping for these values to replace them later. + // 2. Values that are results of the `ForOp`: + // In this case, we record the index mapping between the `WarpOp` result + // index and matching `ForOp` result index. SmallVector nonForYieldedValues; SmallVector nonForResultIndices; - DenseMap forResultMapping; + llvm::SmallDenseMap forResultMapping; for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) { // Yielded value is not a result of the forOp. if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) { @@ -1801,10 +1801,10 @@ struct WarpOpScfForOp : public WarpDistributionPattern { forResult.getResultNumber(); } - // Newly created warp op will yield values in following order: - // 1. All init args of the forOp. + // Newly created `WarpOp` will yield values in following order: + // 1. All init args of the `ForOp`. // 2. All escaping values. - // 3. All non-for yielded values. + // 3. All non-`ForOp` yielded values. SmallVector newWarpOpYieldValues; SmallVector newWarpOpDistTypes; for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) { @@ -1823,50 +1823,50 @@ struct WarpOpScfForOp : public WarpDistributionPattern { newWarpOpDistTypes.insert(newWarpOpDistTypes.end(), escapingValuedistTypes.begin(), escapingValuedistTypes.end()); - // Next, we insert all non-for yielded values and their distributed types. - // We also create a mapping between the non-for yielded value index and the - // corresponding new warp op yield value index (needed to update users - // later). - DenseMap warpResultMapping; + // Next, we insert all non-`ForOp` yielded values and their distributed + // types. We also create a mapping between the non-`ForOp` yielded value + // index and the corresponding new `WarpOp` yield value index (needed to + // update users later). + llvm::SmallDenseMap warpResultMapping; for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) { warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size(); newWarpOpYieldValues.push_back(v); newWarpOpDistTypes.push_back( warpOp.getResult(nonForResultIndices[i]).getType()); } - // Create the new warp op with the updated yield values and types. + // Create the new `WarpOp` with the updated yield values and types. WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); newWarpOpYield = cast( newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - // Next, we create a new for op with the init args yielded by the new - // warp op. - unsigned escapingValuesStartIdx = - forOp.getInitArgs().size(); // ForOp init args are positioned before - // escaping values in the new warp op. + // Next, we create a new `ForOp` with the init args yielded by the new + // `WarpOp`. + const unsigned escapingValuesStartIdx = + forOp.getInitArgs().size(); // `ForOp` init args are positioned before + // escaping values in the new `WarpOp`. SmallVector newForOpOperands; for (size_t i = 0; i < escapingValuesStartIdx; ++i) newForOpOperands.push_back(newWarpOp.getResult(i)); - // Create a new for op outside the new warp op region. + // Create a new `ForOp` outside the new `WarpOp` region. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), newForOpOperands); - // Next, we insert a new warp op (called inner warp op) inside the - // newly created for op. This warp op will contain all ops that were - // contained within the original for op body. + // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the + // newly created `ForOp`. This `WarpOp` will contain all ops that were + // contained within the original `ForOp` body. rewriter.setInsertionPointToStart(newForOp.getBody()); SmallVector innerWarpInput(newForOp.getRegionIterArgs().begin(), newForOp.getRegionIterArgs().end()); SmallVector innerWarpInputType(forOp.getResultTypes().begin(), forOp.getResultTypes().end()); - // Escaping values are forwarded to the inner warp op as its (additional) + // Escaping values are forwarded to the inner `WarpOp` as its (additional) // arguments. We keep track of the mapping between these values and their - // argument index in the inner warp op (to replcace uses later). + // argument index in the inner `WarpOp` (to replace users later). llvm::SmallDenseMap argIndexMapping; for (size_t i = escapingValuesStartIdx; i < escapingValuesStartIdx + escapingValues.size(); ++i) { @@ -1876,12 +1876,12 @@ struct WarpOpScfForOp : public WarpDistributionPattern { innerWarpInputType.push_back( escapingValueInputTypes[i - escapingValuesStartIdx]); } - // Create the inner warp op with the new input values and types. + // Create the inner `WarpOp` with the new input values and types. auto innerWarp = rewriter.create( newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(), newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType); - // Inline the for op body into the inner warp op body. + // Inline the `ForOp` body into the inner `WarpOp` body. SmallVector argMapping; argMapping.push_back(newForOp.getInductionVar()); for (Value args : innerWarp.getBody()->getArguments()) @@ -1895,32 +1895,32 @@ struct WarpOpScfForOp : public WarpDistributionPattern { rewriter.eraseOp(forOp.getBody()->getTerminator()); rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); - // Insert a gpu yieldOp at the end of the inner warp op body that yields - // original forOp results. + // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields + // original `ForOp` results. rewriter.setInsertionPointToEnd(innerWarp.getBody()); rewriter.create(innerWarp.getLoc(), yieldOperands); rewriter.setInsertionPointAfter(innerWarp); - // Insert a scf.yield op at the end of the new for op body that yields - // the inner warp op results. + // Insert a scf.yield op at the end of the new `ForOp` body that yields + // the inner `WarpOp` results. if (!innerWarp.getResults().empty()) rewriter.create(forOp.getLoc(), innerWarp.getResults()); - // Update the users of original warp op results that were coming from the - // original forOp to the corresponding new forOp result. + // Update the users of original `WarpOp` results that were coming from the + // original `ForOp` to the corresponding new `ForOp` result. for (auto [origIdx, newIdx] : forResultMapping) rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), newForOp.getResult(newIdx), newForOp); - // Similarly, update any users of the warp op results that were not - // results of the forOp. + // Similarly, update any users of the `WarpOp` results that were not + // results of the `ForOp`. for (auto [origIdx, newIdx] : warpResultMapping) rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), newWarpOp.getResult(newIdx)); - // Remove the original warp op and for op, they should not have any uses + // Remove the original `WarpOp` and `ForOp`, they should not have any uses // at this point. rewriter.eraseOp(forOp); rewriter.eraseOp(warpOp); // Update any users of escaping values that were forwarded to the - // inner warp op. These values are now arguments of the inner warp op. + // inner `WarpOp`. These values are now arguments of the inner `WarpOp`. newForOp.walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { auto it = argIndexMapping.find(operand.get()); @@ -1930,7 +1930,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { } }); - // Finally, hoist out any now uniform code from the inner warp op. + // Finally, hoist out any now uniform code from the inner `WarpOp`. mlir::vector::moveScalarUniformCode(innerWarp); return success(); } From 683fad8092e6c40a1bc95bf2f249653236f72960 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Thu, 10 Jul 2025 22:18:19 +0000 Subject: [PATCH 08/13] address comments --- .../lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 7d3d6b98666a1..f4928ee6c4221 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1760,7 +1760,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // Those Values need to be returned by the new warp op. llvm::SmallSetVector escapingValues; SmallVector escapingValueInputTypes; - SmallVector escapingValuedistTypes; + SmallVector escapingValueDistTypes; mlir::visitUsedValuesDefinedAbove( forOp.getBodyRegion(), [&](OpOperand *operand) { Operation *parent = operand->get().getParentRegion()->getParentOp(); @@ -1773,11 +1773,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern { distType = getDistributedType(vecType, map, warpOp.getWarpSize()); } escapingValueInputTypes.push_back(operand->get().getType()); - escapingValuedistTypes.push_back(distType); + escapingValueDistTypes.push_back(distType); } }); - if (llvm::is_contained(escapingValuedistTypes, Type{})) + if (llvm::is_contained(escapingValueDistTypes, Type{})) return failure(); // `WarpOp` can yield two types of values: // 1. Values that are not results of the `ForOp`: @@ -1821,8 +1821,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern { newWarpOpYieldValues.insert(newWarpOpYieldValues.end(), escapingValues.begin(), escapingValues.end()); newWarpOpDistTypes.insert(newWarpOpDistTypes.end(), - escapingValuedistTypes.begin(), - escapingValuedistTypes.end()); + escapingValueDistTypes.begin(), + escapingValueDistTypes.end()); // Next, we insert all non-`ForOp` yielded values and their distributed // types. We also create a mapping between the non-`ForOp` yielded value // index and the corresponding new `WarpOp` yield value index (needed to From 2c2703eb9511ff8b2df44de23ffa0efb15308772 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 11 Jul 2025 16:26:37 +0000 Subject: [PATCH 09/13] address comments --- .../Dialect/Vector/Transforms/VectorDistribute.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index f4928ee6c4221..af7c34f354668 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1749,10 +1749,10 @@ struct WarpOpScfForOp : public WarpDistributionPattern { : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - auto newWarpOpYield = cast( + auto warpOpYield = cast( warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); // Only pick up `ForOp` if it is the last op in the region. - Operation *lastNode = newWarpOpYield->getPrevNode(); + Operation *lastNode = warpOpYield->getPrevNode(); auto forOp = dyn_cast_or_null(lastNode); if (!forOp) return failure(); @@ -1789,7 +1789,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { SmallVector nonForYieldedValues; SmallVector nonForResultIndices; llvm::SmallDenseMap forResultMapping; - for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) { + for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) { // Yielded value is not a result of the forOp. if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) { nonForYieldedValues.push_back(yieldOperand.get()); @@ -1827,9 +1827,9 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // types. We also create a mapping between the non-`ForOp` yielded value // index and the corresponding new `WarpOp` yield value index (needed to // update users later). - llvm::SmallDenseMap warpResultMapping; + llvm::SmallDenseMap nonForResultMapping; for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) { - warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size(); + nonForResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size(); newWarpOpYieldValues.push_back(v); newWarpOpDistTypes.push_back( warpOp.getResult(nonForResultIndices[i]).getType()); @@ -1837,8 +1837,6 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // Create the new `WarpOp` with the updated yield values and types. WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); - newWarpOpYield = cast( - newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); // Next, we create a new `ForOp` with the init args yielded by the new // `WarpOp`. @@ -1912,7 +1910,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { newForOp.getResult(newIdx), newForOp); // Similarly, update any users of the `WarpOp` results that were not // results of the `ForOp`. - for (auto [origIdx, newIdx] : warpResultMapping) + for (auto [origIdx, newIdx] : nonForResultMapping) rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), newWarpOp.getResult(newIdx)); // Remove the original `WarpOp` and `ForOp`, they should not have any uses From 6297e47a62c8215fb25cd06f165241c10763d23c Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 11 Jul 2025 19:47:05 +0000 Subject: [PATCH 10/13] address comments --- .../Dialect/Vector/Transforms/VectorDistribute.cpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 15c14bef37e76..e62031412eab6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -17,12 +17,8 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Value.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/FormatVariadic.h" @@ -1787,11 +1783,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // index and the corresponding new `WarpOp` yield value index (needed to // update users later). llvm::SmallDenseMap nonForResultMapping; - for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) { - nonForResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size(); + for (auto [i, v] : + llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) { + nonForResultMapping[i] = newWarpOpYieldValues.size(); newWarpOpYieldValues.push_back(v); - newWarpOpDistTypes.push_back( - warpOp.getResult(nonForResultIndices[i]).getType()); + newWarpOpDistTypes.push_back(warpOp.getResult(i).getType()); } // Create the new `WarpOp` with the updated yield values and types. WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( From 3690c61ddefb8b2b663814b44bf1dbda14605d8f Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 11 Jul 2025 19:50:22 +0000 Subject: [PATCH 11/13] address comments --- mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index ef257307de569..5319496edc5af 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -34,7 +34,6 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/SmallVectorExtras.h" namespace mlir { namespace xegpu { From 20c2cf67662c3b3fdecf95a0e280809f98d8db50 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 11 Jul 2025 21:54:15 +0000 Subject: [PATCH 12/13] bug fix --- .../Vector/Transforms/VectorDistribute.cpp | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index e62031412eab6..436029c31e7f8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1741,9 +1741,12 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // 2. Values that are results of the `ForOp`: // In this case, we record the index mapping between the `WarpOp` result // index and matching `ForOp` result index. + // Additionally, we keep track of the distributed types for all `ForOp` + // vector results. SmallVector nonForYieldedValues; SmallVector nonForResultIndices; llvm::SmallDenseMap forResultMapping; + llvm::SmallDenseMap forResultDistTypes; for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) { // Yielded value is not a result of the forOp. if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) { @@ -1752,8 +1755,15 @@ struct WarpOpScfForOp : public WarpDistributionPattern { continue; } OpResult forResult = cast(yieldOperand.get()); - forResultMapping[yieldOperand.getOperandNumber()] = - forResult.getResultNumber(); + unsigned int forResultNumber = forResult.getResultNumber(); + forResultMapping[yieldOperand.getOperandNumber()] = forResultNumber; + // If this `ForOp` result is vector type and it is yielded by the + // `WarpOp`, we keep track the distributed type for this result. + if (!isa(forResult.getType())) + continue; + VectorType distType = cast( + warpOp.getResult(yieldOperand.getOperandNumber()).getType()); + forResultDistTypes[forResultNumber] = distType; } // Newly created `WarpOp` will yield values in following order: @@ -1767,8 +1777,13 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // Compute the distributed type for this init arg. Type distType = initArg.getType(); if (auto vecType = dyn_cast(distType)) { + // If the `ForOp` result corresponds to this init arg is already yielded + // we can get the distributed type from `forResultDistTypes` map. + // Otherwise, we compute it using distributionMapFn. AffineMap map = distributionMapFn(initArg); - distType = getDistributedType(vecType, map, warpOp.getWarpSize()); + distType = forResultDistTypes.count(i) + ? forResultDistTypes[i] + : getDistributedType(vecType, map, warpOp.getWarpSize()); } newWarpOpDistTypes.push_back(distType); } From d72b09606f32e8ac3701f12bbbf2bfc5ac08aac7 Mon Sep 17 00:00:00 2001 From: Charitha Saumya Date: Fri, 11 Jul 2025 22:42:34 +0000 Subject: [PATCH 13/13] bug fix --- .../Vector/vector-warp-distribute.mlir | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index c6342f07fc314..ae8fce786ee57 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -531,6 +531,35 @@ func.func @warp_scf_for_swap_no_yield(%arg0: index) { return } +// ----- +// scf.for result is not distributed in this case. +// CHECK-PROP-LABEL: func @warp_scf_for_broadcasted_result( +// CHECK-PROP: %[[W0:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) { +// CHECK-PROP: %[[INI:.*]] = "some_def"() : () -> vector<1xf32> +// CHECK-PROP: gpu.yield %[[INI]] : vector<1xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[F:.*]] = scf.for {{.*}} iter_args(%[[ARG2:.*]] = %[[W0]]) -> (vector<1xf32>) { +// CHECK-PROP: %[[W1:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%[[ARG2]] : vector<1xf32>) -> (vector<1xf32>) { +// CHECK-PROP: ^bb0(%{{.*}}: vector<1xf32>): +// CHECK-PROP: %[[T0:.*]] = "some_op"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32> +// CHECK-PROP: gpu.yield %[[T0]] : vector<1xf32> +// CHECK-PROP: } +// CHECK-PROP: scf.yield %[[W1]] : vector<1xf32> +func.func @warp_scf_for_broadcasted_result(%arg0: index) -> vector<1xf32> { + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %2 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<1xf32>) { + %ini = "some_def"() : () -> (vector<1xf32>) + %0 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini) -> (vector<1xf32>) { + %1 = "some_op"(%arg4) : (vector<1xf32>) -> (vector<1xf32>) + scf.yield %1 : vector<1xf32> + } + gpu.yield %0 : vector<1xf32> + } + return %2 : vector<1xf32> +} + // ----- #map = affine_map<()[s0] -> (s0 * 4)>