Skip to content

[mlir][scf] Implement Conversion from scf.parallel to Nested scf.for #147692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mmarjieh
Copy link
Contributor

@mmarjieh mmarjieh commented Jul 9, 2025

Add a utility function/transform operation to convert scf.parallel loops to nested scf.for loops.

Add a utility function/transform operation to convert `scf.parallel`
loops to nested `scf.for` loops.
@llvmbot
Copy link
Member

llvmbot commented Jul 9, 2025

@llvm/pr-subscribers-mlir-scf

Author: Michael Marjieh (mmarjieh)

Changes

Add a utility function/transform operation to convert scf.parallel loops to nested scf.for loops.


Full diff: https://github.com/llvm/llvm-project/pull/147692.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td (+28)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.h (+3)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.td (+11)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h (+6)
  • (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+38)
  • (modified) mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp (+91)
  • (added) mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir (+80)
  • (added) mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir (+62)
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 5dba8c5e57ba8..e2b42208f3f8e 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -105,6 +105,34 @@ def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
   let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
 }
 
+def ParallelForToNestedForOps : Op<Transform_Dialect, "loop.parallel_for_to_nested_fors",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let summary = "Converts scf.parallel into a nest of scf.for operations";
+  let description = [{
+    Converts the `scf.parallel` operation pointed to by the given handle into a
+    set of nested `scf.for` operations. Each new operation corresponds to one
+    dimension of the original parallel loop.
+
+    The operand handle must be associated with exactly one payload operation.
+
+    Loops with shared outputs are currently not supported.
+
+    #### Return Modes
+
+    Consumes the operand handle. Produces a silenceable failure if the operand
+    is not associated with a single `scf.parallel` payload operation.
+    Returns as many handles as the given `parallel` op has dimensions that are
+    associated with the generated `scf.for` loops.
+    Produces a silenceable failure if another number of resulting handles is
+    requested.
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
+
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+}
+
 def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      DeclareOpInterfaceMethods<TransformOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index b70599df6f503..54b0118507184 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForallToForLoopPass();
 /// Creates a pass that converts SCF forall loops to SCF parallel loops.
 std::unique_ptr<Pass> createForallToParallelLoopPass();
 
+/// Creates a pass that converts SCF forall loops to SCF parallel loops.
+std::unique_ptr<Pass> createParallelForToNestedForsPass();
+
 // Creates a pass which lowers for loops into while loops.
 std::unique_ptr<Pass> createForToWhileLoopPass();
 
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 6e5ef96c450aa..afa4ef460c219 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -124,6 +124,17 @@ def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
   let constructor = "mlir::createForallToParallelLoopPass()";
 }
 
+def SCFParallelForToNestedFors : Pass<"scf-parallel-for-to-nested-fors"> {
+  let summary = "Convert SCF parallel for loops to nested SCF for loops";
+  let constructor = "mlir::createParallelForToNestedForsPass()";
+  let description = [{
+    This pass transforms SCF.ParallelOp operations into a nest of SCF.ForOp
+    operations. The transformation is useful for cases where the parallel loop
+    can be expressed as a series of sequential iterations, allowing for more
+    fine-grained control over the loop execution.
+  }];
+}
+
 def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   let summary = "Convert SCF for loops to SCF while loops";
   let constructor = "mlir::createForToWhileLoopPass()";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 63163b77f7f16..5e613238d016d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -42,6 +42,12 @@ LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
 LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp,
                                    ParallelOp *result = nullptr);
 
