Skip to content

[mlir][vector] Refactor WarpOpScfForOp to support unused or swapped forOp results. #147620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

charithaintc
Copy link
Contributor

Current implementation generates incorrect code or crashes in the following valid cases.

  1. At least one of the for op results are not yielded by the warpOp.
    Example:
%0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
    ....
    %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
      
      %1  = ...
      %acc = ....
      scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
    }
    gpu.yield %3#0 : vector<128xf32> // %3#1 is not used but can not be removed as dead code (loop carried).
  }
  "some_use"(%0) : (vector<4xf32>) -> ()
  return
  1. Enclosing warpOp yields the forOp results in different order compared to the forOp results.
    Example:
  %0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<8xf32>) {
    ....
    %3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -> (vector<256xf32>, vector<128xf32>, vector<128xf32>) {
      .....
      scf.yield %acc1, %acc2, %acc3 : vector<256xf32>, vector<128xf32>, vector<128xf32>
    }
    gpu.yield %3#2, %3#1, %3#0 : vector<128xf32>, vector<128xf32>, vector<256xf32> // swapped order
  }
  "some_use_1"(%0#0) : (vector<4xf32>) -> ()
  "some_use_2"(%0#1) : (vector<4xf32>) -> ()
  "some_use_3"(%0#2) : (vector<8xf32>) -> ()

@llvmbot
Copy link
Member

llvmbot commented Jul 9, 2025

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Charitha Saumya (charithaintc)

Changes

Current implementation generates incorrect code or crashes in the following valid cases.

  1. At least one of the for op results are not yielded by the warpOp.
    Example:
%0 = gpu.warp_execute_on_lane_0(%arg0)[32] -&gt; (vector&lt;4xf32&gt;) {
    ....
    %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -&gt; (vector&lt;128xf32&gt;, vector&lt;128xf32&gt;) {
      
      %1  = ...
      %acc = ....
      scf.yield %acc, %1 : vector&lt;128xf32&gt;, vector&lt;128xf32&gt;
    }
    gpu.yield %3#<!-- -->0 : vector&lt;128xf32&gt; // %3#<!-- -->1 is not used but can not be removed as dead code (loop carried).
  }
  "some_use"(%0) : (vector&lt;4xf32&gt;) -&gt; ()
  return
  1. Enclosing warpOp yields the forOp results in different order compared to the forOp results.
    Example:
  %0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -&gt; (vector&lt;4xf32&gt;, vector&lt;4xf32&gt;, vector&lt;8xf32&gt;) {
    ....
    %3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -&gt; (vector&lt;256xf32&gt;, vector&lt;128xf32&gt;, vector&lt;128xf32&gt;) {
      .....
      scf.yield %acc1, %acc2, %acc3 : vector&lt;256xf32&gt;, vector&lt;128xf32&gt;, vector&lt;128xf32&gt;
    }
    gpu.yield %3#<!-- -->2, %3#<!-- -->1, %3#<!-- -->0 : vector&lt;128xf32&gt;, vector&lt;128xf32&gt;, vector&lt;256xf32&gt; // swapped order
  }
  "some_use_1"(%0#<!-- -->0) : (vector&lt;4xf32&gt;) -&gt; ()
  "some_use_2"(%0#<!-- -->1) : (vector&lt;4xf32&gt;) -&gt; ()
  "some_use_3"(%0#<!-- -->2) : (vector&lt;8xf32&gt;) -&gt; ()


Full diff: https://github.com/llvm/llvm-project/pull/147620.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+124-47)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+79)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index fb99e22c77ea0..3ce134fe5f3ce 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -17,8 +17,12 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -1745,19 +1749,18 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
       : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
   LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    auto yield = cast<gpu::YieldOp>(
+    auto newWarpOpYield = cast<gpu::YieldOp>(
         warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
     // Only pick up forOp if it is the last op in the region.
-    Operation *lastNode = yield->getPrevNode();
+    Operation *lastNode = newWarpOpYield->getPrevNode();
     auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
     if (!forOp)
       return failure();
     // Collect Values that come from the warp op but are outside the forOp.
-    // Those Value needs to be returned by the original warpOp and passed to
-    // the new op.
+    // Those Value needs to be returned by the new warp op.
     llvm::SmallSetVector<Value, 32> escapingValues;
-    SmallVector<Type> inputTypes;
-    SmallVector<Type> distTypes;
+    SmallVector<Type> escapingValueInputTypes;
+    SmallVector<Type> escapingValuedistTypes;
     mlir::visitUsedValuesDefinedAbove(
         forOp.getBodyRegion(), [&](OpOperand *operand) {
           Operation *parent = operand->get().getParentRegion()->getParentOp();
@@ -1769,81 +1772,155 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
               AffineMap map = distributionMapFn(operand->get());
               distType = getDistributedType(vecType, map, warpOp.getWarpSize());
             }
-            inputTypes.push_back(operand->get().getType());
-            distTypes.push_back(distType);
+            escapingValueInputTypes.push_back(operand->get().getType());
+            escapingValuedistTypes.push_back(distType);
           }
         });
 
-    if (llvm::is_contained(distTypes, Type{}))
+    if (llvm::is_contained(escapingValuedistTypes, Type{}))
       return failure();
+    // Warp op can yield two types of values:
+    // 1. Values that are not results of the forOp:
+    //    These values must also be yielded by the new warp op. Also, we need to
+    //    record the index mapping for these values to replace them later.
+    // 2. Values that are results of the forOp:
+    //    In this case, we record the index mapping between the warp op result
+    //    index and matching forOp result index.
+    SmallVector<Value> nonForYieldedValues;
+    SmallVector<unsigned> nonForResultIndices;
+    DenseMap<unsigned, unsigned> forResultMapping;
+    for (OpOperand &yieldOperand : newWarpOpYield->getOpOperands()) {
+      // Yielded value is not a result of the forOp.
+      if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) {
+        nonForYieldedValues.push_back(yieldOperand.get());
+        nonForResultIndices.push_back(yieldOperand.getOperandNumber());
+        continue;
+      }
+      OpResult forResult = cast<OpResult>(yieldOperand.get());
+      forResultMapping[yieldOperand.getOperandNumber()] =
+          forResult.getResultNumber();
+    }
 
-    SmallVector<size_t> newRetIndices;
-    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, escapingValues.getArrayRef(), distTypes,
-        newRetIndices);
-    yield = cast<gpu::YieldOp>(
+    // Newly created warp op will yield values in following order:
+    // 1. All init args of the forOp.
+    // 2. All escaping values.
+    // 3. All non-for yielded values.
+    SmallVector<Value> newWarpOpYieldValues;
+    SmallVector<Type> newWarpOpDistTypes;
+    for (auto [i, initArg] : llvm::enumerate(forOp.getInitArgs())) {
+      newWarpOpYieldValues.push_back(initArg);
+      // Compute the distributed type for this init arg.
+      Type distType = initArg.getType();
+      if (auto vecType = dyn_cast<VectorType>(distType)) {
+        AffineMap map = distributionMapFn(initArg);
+        distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+      }
+      newWarpOpDistTypes.push_back(distType);
+    }
+    // Insert escaping values and their distributed types.
+    newWarpOpYieldValues.insert(newWarpOpYieldValues.end(),
+                                escapingValues.begin(), escapingValues.end());
+    newWarpOpDistTypes.insert(newWarpOpDistTypes.end(),
+                              escapingValuedistTypes.begin(),
+                              escapingValuedistTypes.end());
+    // Next, we insert all non-for yielded values and their distributed types.
+    // We also create a mapping between the non-for yielded value index and the
+    // corresponding new warp op yield value index (needed to update users
+    // later).
+    DenseMap<unsigned, unsigned> warpResultMapping;
+    for (auto [i, v] : llvm::enumerate(nonForYieldedValues)) {
+      warpResultMapping[nonForResultIndices[i]] = newWarpOpYieldValues.size();
+      newWarpOpYieldValues.push_back(v);
+      newWarpOpDistTypes.push_back(
+          warpOp.getResult(nonForResultIndices[i]).getType());
+    }
+    // Create the new warp op with the updated yield values and types.
+    WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
+        rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+    newWarpOpYield = cast<gpu::YieldOp>(
         newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator());
 
-    SmallVector<Value> newOperands;
-    SmallVector<unsigned> resultIdx;
-    // Collect all the outputs coming from the forOp.
-    for (OpOperand &yieldOperand : yield->getOpOperands()) {
-      if (yieldOperand.get().getDefiningOp() != forOp.getOperation())
-        continue;
-      auto forResult = cast<OpResult>(yieldOperand.get());
-      newOperands.push_back(
-          newWarpOp.getResult(yieldOperand.getOperandNumber()));
-      yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]);
-      resultIdx.push_back(yieldOperand.getOperandNumber());
-    }
+    // Next, we create a new for op with the init args yielded by the new
+    // warp op.
+    unsigned escapingValuesStartIdx =
+        forOp.getInitArgs().size(); // ForOp init args are positioned before
+                                    // escaping values in the new warp op.
+    SmallVector<Value> newForOpOperands;
+    for (size_t i = 0; i < escapingValuesStartIdx; ++i)
+      newForOpOperands.push_back(newWarpOp.getResult(i));
 
+    // Create a new for op outside the new warp op region.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPointAfter(newWarpOp);
-
-    // Create a new for op outside the region with a WarpExecuteOnLane0Op
-    // region inside.
     auto newForOp = rewriter.create<scf::ForOp>(
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
-        forOp.getStep(), newOperands);
+        forOp.getStep(), newForOpOperands);
+    // Next, we insert a new warp op (called inner warp op) inside the
+    // newly created for op. This warp op will contain all ops that were
+    // contained within the original for op body.
     rewriter.setInsertionPointToStart(newForOp.getBody());
 
-    SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(),
-                                 newForOp.getRegionIterArgs().end());
-    SmallVector<Type> warpInputType(forOp.getResultTypes().begin(),
-                                    forOp.getResultTypes().end());
+    SmallVector<Value> innerWarpInput(newForOp.getRegionIterArgs().begin(),
+                                      newForOp.getRegionIterArgs().end());
+    SmallVector<Type> innerWarpInputType(forOp.getResultTypes().begin(),
+                                         forOp.getResultTypes().end());
+    // Escaping values are forwarded to the inner warp op as its (additional)
+    // arguments. We keep track of the mapping between these values and their
+    // argument index in the inner warp op (to replcace uses later).
     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
-    for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
-      warpInput.push_back(newWarpOp.getResult(retIdx));
-      argIndexMapping[escapingValues[i]] = warpInputType.size();
-      warpInputType.push_back(inputTypes[i]);
+    for (size_t i = escapingValuesStartIdx;
+         i < escapingValuesStartIdx + escapingValues.size(); ++i) {
+      innerWarpInput.push_back(newWarpOp.getResult(i));
+      argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
+          innerWarpInputType.size();
+      innerWarpInputType.push_back(
+          escapingValueInputTypes[i - escapingValuesStartIdx]);
     }
+    // Create the inner warp op with the new input values and types.
     auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>(
         newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(),
-        newWarpOp.getWarpSize(), warpInput, warpInputType);
+        newWarpOp.getWarpSize(), innerWarpInput, innerWarpInputType);
 
+    // Inline the for op body into the inner warp op body.
     SmallVector<Value> argMapping;
     argMapping.push_back(newForOp.getInductionVar());
-    for (Value args : innerWarp.getBody()->getArguments()) {
+    for (Value args : innerWarp.getBody()->getArguments())
       argMapping.push_back(args);
-    }
+
     argMapping.resize(forOp.getBody()->getNumArguments());
     SmallVector<Value> yieldOperands;
     for (Value operand : forOp.getBody()->getTerminator()->getOperands())
       yieldOperands.push_back(operand);
+
     rewriter.eraseOp(forOp.getBody()->getTerminator());
     rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping);
+
+    // Insert a gpu yieldOp at the end of the inner warp op body that yields
+    // original forOp results.
     rewriter.setInsertionPointToEnd(innerWarp.getBody());
     rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands);
     rewriter.setInsertionPointAfter(innerWarp);
+    // Insert a scf.yield op at the end of the new for op body that yields
+    // the inner warp op results.
     if (!innerWarp.getResults().empty())
       rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults());
+
+    // Update the users of original warp op results that were coming from the
+    // original forOp to the corresponding new forOp result.
+    for (auto [origIdx, newIdx] : forResultMapping)
+      rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+                                    newForOp.getResult(newIdx), newForOp);
+    // Similarly, update any users of the warp op results that were not
+    // results of the forOp.
+    for (auto [origIdx, newIdx] : warpResultMapping)
+      rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
+                                  newWarpOp.getResult(newIdx));
+    // Remove the original warp op and for op, they should not have any uses
+    // at this point.
     rewriter.eraseOp(forOp);
-    // Replace the warpOp result coming from the original ForOp.
-    for (const auto &res : llvm::enumerate(resultIdx)) {
-      rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()),
-                                  newForOp.getResult(res.index()));
-      newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value()));
-    }
+    rewriter.eraseOp(warpOp);
+    // Update any users of escaping values that were forwarded to the
+    // inner warp op. These values are now arguments of the inner warp op.
     newForOp.walk([&](Operation *op) {
       for (OpOperand &operand : op->getOpOperands()) {
         auto it = argIndexMapping.find(operand.get());
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 1161dbd4b2166..56fa38ce5a3e8 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -584,6 +584,85 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
   return
 }
 
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_for_result(
+//       CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP:  %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
+//       CHECK-PROP:  %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+//       CHECK-PROP:  gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
+//       CHECK-PROP: }
+//       CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP:  %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
+//       CHECK-PROP:    %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
+//       CHECK-PROP:    %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
+//       CHECK-PROP:    gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
+//       CHECK-PROP:  }
+//       CHECK-PROP:  scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
+//       CHECK-PROP: }
+//       CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
+func.func @warp_scf_for_unused_for_result(%arg0: index) {
+  %c128 = arith.constant 128 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
+    %ini = "some_def"() : () -> (vector<128xf32>)
+    %ini1 = "some_def"() : () -> (vector<128xf32>)
+    %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
+      %add = arith.addi %arg3, %c1 : index
+      %1  = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
+      %acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
+    }
+    gpu.yield %3#0 : vector<128xf32>
+  }
+  "some_use"(%0) : (vector<4xf32>) -> ()
+  return
+}
+
+// -----
+// CHECK-PROP-LABEL: func.func @warp_scf_for_swapped_for_results(
+//       CHECK-PROP:  %[[W0:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+//  CHECK-PROP-NEXT:    %[[INI0:.*]] = "some_def"() : () -> vector<256xf32>
+//  CHECK-PROP-NEXT:    %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
+//  CHECK-PROP-NEXT:    %[[INI2:.*]] = "some_def"() : () -> vector<128xf32>
+//  CHECK-PROP-NEXT:    gpu.yield %[[INI0]], %[[INI1]], %[[INI2]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
+//  CHECK-PROP-NEXT:  }
+//  CHECK-PROP-NEXT:  %[[F0:.*]]:3 = scf.for {{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1, %{{.*}} = %[[W0]]#2) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+//  CHECK-PROP-NEXT:    %[[W1:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} :
+//  CHECK-PROP-SAME:        vector<8xf32>, vector<4xf32>, vector<4xf32>) -> (vector<8xf32>, vector<4xf32>, vector<4xf32>) {
+//  CHECK-PROP-NEXT:      ^bb0(%{{.*}}: vector<256xf32>, %{{.*}}: vector<128xf32>, %{{.*}}: vector<128xf32>):
+//  CHECK-PROP-NEXT:        %[[T3:.*]] = "some_def_1"(%{{.*}}) : (vector<256xf32>) -> vector<256xf32>
+//  CHECK-PROP-NEXT:        %[[T4:.*]] = "some_def_2"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
+//  CHECK-PROP-NEXT:        %[[T5:.*]] = "some_def_3"(%{{.*}}) : (vector<128xf32>) -> vector<128xf32>
+//  CHECK-PROP-NEXT:        gpu.yield %[[T3]], %[[T4]], %[[T5]] : vector<256xf32>, vector<128xf32>, vector<128xf32>
+//  CHECK-PROP-NEXT:    }
+//  CHECK-PROP-NEXT:    scf.yield %[[W1]]#0, %[[W1]]#1, %[[W1]]#2 : vector<8xf32>, vector<4xf32>, vector<4xf32>
+//  CHECK-PROP-NEXT:  }
+//  CHECK-PROP-NEXT:  "some_use_1"(%[[F0]]#2) : (vector<4xf32>) -> ()
+//  CHECK-PROP-NEXT:  "some_use_2"(%[[F0]]#1) : (vector<4xf32>) -> ()
+//  CHECK-PROP-NEXT:  "some_use_3"(%[[F0]]#0) : (vector<8xf32>) -> ()
+func.func @warp_scf_for_swapped_for_results(%arg0: index) {
+  %c128 = arith.constant 128 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %0:3 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>, vector<4xf32>, vector<8xf32>) {
+    %ini1 = "some_def"() : () -> (vector<256xf32>)
+    %ini2 = "some_def"() : () -> (vector<128xf32>)
+    %ini3 = "some_def"() : () -> (vector<128xf32>)
+    %3:3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini1, %arg5 = %ini2, %arg6 = %ini3) -> (vector<256xf32>, vector<128xf32>, vector<128xf32>) {
+      %acc1 = "some_def_1"(%arg4) : (vector<256xf32>) -> (vector<256xf32>)
+      %acc2 = "some_def_2"(%arg5) : (vector<128xf32>) -> (vector<128xf32>)
+      %acc3 = "some_def_3"(%arg6) : (vector<128xf32>) -> (vector<128xf32>)
+      scf.yield %acc1, %acc2, %acc3 : vector<256xf32>, vector<128xf32>, vector<128xf32>
+    }
+    gpu.yield %3#2, %3#1, %3#0 : vector<128xf32>, vector<128xf32>, vector<256xf32>
+  }
+  "some_use_1"(%0#0) : (vector<4xf32>) -> ()
+  "some_use_2"(%0#1) : (vector<4xf32>) -> ()
+  "some_use_3"(%0#2) : (vector<8xf32>) -> ()
+  return
+}
+
 // -----
 
 // CHECK-PROP-LABEL: func @vector_reduction(

Copy link
Contributor

@kurapov-peter kurapov-peter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transformation looks correct to me, nice! Wouldn't we need basically the same thing for other scf ops? Seems like this can be generalized.

// for now.
if (!layout)
return AffineMap::getMultiDimMapWithTargets(
vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes 2d, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. At SG level layout is currently only 2d. But upstream distribution makes no such assumption. Only assumption there is only 1 dim is distributed. We need to add more support there in future.

Comment on lines 1789 to 1791
SmallVector<Value> nonForYieldedValues;
SmallVector<unsigned> nonForResultIndices;
DenseMap<unsigned, unsigned> forResultMapping;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can probably do the mapping with some existing tools like a value to value map.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about using IRMapping. but it has the same amount of book keeping. So I don't see any added benefit. DenseMap does the job and code looks easy to read in my view.

But if you can point to some code example with value mapping I will reconsider it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like function cloning comes to mind, but if it's the same effort/complexity, it doesn't really matter, I suppose.

forOp.getStep(), newOperands);
forOp.getStep(), newForOpOperands);
// Next, we insert a new warp op (called inner warp op) inside the
// newly created for op. This warp op will contain all ops that were
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, the comments would be easier to read if they highlight the op names, e.g.

Suggested change
// newly created for op. This warp op will contain all ops that were
// newly created `ForOp`. This warp op will contain all ops that were

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleaned up the comments. thanks!

}
// Next, we create a new for op with the init args yielded by the new
// warp op.
unsigned escapingValuesStartIdx =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
unsigned escapingValuesStartIdx =
const unsigned escapingValuesStartIdx =

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

@charithaintc
Copy link
Contributor Author

The transformation looks correct to me, nice! Wouldn't we need basically the same thing for other scf ops? Seems like this can be generalized.

Thanks for the review Peter. Currently we don't have support for other scf ops (like scf.if, scf.while) which we may need later. I will think about generalization steps once we have more use cases for these ops. But appreciate your suggestion. This pattern is very complex to begin with, so some kind of abstraction to deal with these moving parts will def help.

SmallVector<Type> inputTypes;
SmallVector<Type> distTypes;
SmallVector<Type> escapingValueInputTypes;
SmallVector<Type> escapingValuedistTypes;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dist->Dist

Copy link
Contributor

@kurapov-peter kurapov-peter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants