diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index e62031412eab6..c8566b1ff83ef 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1704,18 +1704,19 @@ struct WarpOpScfForOp : public WarpDistributionPattern { : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - auto warpOpYield = cast( + auto yield = cast( warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - // Only pick up `ForOp` if it is the last op in the region. - Operation *lastNode = warpOpYield->getPrevNode(); + // Only pick up forOp if it is the last op in the region. + Operation *lastNode = yield->getPrevNode(); auto forOp = dyn_cast_or_null(lastNode); if (!forOp) return failure(); - // Collect Values that come from the `WarpOp` but are outside the `ForOp`. - // Those Values need to be returned by the new warp op. + // Collect Values that come from the warp op but are outside the forOp. + // Those Value needs to be returned by the original warpOp and passed to + // the new op. llvm::SmallSetVector escapingValues; - SmallVector escapingValueInputTypes; - SmallVector escapingValueDistTypes; + SmallVector inputTypes; + SmallVector distTypes; mlir::visitUsedValuesDefinedAbove( forOp.getBodyRegion(), [&](OpOperand *operand) { Operation *parent = operand->get().getParentRegion()->getParentOp(); @@ -1727,153 +1728,81 @@ struct WarpOpScfForOp : public WarpDistributionPattern { AffineMap map = distributionMapFn(operand->get()); distType = getDistributedType(vecType, map, warpOp.getWarpSize()); } - escapingValueInputTypes.push_back(operand->get().getType()); - escapingValueDistTypes.push_back(distType); + inputTypes.push_back(operand->get().getType()); + distTypes.push_back(distType); } }); - if (llvm::is_contained(escapingValueDistTypes, Type{})) + if (llvm::is_contained(distTypes, Type{})) return failure(); - // `WarpOp` can yield two types of values: - // 1. Values that are not results of the `ForOp`: - // These values must also be yielded by the new `WarpOp`. Also, we need - // to record the index mapping for these values to replace them later. - // 2. Values that are results of the `ForOp`: - // In this case, we record the index mapping between the `WarpOp` result - // index and matching `ForOp` result index. - SmallVector nonForYieldedValues; - SmallVector nonForResultIndices; - llvm::SmallDenseMap forResultMapping; - for (OpOperand &yieldOperand : warpOpYield->getOpOperands()) { - // Yielded value is not a result of the forOp. - if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) { - nonForYieldedValues.push_back(yieldOperand.get()); - nonForResultIndices.push_back(yieldOperand.getOperandNumber()); + + SmallVector newRetIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, escapingValues.getArrayRef(), distTypes, + newRetIndices); + yield = cast( + newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + + SmallVector newOperands; + SmallVector resultIdx; + // Collect all the outputs coming from the forOp. + for (OpOperand &yieldOperand : yield->getOpOperands()) { + if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) continue; - } - OpResult forResult = cast(yieldOperand.get()); - forResultMapping[yieldOperand.getOperandNumber()] = - forResult.getResultNumber(); + auto forResult = cast(yieldOperand.get()); + newOperands.push_back( + newWarpOp.getResult(yieldOperand.getOperandNumber())); + yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); + resultIdx.push_back(yieldOperand.getOperandNumber()); } - // Newly created `WarpOp` will yield values in following order: - // 1. All init args of the `ForOp`. - // 2. All escaping values. - // 3. All non-`ForOp` yielded values. - SmallVector newWarpOpYieldValues; - SmallVector newWarpOpDistTypes; - for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) { - newWarpOpYieldValues.push_back(initArg); - // Compute the distributed type for this init arg. - Type distType = initArg.getType(); - if (auto vecType = dyn_cast(distType)) { - AffineMap map = distributionMapFn(initArg); - distType = getDistributedType(vecType, map, warpOp.getWarpSize()); - } - newWarpOpDistTypes.push_back(distType); - } - // Insert escaping values and their distributed types. - newWarpOpYieldValues.insert(newWarpOpYieldValues.end(), - escapingValues.begin(), escapingValues.end()); - newWarpOpDistTypes.insert(newWarpOpDistTypes.end(), - escapingValueDistTypes.begin(), - escapingValueDistTypes.end()); - // Next, we insert all non-`ForOp` yielded values and their distributed - // types. We also create a mapping between the non-`ForOp` yielded value - // index and the corresponding new `WarpOp` yield value index (needed to - // update users later). - llvm::SmallDenseMap nonForResultMapping; - for (auto [i, v] : - llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) { - nonForResultMapping[i] = newWarpOpYieldValues.size(); - newWarpOpYieldValues.push_back(v); - newWarpOpDistTypes.push_back(warpOp.getResult(i).getType()); - } - // Create the new `WarpOp` with the updated yield values and types. - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); - - // Next, we create a new `ForOp` with the init args yielded by the new - // `WarpOp`. - const unsigned escapingValuesStartIdx = - forOp.getInitArgs().size(); // `ForOp` init args are positioned before - // escaping values in the new `WarpOp`. - SmallVector newForOpOperands; - for (size_t i = 0; i < escapingValuesStartIdx; ++i) - newForOpOperands.push_back(newWarpOp.getResult(i)); - - // Create a new `ForOp` outside the new `WarpOp` region. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); + + // Create a new for op outside the region with a WarpExecuteOnLane0Op + // region inside. auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newForOpOperands); - // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the - // newly created `ForOp`. This `WarpOp` will contain all ops that were - // contained within the original `ForOp` body. + forOp.getStep(), newOperands); rewriter.setInsertionPointToStart(newForOp.getBody()); - SmallVector innerWarpInput(newForOp.getRegionIterArgs().begin(), - newForOp.getRegionIterArgs().end()); - SmallVector innerWarpInputType(forOp.getResultTypes().begin(), - forOp.getResultTypes().end()); - // Escaping values are forwarded to the inner `WarpOp` as its (additional) - // arguments. We keep track of the mapping between these values and their - // argument index in the inner `WarpOp` (to replace users later). + SmallVector warpInput(newForOp.getRegionIterArgs().begin(), + newForOp.getRegionIterArgs().end()); + SmallVector warpInputType(forOp.getResultTypes().begin(), + forOp.getResultTypes().end()); llvm::SmallDenseMap argIndexMapping; - for (size_t i = escapingValuesStartIdx; - i < escapingValuesStartIdx + escapingValues.size(); ++i) { - innerWarpInput.push_back(newWarpOp.getResult(i)); - argIndexMapping[escapingValues[i - escapingValuesStartIdx]] = - innerWarpInputType.size(); - innerWarpInputType.push_back( - escapingValueInputTypes[i - escapingValuesStartIdx]); + for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) { + warpInput.push_back(newWarpOp.getResult(retIdx)); + argIndexMapping[escapingValues[i]] = warpInputType.size(); + warpInputType.push_back(inputTypes[i]); } - // Create the inner `WarpOp` with the new input values and types. auto innerWarp = rewriter.create( newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(), - newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType); + newWarpOp.getWarpSize(), warpInput, warpInputType); - // Inline the `ForOp` body into the inner `WarpOp` body. SmallVector argMapping; argMapping.push_back(newForOp.getInductionVar()); - for (Value args : innerWarp.getBody()->getArguments()) + for (Value args : innerWarp.getBody()->getArguments()) { argMapping.push_back(args); - + } argMapping.resize(forOp.getBody()->getNumArguments()); SmallVector yieldOperands; for (Value operand : forOp.getBody()->getTerminator()->getOperands()) yieldOperands.push_back(operand); - rewriter.eraseOp(forOp.getBody()->getTerminator()); rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); - - // Insert a gpu `YieldOp` at the end of the inner `WarpOp` body that yields - // original `ForOp` results. rewriter.setInsertionPointToEnd(innerWarp.getBody()); rewriter.create(innerWarp.getLoc(), yieldOperands); rewriter.setInsertionPointAfter(innerWarp); - // Insert a scf.yield op at the end of the new `ForOp` body that yields - // the inner `WarpOp` results. if (!innerWarp.getResults().empty()) rewriter.create(forOp.getLoc(), innerWarp.getResults()); - - // Update the users of original `WarpOp` results that were coming from the - // original `ForOp` to the corresponding new `ForOp` result. - for (auto [origIdx, newIdx] : forResultMapping) - rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), - newForOp.getResult(newIdx), newForOp); - // Similarly, update any users of the `WarpOp` results that were not - // results of the `ForOp`. - for (auto [origIdx, newIdx] : nonForResultMapping) - rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), - newWarpOp.getResult(newIdx)); - // Remove the original `WarpOp` and `ForOp`, they should not have any uses - // at this point. rewriter.eraseOp(forOp); - rewriter.eraseOp(warpOp); - // Update any users of escaping values that were forwarded to the - // inner `WarpOp`. These values are now arguments of the inner `WarpOp`. + // Replace the warpOp result coming from the original ForOp. + for (const auto &res : llvm::enumerate(resultIdx)) { + rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()), + newForOp.getResult(res.index())); + newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value())); + } newForOp.walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { auto it = argIndexMapping.find(operand.get()); @@ -1883,7 +1812,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { } }); - // Finally, hoist out any now uniform code from the inner `WarpOp`. + // Finally, hoist out any now uniform code from the inner warp op. mlir::vector::moveScalarUniformCode(innerWarp); return success(); } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 5319496edc5af..c072557c2bd22 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -876,32 +876,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() { // Step 3: Apply subgroup to workitem distribution patterns. RewritePatternSet patterns(&getContext()); xegpu::populateXeGPUSubgroupDistributePatterns(patterns); - // distributionFn is used by vector distribution patterns to determine the - // distributed vector type for a given vector value. In XeGPU subgroup - // distribution context, we compute this based on lane layout. + // TODO: distributionFn and shuffleFn are not used at this point. auto distributionFn = [](Value val) { VectorType vecType = dyn_cast(val.getType()); int64_t vecRank = vecType ? vecType.getRank() : 0; + OpBuilder builder(val.getContext()); if (vecRank == 0) return AffineMap::get(val.getContext()); - // Get the layout of the vector type. - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val); - // If no layout is specified, assume the inner most dimension is distributed - // for now. - if (!layout) - return AffineMap::getMultiDimMapWithTargets( - vecRank, {static_cast(vecRank - 1)}, val.getContext()); - SmallVector distributedDims; - // Get the distributed dimensions based on the layout. - ArrayRef laneLayout = layout.getLaneLayout().asArrayRef(); - for (unsigned i = 0; i < laneLayout.size(); ++i) { - if (laneLayout[i] > 1) - distributedDims.push_back(i); - } - return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims, - val.getContext()); + return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext()); }; - // TODO: shuffleFn is not used. auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx, int64_t warpSz) { return Value(); }; vector::populatePropagateWarpVectorDistributionPatterns( diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index c6342f07fc314..9fa9d56e4a324 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -584,85 +584,6 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref, %arg2 return } -// ----- -// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_for_result( -// CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) { -// CHECK-PROP: %[[INI0:.*]] = "some_def"() : () -> vector<128xf32> -// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32> -// CHECK-PROP: gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32> -// CHECK-PROP: } -// CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) { -// CHECK-PROP: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) { -// CHECK-PROP: %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32> -// CHECK-PROP: %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32> -// CHECK-PROP: gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32> -// CHECK-PROP: } -// CHECK-PROP: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32> -// CHECK-PROP: } -// CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> () -func.func @warp_scf_for_unused_for_result(%arg0: index) { - %c128 = arith.constant 128 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) { - %ini = "some_def"() : () -> (vector<128xf32>) - %ini1 = "some_def"() : () -> (vector<128xf32>) - %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) { - %add = arith.addi %arg3, %c1 : index - %1 = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>) - %acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>) - scf.yield %acc, %1 : vector<128xf32>, vector<128xf32> - } - gpu.yield %3#0 : vector<128xf32> - } - "some_use"(%0) : (vector<4xf32>) -> () - return -} - -// ----- -// CHECK-PROP-LABEL: func.func @warp_scf_for_swapped_for_results( -// CHECK-PROP: %[[W0:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) { -// CHECK-PROP-NEXT: %[[INI0:.*]] = "some_def"() : () -> vector<256xf32> -// CHECK-PROP-NEXT: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32> -// CHECK-PROP-NEXT: %[[INI2:.*]] = "some_def"() : () -> vector<128xf32> -// CHECK-PROP-NEXT: gpu.yield %[[INI0]], %[[INI1]], %[[INI2]] : vector<256xf32>, vector<128xf32>, vector<128xf32> -// CHECK-PROP-NEXT: } -// CHECK-PROP-NEXT: %[[F0:.*]]:3 = scf.for {{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1, %{{.*}} = %[[W0]]#2) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) { -// CHECK-PROP-NEXT: %[[W1:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : -// CHECK-PROP-SAME: vector<8xf32>, vector<4xf32>, vector<4xf32>) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) { -// CHECK-PROP-NEXT: ^bb0(%{{.*}}: vector<256xf32>, %{{.*}}: vector<128xf32>, %{{.*}}: vector<128xf32>): -// CHECK-PROP-NEXT: %[[T3:.*]] = "some_def_1"(%{{.*}}) : (vector<256xf32>) -> vector<256xf32> -// CHECK-PROP-NEXT: %[[T4:.*]] = "some_def_2"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32> -// CHECK-PROP-NEXT: %[[T5:.*]] = "some_def_3"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32> -// CHECK-PROP-NEXT: gpu.yield %[[T3]], %[[T4]], %[[T5]] : vector<256xf32>, vector<128xf32>, vector<128xf32> -// CHECK-PROP-NEXT: } -// CHECK-PROP-NEXT: scf.yield %[[W1]]#0, %[[W1]]#1, %[[W1]]#2 : vector<8xf32>, vector<4xf32>, vector<4xf32> -// CHECK-PROP-NEXT: } -// CHECK-PROP-NEXT: "some_use_1"(%[[F0]]#2) : (vector<4xf32>) -> () -// CHECK-PROP-NEXT: "some_use_2"(%[[F0]]#1) : (vector<4xf32>) -> () -// CHECK-PROP-NEXT: "some_use_3"(%[[F0]]#0) : (vector<8xf32>) -> () -func.func @warp_scf_for_swapped_for_results(%arg0: index) { - %c128 = arith.constant 128 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<8xf32>) { - %ini1 = "some_def"() : () -> (vector<256xf32>) - %ini2 = "some_def"() : () -> (vector<128xf32>) - %ini3 = "some_def"() : () -> (vector<128xf32>) - %3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -> (vector<256xf32>, vector<128xf32>, vector<128xf32>) { - %acc1 = "some_def_1"(%arg4) : (vector<256xf32>) -> (vector<256xf32>) - %acc2 = "some_def_2"(%arg5) : (vector<128xf32>) -> (vector<128xf32>) - %acc3 = "some_def_3"(%arg6) : (vector<128xf32>) -> (vector<128xf32>) - scf.yield %acc1, %acc2, %acc3 : vector<256xf32>, vector<128xf32>, vector<128xf32> - } - gpu.yield %3#2, %3#1, %3#0 : vector<128xf32>, vector<128xf32>, vector<256xf32> - } - "some_use_1"(%0#0) : (vector<4xf32>) -> () - "some_use_2"(%0#1) : (vector<4xf32>) -> () - "some_use_3"(%0#2) : (vector<8xf32>) -> () - return -} - // ----- // CHECK-PROP-LABEL: func @vector_reduction(