From d463fa94d155d25f6f0ef15daa5e397ae8806f8e Mon Sep 17 00:00:00 2001 From: Michael Marjieh Date: Wed, 9 Jul 2025 13:31:50 +0300 Subject: [PATCH] [mlir][scf] Implement Conversion from scf.parallel to Nested scf.for Add a utility function/transform operation to convert `scf.parallel` loops to nested `scf.for` loops. --- .../SCF/TransformOps/SCFTransformOps.td | 28 ++++++ .../mlir/Dialect/SCF/Transforms/Passes.h | 3 + .../mlir/Dialect/SCF/Transforms/Passes.td | 11 +++ .../mlir/Dialect/SCF/Transforms/Transforms.h | 6 ++ .../SCF/TransformOps/SCFTransformOps.cpp | 38 ++++++++ .../lib/Dialect/SCF/Transforms/CMakeLists.txt | 1 + .../Transforms/ParallelForToNestedFors.cpp | 91 +++++++++++++++++++ .../Dialect/SCF/parallel-to-nested-fors.mlir | 80 ++++++++++++++++ .../transform-op-parallel-to-nested-fors.mlir | 62 +++++++++++++ 9 files changed, 320 insertions(+) create mode 100644 mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp create mode 100644 mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir create mode 100644 mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td index 5dba8c5e57ba8..e2b42208f3f8e 100644 --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -105,6 +105,34 @@ def ForallToParallelOp : Op]> { + let summary = "Converts scf.parallel into a nest of scf.for operations"; + let description = [{ + Converts the `scf.parallel` operation pointed to by the given handle into a + set of nested `scf.for` operations. Each new operation corresponds to one + dimension of the original parallel loop. + + The operand handle must be associated with exactly one payload operation. + + Loops with shared outputs are currently not supported. + + #### Return Modes + + Consumes the operand handle. Produces a silenceable failure if the operand + is not associated with a single `scf.parallel` payload operation. + Returns as many handles as the given `parallel` op has dimensions that are + associated with the generated `scf.for` loops. + Produces a silenceable failure if another number of resulting handles is + requested. + }]; + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs Variadic:$transformed); + + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)"; +} + def LoopOutlineOp : Op]> { diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h index b70599df6f503..54b0118507184 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h @@ -62,6 +62,9 @@ std::unique_ptr createForallToForLoopPass(); /// Creates a pass that converts SCF forall loops to SCF parallel loops. std::unique_ptr createForallToParallelLoopPass(); +/// Creates a pass that converts SCF forall loops to SCF parallel loops. +std::unique_ptr createParallelForToNestedForsPass(); + // Creates a pass which lowers for loops into while loops. std::unique_ptr createForToWhileLoopPass(); diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td index 6e5ef96c450aa..afa4ef460c219 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td @@ -124,6 +124,17 @@ def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> { let constructor = "mlir::createForallToParallelLoopPass()"; } +def SCFParallelForToNestedFors : Pass<"scf-parallel-for-to-nested-fors"> { + let summary = "Convert SCF parallel for loops to nested SCF for loops"; + let constructor = "mlir::createParallelForToNestedForsPass()"; + let description = [{ + This pass transforms SCF.ParallelOp operations into a nest of SCF.ForOp + operations. The transformation is useful for cases where the parallel loop + can be expressed as a series of sequential iterations, allowing for more + fine-grained control over the loop execution. + }]; +} + def SCFForToWhileLoop : Pass<"scf-for-to-while"> { let summary = "Convert SCF for loops to SCF while loops"; let constructor = "mlir::createForToWhileLoopPass()"; diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h index 63163b77f7f16..5e613238d016d 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h @@ -42,6 +42,12 @@ LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp, LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp, ParallelOp *result = nullptr); +/// Try converting scf.forall into an scf.parallel loop. +/// The conversion is only supported for forall operations with no results. +LogicalResult parallelForToNestedFors(RewriterBase &rewriter, + ParallelOp parallelOp, + ForOp *result = nullptr); + /// Fuses all adjacent scf.parallel operations with identical bounds and step /// into one scf.parallel operations. Uses a naive aliasing and dependency /// analysis. diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 57c27231f2144..7fd9255c490ef 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -149,6 +149,44 @@ transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter, return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// ParallelForToNestedForOps +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::ParallelForToNestedForOps::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { + auto payload = state.getPayloadOps(getTarget()); + if (!llvm::hasSingleElement(payload)) + return emitSilenceableError() << "expected a single payload op"; + + auto target = dyn_cast(*payload.begin()); + if (!target) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "expected the payload to be scf.parallel"; + diag.attachNote((*payload.begin())->getLoc()) << "payload op"; + return diag; + } + + if (getNumResults() != 1) { + DiagnosedSilenceableFailure diag = emitSilenceableError() + << "op expects one result, given " + << getNumResults(); + diag.attachNote(target.getLoc()) << "payload op"; + return diag; + } + + scf::ForOp opResult; + if (failed(scf::parallelForToNestedFors(rewriter, target, &opResult))) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "failed to convert parallel into nested fors"; + return diag; + } + + results.set(cast(getTransformed()[0]), {opResult}); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // LoopOutlineOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt index 84dd992bec53a..a9ffa9dc208a0 100644 --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRSCFTransforms LoopPipelining.cpp LoopRangeFolding.cpp LoopSpecialization.cpp + ParallelForToNestedFors.cpp ParallelLoopCollapsing.cpp ParallelLoopFusion.cpp ParallelLoopTiling.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp new file mode 100644 index 0000000000000..75672f1c9239e --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp @@ -0,0 +1,91 @@ +//===- ParallelForToNestedFors.cpp - scf.parallel to nested scf.for ops --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Transforms SCF.ParallelOp to nested scf.for ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +#define GEN_PASS_DEF_SCFPARALLELFORTONESTEDFORS +#include "mlir/Dialect/SCF/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +LogicalResult mlir::scf::parallelForToNestedFors(RewriterBase &rewriter, + scf::ParallelOp parallelOp, + scf::ForOp *result) { + + if (!parallelOp.getResults().empty()) { + parallelOp->emitError("Currently ScfParallel to ScfFor conversion " + "doesn't support ScfParallel with results."); + return failure(); + } + + rewriter.setInsertionPoint(parallelOp); + + Location loc = parallelOp.getLoc(); + auto lowerBounds = parallelOp.getLowerBound(); + auto upperBounds = parallelOp.getUpperBound(); + auto steps = parallelOp.getStep(); + + assert(lowerBounds.size() == upperBounds.size() && + lowerBounds.size() == steps.size() && + "Mismatched parallel loop bounds"); + + SmallVector ivs; + auto loopNest = + scf::buildLoopNest(rewriter, loc, lowerBounds, upperBounds, steps); + + auto oldInductionVars = parallelOp.getInductionVars(); + auto newInductionVars = llvm::map_to_vector( + loopNest.loops, [](scf::ForOp forOp) { return forOp.getInductionVar(); }); + assert(oldInductionVars.size() == newInductionVars.size() && + "Mismatched induction variables"); + for (auto [oldIV, newIV] : llvm::zip(oldInductionVars, newInductionVars)) + oldIV.replaceAllUsesWith(newIV); + + auto *linearizedBody = loopNest.loops.back().getBody(); + Block ¶llelBody = *parallelOp.getBody(); + for (Operation &op : llvm::make_early_inc_range(parallelBody)) { + // Skip the terminator of the parallelOp body. + if (&op == parallelBody.getTerminator()) + continue; + op.moveBefore(linearizedBody->getTerminator()); + } + rewriter.eraseOp(parallelOp); + if (result) + *result = loopNest.loops.front(); + return success(); +} + +namespace { +struct ParallelForToNestedFors final + : public impl::SCFParallelForToNestedForsBase { + void runOnOperation() override { + Operation *parentOp = getOperation(); + IRRewriter rewriter(parentOp->getContext()); + + parentOp->walk([&](scf::ParallelOp parallelOp) { + if (failed(scf::parallelForToNestedFors(rewriter, parallelOp))) { + return signalPassFailure(); + } + }); + } +}; +} // namespace + +std::unique_ptr mlir::createParallelForToNestedForsPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir b/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir new file mode 100644 index 0000000000000..4df7bab790ea5 --- /dev/null +++ b/mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir @@ -0,0 +1,80 @@ +// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-parallel-for-to-nested-fors))' -split-input-file -verify-diagnostics | FileCheck %s + +func.func private @callee(%i: index, %j: index) + +func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) { + scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) { + func.call @callee(%i, %j) : (index, index) -> () + } + // CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] { + // CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] { + // CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> () + // CHECK: } + // CHECK: } + return +} + +// ----- + +func.func private @callee(%i: index, %j: index) + +func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) { + scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) { + func.call @callee(%i, %j) : (index, index) -> () + } + + scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) { + func.call @callee(%i, %j) : (index, index) -> () + } + // CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] { + // CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] { + // CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> () + // CHECK: } + // CHECK: } + // CHECK: scf.for %[[VAL_2:.*]] = %[[ARG0]] to %[[ARG2]] step %[[ARG4]] { + // CHECK: scf.for %[[VAL_3:.*]] = %[[ARG1]] to %[[ARG3]] step %[[ARG5]] { + // CHECK: func.call @callee(%[[VAL_2]], %[[VAL_3]]) : (index, index) -> () + // CHECK: } + // CHECK: } + + return +} + +// ----- + +func.func private @callee(%i: index, %j: index, %k: index, %l: index) + +func.func @nested(%lb1: index, %lb2: index, %lb3: index, %lb4: index, %ub1: index, %ub2: index, %ub3: index, %ub4: index, %step1: index, %step2: index, %step3: index, %step4: index) { + scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) { + scf.parallel (%k, %l) = (%lb3, %lb4) to (%ub3, %ub4) step (%step3, %step4) { + func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> () + } + } + // CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG4:.*]] step %[[ARG8:.*]] { + // CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG5:.*]] step %[[ARG9:.*]] { + // CHECK: scf.for %[[VAL_2:.*]] = %[[ARG2:.*]] to %[[ARG6:.*]] step %[[ARG10:.*]] { + // CHECK: scf.for %[[VAL_3:.*]] = %[[ARG3:.*]] to %[[ARG7:.*]] step %[[ARG11:.*]] { + // CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) : (index, index, index, index) -> () + // CHECK: } + // CHECK: } + // CHECK: } + // CHECK: } + return +} + +// ----- +func.func private @callee(%i: index, %j: index) -> i32 + +func.func @two_iters_with_reduce(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) -> i32 { + %c0 = arith.constant 0 : i32 + // expected-error@+1 {{Currently ScfParallel to ScfFor conversion doesn't support ScfParallel with results}} + %0 = scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) init (%c0) -> i32 { + %curr = func.call @callee(%i, %j) : (index, index) -> i32 + scf.reduce(%curr : i32) { + ^bb0(%arg3: i32, %arg4: i32): + %3 = arith.addi %arg3, %arg4 : i32 + scf.reduce.return %3 : i32 + } + } + return %0 : i32 +} \ No newline at end of file diff --git a/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir b/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir new file mode 100644 index 0000000000000..496123b288038 --- /dev/null +++ b/mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s + +func.func private @callee(%i: index, %j: index) + +func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) { + scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) { + func.call @callee(%i, %j) : (index, index) -> () + } + // CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] { + // CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] { + // CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> () + // CHECK: } + // CHECK: } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// ----- + +func.func private @callee(%i: index, %j: index) + +func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) { + scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) { + func.call @callee(%i, %j) : (index, index) -> () + } + + scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) { + func.call @callee(%i, %j) : (index, index) -> () + } + + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-error @below {{expected a single payload op}} + transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// ----- + +// expected-note @below {{payload op}} +func.func private @callee(%i: index, %j: index) + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-error @below {{expected the payload to be scf.parallel}} + transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> !transform.any_op + transform.yield + } +}