+/// Try converting scf.forall into an scf.parallel loop.
+/// The conversion is only supported for forall operations with no results.
+LogicalResult parallelForToNestedFors(RewriterBase &rewriter,
+                                      ParallelOp parallelOp,
+                                      ForOp *result = nullptr);
+
 /// Fuses all adjacent scf.parallel operations with identical bounds and step
 /// into one scf.parallel operations. Uses a naive aliasing and dependency
 /// analysis.
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 57c27231f2144..7fd9255c490ef 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -149,6 +149,44 @@ transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ParallelForToNestedForOps
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::ParallelForToNestedForOps::apply(
+    transform::TransformRewriter &rewriter,
+    transform::TransformResults &results, transform::TransformState &state) {
+  auto payload = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(payload))
+    return emitSilenceableError() << "expected a single payload op";
+
+  auto target = dyn_cast<scf::ParallelOp>(*payload.begin());
+  if (!target) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "expected the payload to be scf.parallel";
+    diag.attachNote((*payload.begin())->getLoc()) << "payload op";
+    return diag;
+  }
+
+  if (getNumResults() != 1) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "op expects one result, given "
+                                       << getNumResults();
+    diag.attachNote(target.getLoc()) << "payload op";
+    return diag;
+  }
+
+  scf::ForOp opResult;
+  if (failed(scf::parallelForToNestedFors(rewriter, target, &opResult))) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "failed to convert parallel into nested fors";
+    return diag;
+  }
+
+  results.set(cast<OpResult>(getTransformed()[0]), {opResult});
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LoopOutlineOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 84dd992bec53a..a9ffa9dc208a0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   LoopPipelining.cpp
   LoopRangeFolding.cpp
   LoopSpecialization.cpp
+  ParallelForToNestedFors.cpp
   ParallelLoopCollapsing.cpp
   ParallelLoopFusion.cpp
   ParallelLoopTiling.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
new file mode 100644
index 0000000000000..75672f1c9239e
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
@@ -0,0 +1,91 @@
+//===- ParallelForToNestedFors.cpp - scf.parallel to nested scf.for ops --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.ParallelOp to nested scf.for ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFPARALLELFORTONESTEDFORS
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+LogicalResult mlir::scf::parallelForToNestedFors(RewriterBase &rewriter,
+                                                 scf::ParallelOp parallelOp,
+                                                 scf::ForOp *result) {
+
+  if (!parallelOp.getResults().empty()) {
+    parallelOp->emitError("Currently ScfParallel to ScfFor conversion "
+                          "doesn't support ScfParallel with results.");
+    return failure();
+  }
+
+  rewriter.setInsertionPoint(parallelOp);
+
+  Location loc = parallelOp.getLoc();
+  auto lowerBounds = parallelOp.getLowerBound();
+  auto upperBounds = parallelOp.getUpperBound();
+  auto steps = parallelOp.getStep();
+
+  assert(lowerBounds.size() == upperBounds.size() &&
+         lowerBounds.size() == steps.size() &&
+         "Mismatched parallel loop bounds");
+
+  SmallVector<Value> ivs;
+  auto loopNest =
+      scf::buildLoopNest(rewriter, loc, lowerBounds, upperBounds, steps);
+
+  auto oldInductionVars = parallelOp.getInductionVars();
+  auto newInductionVars = llvm::map_to_vector(
+      loopNest.loops, [](scf::ForOp forOp) { return forOp.getInductionVar(); });
+  assert(oldInductionVars.size() == newInductionVars.size() &&
+         "Mismatched induction variables");
+  for (auto [oldIV, newIV] : llvm::zip(oldInductionVars, newInductionVars))
+    oldIV.replaceAllUsesWith(newIV);
+
+  auto *linearizedBody = loopNest.loops.back().getBody();
+  Block &parallelBody = *parallelOp.getBody();
+  for (Operation &op : llvm::make_early_inc_range(parallelBody)) {
+    // Skip the terminator of the parallelOp body.
+    if (&op == parallelBody.getTerminator())
+      continue;
+    op.moveBefore(linearizedBody->getTerminator());
+  }
+  rewriter.eraseOp(parallelOp);
+  if (result)
+    *result = loopNest.loops.front();
+  return success();
+}
+
+namespace {
+struct ParallelForToNestedFors final
+    : public impl::SCFParallelForToNestedForsBase<ParallelForToNestedFors> {
+  void runOnOperation() override {
+    Operation *parentOp = getOperation();
+    IRRewriter rewriter(parentOp->getContext());
+
+    parentOp->walk([&](scf::ParallelOp parallelOp) {
+      if (failed(scf::parallelForToNestedFors(rewriter, parallelOp))) {
+        return signalPassFailure();
+      }
+    });
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createParallelForToNestedForsPass() {
+  return std::make_unique<ParallelForToNestedFors>();
+}
diff --git a/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir b/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir
new file mode 100644
index 0000000000000..4df7bab790ea5
--- /dev/null
+++ b/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir
@@ -0,0 +1,80 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-parallel-for-to-nested-fors))' -split-input-file -verify-diagnostics | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK:           scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
+  // CHECK:             scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
+  // CHECK:               func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
+  // CHECK:             }
+  // CHECK:           }
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK:           scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
+  // CHECK:             scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
+  // CHECK:               func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
+  // CHECK:             }
+  // CHECK:           }
+  // CHECK:           scf.for %[[VAL_2:.*]] = %[[ARG0]] to %[[ARG2]] step %[[ARG4]] {
+  // CHECK:             scf.for %[[VAL_3:.*]] = %[[ARG1]] to %[[ARG3]] step %[[ARG5]] {
+  // CHECK:               func.call @callee(%[[VAL_2]], %[[VAL_3]]) : (index, index) -> ()
+  // CHECK:             }
+  // CHECK:           }
+
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index, %k: index, %l: index)
+
+func.func @nested(%lb1: index, %lb2: index, %lb3: index, %lb4: index, %ub1: index, %ub2: index, %ub3: index, %ub4: index, %step1: index, %step2: index, %step3: index, %step4: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    scf.parallel (%k, %l) = (%lb3, %lb4) to (%ub3, %ub4) step (%step3, %step4) {
+      func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
+    }
+  }
+  // CHECK:           scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG4:.*]] step %[[ARG8:.*]] {
+  // CHECK:             scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG5:.*]] step %[[ARG9:.*]] {
+  // CHECK:               scf.for %[[VAL_2:.*]] = %[[ARG2:.*]] to %[[ARG6:.*]] step %[[ARG10:.*]] {
+  // CHECK:                 scf.for %[[VAL_3:.*]] = %[[ARG3:.*]] to %[[ARG7:.*]] step %[[ARG11:.*]] {
+  // CHECK:                   func.call @callee(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) : (index, index, index, index) -> ()
+  // CHECK:                 }
+  // CHECK:               }
+  // CHECK:             }
+  // CHECK:           }
+  return
+}
+
+// -----
+func.func private @callee(%i: index, %j: index) -> i32
+
+func.func @two_iters_with_reduce(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) -> i32 {
+  %c0 = arith.constant 0 : i32
+  // expected-error@+1 {{Currently ScfParallel to ScfFor conversion doesn't support ScfParallel with results}}
+  %0 = scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) init (%c0) -> i32 {
+    %curr = func.call @callee(%i, %j) : (index, index) -> i32
+    scf.reduce(%curr : i32) {
+      ^bb0(%arg3: i32, %arg4: i32):
+        %3 = arith.addi %arg3, %arg4 : i32
+        scf.reduce.return %3 : i32
+    }
+  }
+  return %0 : i32
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir b/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir
new file mode 100644
index 0000000000000..496123b288038
--- /dev/null
+++ b/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK:           scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
+  // CHECK:             scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
+  // CHECK:               func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
+  // CHECK:             }
+  // CHECK:           }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{expected a single payload op}}
+    transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// expected-note @below {{payload op}}
+func.func private @callee(%i: index, %j: index)
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{expected the payload to be scf.parallel}}
+    transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

@llvmbot
Copy link
Member

llvmbot commented Jul 9, 2025

@llvm/pr-subscribers-mlir

Author: Michael Marjieh (mmarjieh)

Changes

Add a utility function/transform operation to convert scf.parallel loops to nested scf.for loops.


Full diff: https://github.com/llvm/llvm-project/pull/147692.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td (+28)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.h (+3)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.td (+11)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h (+6)
  • (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+38)
  • (modified) mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp (+91)
  • (added) mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir (+80)
  • (added) mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir (+62)
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 5dba8c5e57ba8..e2b42208f3f8e 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -105,6 +105,34 @@ def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
   let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
 }
 
+def ParallelForToNestedForOps : Op<Transform_Dialect, "loop.parallel_for_to_nested_fors",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     DeclareOpInterfaceMethods<TransformOpInterface>]> {
+  let summary = "Converts scf.parallel into a nest of scf.for operations";
+  let description = [{
+    Converts the `scf.parallel` operation pointed to by the given handle into a
+    set of nested `scf.for` operations. Each new operation corresponds to one
+    dimension of the original parallel loop.
+
+    The operand handle must be associated with exactly one payload operation.
+
+    Loops with shared outputs are currently not supported.
+
+    #### Return Modes
+
+    Consumes the operand handle. Produces a silenceable failure if the operand
+    is not associated with a single `scf.parallel` payload operation.
+    Returns as many handles as the given `parallel` op has dimensions that are
+    associated with the generated `scf.for` loops.
+    Produces a silenceable failure if another number of resulting handles is
+    requested.
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
+
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+}
+
 def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      DeclareOpInterfaceMethods<TransformOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index b70599df6f503..54b0118507184 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForallToForLoopPass();
 /// Creates a pass that converts SCF forall loops to SCF parallel loops.
 std::unique_ptr<Pass> createForallToParallelLoopPass();
 
+/// Creates a pass that converts SCF forall loops to SCF parallel loops.
+std::unique_ptr<Pass> createParallelForToNestedForsPass();
+
 // Creates a pass which lowers for loops into while loops.
 std::unique_ptr<Pass> createForToWhileLoopPass();
 
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 6e5ef96c450aa..afa4ef460c219 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -124,6 +124,17 @@ def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
   let constructor = "mlir::createForallToParallelLoopPass()";
 }
 
+def SCFParallelForToNestedFors : Pass<"scf-parallel-for-to-nested-fors"> {
+  let summary = "Convert SCF parallel for loops to nested SCF for loops";
+  let constructor = "mlir::createParallelForToNestedForsPass()";
+  let description = [{
+    This pass transforms SCF.ParallelOp operations into a nest of SCF.ForOp
+    operations. The transformation is useful for cases where the parallel loop
+    can be expressed as a series of sequential iterations, allowing for more
+    fine-grained control over the loop execution.
+  }];
+}
+
 def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
   let summary = "Convert SCF for loops to SCF while loops";
   let constructor = "mlir::createForToWhileLoopPass()";
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 63163b77f7f16..5e613238d016d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -42,6 +42,12 @@ LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
 LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp,
                                    ParallelOp *result = nullptr);
 
+/// Try converting scf.forall into an scf.parallel loop.
+/// The conversion is only supported for forall operations with no results.
+LogicalResult parallelForToNestedFors(RewriterBase &rewriter,
+                                      ParallelOp parallelOp,
+                                      ForOp *result = nullptr);
+
 /// Fuses all adjacent scf.parallel operations with identical bounds and step
 /// into one scf.parallel operations. Uses a naive aliasing and dependency
 /// analysis.
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 57c27231f2144..7fd9255c490ef 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -149,6 +149,44 @@ transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ParallelForToNestedForOps
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::ParallelForToNestedForOps::apply(
+    transform::TransformRewriter &rewriter,
+    transform::TransformResults &results, transform::TransformState &state) {
+  auto payload = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(payload))
+    return emitSilenceableError() << "expected a single payload op";
+
+  auto target = dyn_cast<scf::ParallelOp>(*payload.begin());
+  if (!target) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "expected the payload to be scf.parallel";
+    diag.attachNote((*payload.begin())->getLoc()) << "payload op";
+    return diag;
+  }
+
+  if (getNumResults() != 1) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "op expects one result, given "
+                                       << getNumResults();
+    diag.attachNote(target.getLoc()) << "payload op";
+    return diag;
+  }
+
+  scf::ForOp opResult;
+  if (failed(scf::parallelForToNestedFors(rewriter, target, &opResult))) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "failed to convert parallel into nested fors";
+    return diag;
+  }
+
+  results.set(cast<OpResult>(getTransformed()[0]), {opResult});
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LoopOutlineOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 84dd992bec53a..a9ffa9dc208a0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
   LoopPipelining.cpp
   LoopRangeFolding.cpp
   LoopSpecialization.cpp
+  ParallelForToNestedFors.cpp
   ParallelLoopCollapsing.cpp
   ParallelLoopFusion.cpp
   ParallelLoopTiling.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
new file mode 100644
index 0000000000000..75672f1c9239e
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
@@ -0,0 +1,91 @@
+//===- ParallelForToNestedFors.cpp - scf.parallel to nested scf.for ops --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.ParallelOp to nested scf.for ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFPARALLELFORTONESTEDFORS
+#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+LogicalResult mlir::scf::parallelForToNestedFors(RewriterBase &rewriter,
+                                                 scf::ParallelOp parallelOp,
+                                                 scf::ForOp *result) {
+
+  if (!parallelOp.getResults().empty()) {
+    parallelOp->emitError("Currently ScfParallel to ScfFor conversion "
+                          "doesn't support ScfParallel with results.");
+    return failure();
+  }
+
+  rewriter.setInsertionPoint(parallelOp);
+
+  Location loc = parallelOp.getLoc();
+  auto lowerBounds = parallelOp.getLowerBound();
+  auto upperBounds = parallelOp.getUpperBound();
+  auto steps = parallelOp.getStep();
+
+  assert(lowerBounds.size() == upperBounds.size() &&
+         lowerBounds.size() == steps.size() &&
+         "Mismatched parallel loop bounds");
+
+  SmallVector<Value> ivs;
+  auto loopNest =
+      scf::buildLoopNest(rewriter, loc, lowerBounds, upperBounds, steps);
+
+  auto oldInductionVars = parallelOp.getInductionVars();
+  auto newInductionVars = llvm::map_to_vector(
+      loopNest.loops, [](scf::ForOp forOp) { return forOp.getInductionVar(); });
+  assert(oldInductionVars.size() == newInductionVars.size() &&
+         "Mismatched induction variables");
+  for (auto [oldIV, newIV] : llvm::zip(oldInductionVars, newInductionVars))
+    oldIV.replaceAllUsesWith(newIV);
+
+  auto *linearizedBody = loopNest.loops.back().getBody();
+  Block &parallelBody = *parallelOp.getBody();
+  for (Operation &op : llvm::make_early_inc_range(parallelBody)) {
+    // Skip the terminator of the parallelOp body.
+    if (&op == parallelBody.getTerminator())
+      continue;
+    op.moveBefore(linearizedBody->getTerminator());
+  }
+  rewriter.eraseOp(parallelOp);
+  if (result)
+    *result = loopNest.loops.front();
+  return success();
+}
+
+namespace {
+struct ParallelForToNestedFors final
+    : public impl::SCFParallelForToNestedForsBase<ParallelForToNestedFors> {
+  void runOnOperation() override {
+    Operation *parentOp = getOperation();
+    IRRewriter rewriter(parentOp->getContext());
+
+    parentOp->walk([&](scf::ParallelOp parallelOp) {
+      if (failed(scf::parallelForToNestedFors(rewriter, parallelOp))) {
+        return signalPassFailure();
+      }
+    });
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createParallelForToNestedForsPass() {
+  return std::make_unique<ParallelForToNestedFors>();
+}
diff --git a/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir b/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir
new file mode 100644
index 0000000000000..4df7bab790ea5
--- /dev/null
+++ b/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir
@@ -0,0 +1,80 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-parallel-for-to-nested-fors))' -split-input-file -verify-diagnostics | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK:           scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
+  // CHECK:             scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
+  // CHECK:               func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
+  // CHECK:             }
+  // CHECK:           }
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK:           scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
+  // CHECK:             scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
+  // CHECK:               func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
+  // CHECK:             }
+  // CHECK:           }
+  // CHECK:           scf.for %[[VAL_2:.*]] = %[[ARG0]] to %[[ARG2]] step %[[ARG4]] {
+  // CHECK:             scf.for %[[VAL_3:.*]] = %[[ARG1]] to %[[ARG3]] step %[[ARG5]] {
+  // CHECK:               func.call @callee(%[[VAL_2]], %[[VAL_3]]) : (index, index) -> ()
+  // CHECK:             }
+  // CHECK:           }
+
+  return
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index, %k: index, %l: index)
+
+func.func @nested(%lb1: index, %lb2: index, %lb3: index, %lb4: index, %ub1: index, %ub2: index, %ub3: index, %ub4: index, %step1: index, %step2: index, %step3: index, %step4: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    scf.parallel (%k, %l) = (%lb3, %lb4) to (%ub3, %ub4) step (%step3, %step4) {
+      func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
+    }
+  }
+  // CHECK:           scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG4:.*]] step %[[ARG8:.*]] {
+  // CHECK:             scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG5:.*]] step %[[ARG9:.*]] {
+  // CHECK:               scf.for %[[VAL_2:.*]] = %[[ARG2:.*]] to %[[ARG6:.*]] step %[[ARG10:.*]] {
+  // CHECK:                 scf.for %[[VAL_3:.*]] = %[[ARG3:.*]] to %[[ARG7:.*]] step %[[ARG11:.*]] {
+  // CHECK:                   func.call @callee(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) : (index, index, index, index) -> ()
+  // CHECK:                 }
+  // CHECK:               }
+  // CHECK:             }
+  // CHECK:           }
+  return
+}
+
+// -----
+func.func private @callee(%i: index, %j: index) -> i32
+
+func.func @two_iters_with_reduce(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) -> i32 {
+  %c0 = arith.constant 0 : i32
+  // expected-error@+1 {{Currently ScfParallel to ScfFor conversion doesn't support ScfParallel with results}}
+  %0 = scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) init (%c0) -> i32 {
+    %curr = func.call @callee(%i, %j) : (index, index) -> i32
+    scf.reduce(%curr : i32) {
+      ^bb0(%arg3: i32, %arg4: i32):
+        %3 = arith.addi %arg3, %arg4 : i32
+        scf.reduce.return %3 : i32
+    }
+  }
+  return %0 : i32
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir b/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir
new file mode 100644
index 0000000000000..496123b288038
--- /dev/null
+++ b/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir
@@ -0,0 +1,62 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+  // CHECK:           scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
+  // CHECK:             scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
+  // CHECK:               func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
+  // CHECK:             }
+  // CHECK:           }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @callee(%i: index, %j: index)
+
+func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    func.call @callee(%i, %j) : (index, index) -> ()
+  }
+
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{expected a single payload op}}
+    transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
+// expected-note @below {{payload op}}
+func.func private @callee(%i: index, %j: index)
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    // expected-error @below {{expected the payload to be scf.parallel}}
+    transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

@MaheshRavishankar MaheshRavishankar requested a review from Max191 July 9, 2025 15:50
@MaheshRavishankar
Copy link
Contributor

cc @Max191 who implemented something similar in IREE. Maybe we can swap the downstream implementation to the upstream one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants