diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index c8566b1ff83ef..e62031412eab6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1704,19 +1704,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern { : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - auto yield = cast( + auto warpOpYield = cast( warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - // Only pick up forOp if it is the last op in the region. - Operation *lastNode = yield->getPrevNode(); + // Only pick up `ForOp` if it is the last op in the region. + Operation *lastNode = warpOpYield->getPrevNode(); auto forOp = dyn_cast_or_null(lastNode); if (!forOp) return failure(); - // Collect Values that come from the warp op but are outside the forOp. - // Those Value needs to be returned by the original warpOp and passed to - // the new op. + // Collect Values that come from the `WarpOp` but are outside the `ForOp`. + // Those Values need to be returned by the new warp op. llvm::SmallSetVector escapingValues; - SmallVector inputTypes; - SmallVector distTypes; + SmallVector escapingValueInputTypes; + SmallVector escapingValueDistTypes; mlir::visitUsedValuesDefinedAbove( forOp.getBodyRegion(), [&](OpOperand *operand) { Operation *parent = operand->get().getParentRegion()->getParentOp(); @@ -1728,81 +1727,153 @@ struct WarpOpScfForOp : public WarpDistributionPattern { AffineMap map = distributionMapFn(operand->get()); distType = getDistributedType(vecType, map, warpOp.getWarpSize()); } - inputTypes.push_back(operand->get().getType()); - distTypes.push_back(distType); + escapingValueInputTypes.push_back(operand->get().getType()); + escapingValueDistTypes.push_back(distType); } }); - if (llvm::is_contained(distTypes, Type{})) + if (llvm::is_contained(escapingValueDistTypes, Type{})) return failure(); - - SmallVector newRetIndices; - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, escapingValues.getArrayRef(), distTypes, - newRetIndices); - yield = cast( - newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - - SmallVector newOperands; - SmallVector resultIdx; - // Collect all the outputs coming from the forOp. - for (OpOperand &yieldOperand : yield->getOpOperands()) { - if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) + // `WarpOp` can yield two types of values: + // 1. Values that are not results of the `ForOp`: + // These values must also be yielded by the new `WarpOp`. Also, we need + // to record the index mapping for these values to replace them later. + // 2. Values that are results of the `ForOp`: + // In this case, we record the index mapping between the `WarpOp` result + // index and matching `ForOp` result index. + SmallVector nonForYieldedValues; + SmallVector nonForResultIndices; + llvm::SmallDenseMap forResultMapping; + for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) { + // Yielded value is not a result of the forOp. + if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) { + nonForYieldedValues.push_back(yieldOperand.get()); + nonForResultIndices.push_back(yieldOperand.getOperandNumber()); continue; - auto forResult = cast(yieldOperand.get()); - newOperands.push_back( - newWarpOp.getResult(yieldOperand.getOperandNumber())); - yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); - resultIdx.push_back(yieldOperand.getOperandNumber()); + } + OpResult forResult = cast(yieldOperand.get()); + forResultMapping[yieldOperand.getOperandNumber()] = + forResult.getResultNumber(); } + // Newly created `WarpOp` will yield values in following order: + // 1. All init args of the `ForOp`. + // 2. All escaping values. + // 3. All non-`ForOp` yielded values. + SmallVector newWarpOpYieldValues; + SmallVector newWarpOpDistTypes; + for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) { + newWarpOpYieldValues.push_back(initArg); + // Compute the distributed type for this init arg. + Type distType = initArg.getType(); + if (auto vecType = dyn_cast(distType)) { + AffineMap map = distributionMapFn(initArg); + distType = getDistributedType(vecType, map, warpOp.getWarpSize()); + } + newWarpOpDistTypes.push_back(distType); + } + // Insert escaping values and their distributed types. + newWarpOpYieldValues.insert(newWarpOpYieldValues.end(), + escapingValues.begin(), escapingValues.end()); + newWarpOpDistTypes.insert(newWarpOpDistTypes.end(), + escapingValueDistTypes.begin(), + escapingValueDistTypes.end()); + // Next, we insert all non-`ForOp` yielded values and their distributed + // types. We also create a mapping between the non-`ForOp` yielded value + // index and the corresponding new `WarpOp` yield value index (needed to + // update users later). + llvm::SmallDenseMap nonForResultMapping; + for (auto [i, v] : + llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) { + nonForResultMapping[i] = newWarpOpYieldValues.size(); + newWarpOpYieldValues.push_back(v); + newWarpOpDistTypes.push_back(warpOp.getResult(i).getType()); + } + // Create the new `WarpOp` with the updated yield values and types. + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( + rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); + + // Next, we create a new `ForOp` with the init args yielded by the new + // `WarpOp`. + const unsigned escapingValuesStartIdx = + forOp.getInitArgs().size(); // `ForOp` init args are positioned before + // escaping values in the new `WarpOp`. + SmallVector newForOpOperands; + for (size_t i = 0; i < escapingValuesStartIdx; ++i) + newForOpOperands.push_back(newWarpOp.getResult(i)); + + // Create a new `ForOp` outside the new `WarpOp` region. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); - - // Create a new for op outside the region with a WarpExecuteOnLane0Op - // region inside. auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newOperands); + forOp.getStep(), newForOpOperands); + // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the + // newly created `ForOp`. This `WarpOp` will contain all ops that were + // contained within the original `ForOp` body. rewriter.setInsertionPointToStart(newForOp.getBody()); - SmallVector warpInput(newForOp.getRegionIterArgs().begin(), - newForOp.getRegionIterArgs().end()); - SmallVector warpInputType(forOp.getResultTypes().begin(), - forOp.getResultTypes().end()); + SmallVector innerWarpInput(newForOp.getRegionIterArgs().begin(), + newForOp.getRegionIterArgs().end()); + SmallVector innerWarpInputType(forOp.getResultTypes().begin(), + forOp.getResultTypes().end()); + // Escaping values are forwarded to the inner `WarpOp` as its (additional) + // arguments. We keep track of the mapping between these values and their + // argument index in the inner `WarpOp` (to replace users later). llvm::SmallDenseMap argIndexMapping; - for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) { - warpInput.push_back(newWarpOp.getResult(retIdx)); - argIndexMapping[escapingValues[i]] = warpInputType.size(); - warpInputType.push_back(inputTypes[i]); + for (size_t i = escapingValuesStartIdx; + i < escapingValuesStartIdx + escapingValues.size(); ++i) { + innerWarpInput.push_back(newWarpOp.getResult(i)); + argIndexMapping[escapingValues[i - escapingValuesStartIdx]] = + innerWarpInputType.size(); + innerWarpInputType.push_back( + escapingValueInputTypes[i - escapingValuesStartIdx]); } + // Create the inner `WarpOp` with the new input values and types. auto innerWarp = rewriter.create( newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(), - newWarpOp.getWarpSize(), warpInput, warpInputType); + newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType); + // Inline the `ForOp` body into the inner `WarpOp` body. SmallVector argMapping; argMapping.push_back(newForOp.getInductionVar()); - for (Value args : innerWarp.getBody()->getArguments()) { + for (Value args : innerWarp.getBody()->getArguments()) argMapping.push_back(args); - } + argMapping.resize(forOp.getBody()->getNumArguments()); SmallVector yieldOperands; for (Value operand : forOp.getBody()->getTerminator()->getOperands()) yieldOperands.push_back(operand); + rewriter.eraseOp(forOp.getBody()->getTerminator()); rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); + + // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields + // original `ForOp` results. rewriter.setInsertionPointToEnd(innerWarp.getBody()); rewriter.create(innerWarp.getLoc(), yieldOperands); rewriter.setInsertionPointAfter(innerWarp); + // Insert a scf.yield op at the end of the new `ForOp` body that yields + // the inner `WarpOp` results. if (!innerWarp.getResults().empty()) rewriter.create(forOp.getLoc(), innerWarp.getResults()); + + // Update the users of original `WarpOp` results that were coming from the + // original `ForOp` to the corresponding new `ForOp` result. + for (auto [origIdx, newIdx] : forResultMapping) + rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), + newForOp.getResult(newIdx), newForOp); + // Similarly, update any users of the `WarpOp` results that were not + // results of the `ForOp`. + for (auto [origIdx, newIdx] : nonForResultMapping) + rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), + newWarpOp.getResult(newIdx)); + // Remove the original `WarpOp` and `ForOp`, they should not have any uses + // at this point. rewriter.eraseOp(forOp); - // Replace the warpOp result coming from the original ForOp. - for (const auto &res : llvm::enumerate(resultIdx)) { - rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()), - newForOp.getResult(res.index())); - newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value())); - } + rewriter.eraseOp(warpOp); + // Update any users of escaping values that were forwarded to the + // inner `WarpOp`. These values are now arguments of the inner `WarpOp`. newForOp.walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { auto it = argIndexMapping.find(operand.get()); @@ -1812,7 +1883,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { } }); - // Finally, hoist out any now uniform code from the inner warp op. + // Finally, hoist out any now uniform code from the inner `WarpOp`. mlir::vector::moveScalarUniformCode(innerWarp); return success(); } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index c072557c2bd22..5319496edc5af 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -876,15 +876,32 @@ void XeGPUSubgroupDistributePass::runOnOperation() { // Step 3: Apply subgroup to workitem distribution patterns. RewritePatternSet patterns(&getContext()); xegpu::populateXeGPUSubgroupDistributePatterns(patterns); - // TODO: distributionFn and shuffleFn are not used at this point. + // distributionFn is used by vector distribution patterns to determine the + // distributed vector type for a given vector value. In XeGPU subgroup + // distribution context, we compute this based on lane layout. auto distributionFn = [](Value val) { VectorType vecType = dyn_cast(val.getType()); int64_t vecRank = vecType ? vecType.getRank() : 0; - OpBuilder builder(val.getContext()); if (vecRank == 0) return AffineMap::get(val.getContext()); - return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext()); + // Get the layout of the vector type. + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val); + // If no layout is specified, assume the inner most dimension is distributed + // for now. + if (!layout) + return AffineMap::getMultiDimMapWithTargets( + vecRank, {static_cast(vecRank - 1)}, val.getContext()); + SmallVector distributedDims; + // Get the distributed dimensions based on the layout. + ArrayRef laneLayout = layout.getLaneLayout().asArrayRef(); + for (unsigned i = 0; i < laneLayout.size(); ++i) { + if (laneLayout[i] > 1) + distributedDims.push_back(i); + } + return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims, + val.getContext()); }; + // TODO: shuffleFn is not used. auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx, int64_t warpSz) { return Value(); }; vector::populatePropagateWarpVectorDistributionPatterns( diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 9fa9d56e4a324..c6342f07fc314 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -584,6 +584,85 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref, %arg2 return } +// ----- +// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_for_result( +// CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) { +// CHECK-PROP: %[[INI0:.*]] = "some_def"() : () -> vector<128xf32> +// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32> +// CHECK-PROP: gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) { +// CHECK-PROP: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) { +// CHECK-PROP: %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32> +// CHECK-PROP: %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32> +// CHECK-PROP: gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32> +// CHECK-PROP: } +// CHECK-PROP: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32> +// CHECK-PROP: } +// CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> () +func.func @warp_scf_for_unused_for_result(%arg0: index) { + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) { + %ini = "some_def"() : () -> (vector<128xf32>) + %ini1 = "some_def"() : () -> (vector<128xf32>) + %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) { + %add = arith.addi %arg3, %c1 : index + %1 = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>) + %acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>) + scf.yield %acc, %1 : vector<128xf32>, vector<128xf32> + } + gpu.yield %3#0 : vector<128xf32> + } + "some_use"(%0) : (vector<4xf32>) -> () + return +} + +// ----- +// CHECK-PROP-LABEL: func.func @warp_scf_for_swapped_for_results( +// CHECK-PROP: %[[W0:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) { +// CHECK-PROP-NEXT: %[[INI0:.*]] = "some_def"() : () -> vector<256xf32> +// CHECK-PROP-NEXT: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32> +// CHECK-PROP-NEXT: %[[INI2:.*]] = "some_def"() : () -> vector<128xf32> +// CHECK-PROP-NEXT: gpu.yield %[[INI0]], %[[INI1]], %[[INI2]] : vector<256xf32>, vector<128xf32>, vector<128xf32> +// CHECK-PROP-NEXT: } +// CHECK-PROP-NEXT: %[[F0:.*]]:3 = scf.for {{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1, %{{.*}} = %[[W0]]#2) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) { +// CHECK-PROP-NEXT: %[[W1:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : +// CHECK-PROP-SAME: vector<8xf32>, vector<4xf32>, vector<4xf32>) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) { +// CHECK-PROP-NEXT: ^bb0(%{{.*}}: vector<256xf32>, %{{.*}}: vector<128xf32>, %{{.*}}: vector<128xf32>): +// CHECK-PROP-NEXT: %[[T3:.*]] = "some_def_1"(%{{.*}}) : (vector<256xf32>) -> vector<256xf32> +// CHECK-PROP-NEXT: %[[T4:.*]] = "some_def_2"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32> +// CHECK-PROP-NEXT: %[[T5:.*]] = "some_def_3"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32> +// CHECK-PROP-NEXT: gpu.yield %[[T3]], %[[T4]], %[[T5]] : vector<256xf32>, vector<128xf32>, vector<128xf32> +// CHECK-PROP-NEXT: } +// CHECK-PROP-NEXT: scf.yield %[[W1]]#0, %[[W1]]#1, %[[W1]]#2 : vector<8xf32>, vector<4xf32>, vector<4xf32> +// CHECK-PROP-NEXT: } +// CHECK-PROP-NEXT: "some_use_1"(%[[F0]]#2) : (vector<4xf32>) -> () +// CHECK-PROP-NEXT: "some_use_2"(%[[F0]]#1) : (vector<4xf32>) -> () +// CHECK-PROP-NEXT: "some_use_3"(%[[F0]]#0) : (vector<8xf32>) -> () +func.func @warp_scf_for_swapped_for_results(%arg0: index) { + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<8xf32>) { + %ini1 = "some_def"() : () -> (vector<256xf32>) + %ini2 = "some_def"() : () -> (vector<128xf32>) + %ini3 = "some_def"() : () -> (vector<128xf32>) + %3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -> (vector<256xf32>, vector<128xf32>, vector<128xf32>) { + %acc1 = "some_def_1"(%arg4) : (vector<256xf32>) -> (vector<256xf32>) + %acc2 = "some_def_2"(%arg5) : (vector<128xf32>) -> (vector<128xf32>) + %acc3 = "some_def_3"(%arg6) : (vector<128xf32>) -> (vector<128xf32>) + scf.yield %acc1, %acc2, %acc3 : vector<256xf32>, vector<128xf32>, vector<128xf32> + } + gpu.yield %3#2, %3#1, %3#0 : vector<128xf32>, vector<128xf32>, vector<256xf32> + } + "some_use_1"(%0#0) : (vector<4xf32>) -> () + "some_use_2"(%0#1) : (vector<4xf32>) -> () + "some_use_3"(%0#2) : (vector<8xf32>) -> () + return +} + // ----- // CHECK-PROP-LABEL: func @vector_reduction(