Skip to content

[mlir][scf] Implement Conversion from scf.parallel to Nested scf.for #147692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,34 @@ def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
}

def ParallelForToNestedForOps : Op<Transform_Dialect, "loop.parallel_for_to_nested_fors",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let summary = "Converts scf.parallel into a nest of scf.for operations";
let description = [{
Converts the `scf.parallel` operation pointed to by the given handle into a
set of nested `scf.for` operations. Each new operation corresponds to one
dimension of the original parallel loop.

The operand handle must be associated with exactly one payload operation.

Loops with shared outputs are currently not supported.

#### Return Modes

Consumes the operand handle. Produces a silenceable failure if the operand
is not associated with a single `scf.parallel` payload operation.
Returns as many handles as the given `parallel` op has dimensions that are
associated with the generated `scf.for` loops.
Produces a silenceable failure if another number of resulting handles is
requested.
}];
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);

let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
}

def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForallToForLoopPass();
/// Creates a pass that converts SCF forall loops to SCF parallel loops.
std::unique_ptr<Pass> createForallToParallelLoopPass();

/// Creates a pass that converts SCF forall loops to SCF parallel loops.
std::unique_ptr<Pass> createParallelForToNestedForsPass();

// Creates a pass which lowers for loops into while loops.
std::unique_ptr<Pass> createForToWhileLoopPass();

Expand Down
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,17 @@ def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
let constructor = "mlir::createForallToParallelLoopPass()";
}

def SCFParallelForToNestedFors : Pass<"scf-parallel-for-to-nested-fors"> {
let summary = "Convert SCF parallel for loops to nested SCF for loops";
let constructor = "mlir::createParallelForToNestedForsPass()";
let description = [{
This pass transforms SCF.ParallelOp operations into a nest of SCF.ForOp
operations. The transformation is useful for cases where the parallel loop
can be expressed as a series of sequential iterations, allowing for more
fine-grained control over the loop execution.
}];
}

def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
let summary = "Convert SCF for loops to SCF while loops";
let constructor = "mlir::createForToWhileLoopPass()";
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp,
ParallelOp *result = nullptr);

/// Try converting scf.forall into an scf.parallel loop.
/// The conversion is only supported for forall operations with no results.
LogicalResult parallelForToNestedFors(RewriterBase &rewriter,
ParallelOp parallelOp,
ForOp *result = nullptr);

/// Fuses all adjacent scf.parallel operations with identical bounds and step
/// into one scf.parallel operations. Uses a naive aliasing and dependency
/// analysis.
Expand Down
38 changes: 38 additions & 0 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,44 @@ transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// ParallelForToNestedForOps
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::ParallelForToNestedForOps::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
auto payload = state.getPayloadOps(getTarget());
if (!llvm::hasSingleElement(payload))
return emitSilenceableError() << "expected a single payload op";

auto target = dyn_cast<scf::ParallelOp>(*payload.begin());
if (!target) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "expected the payload to be scf.parallel";
diag.attachNote((*payload.begin())->getLoc()) << "payload op";
return diag;
}

if (getNumResults() != 1) {
DiagnosedSilenceableFailure diag = emitSilenceableError()
<< "op expects one result, given "
<< getNumResults();
diag.attachNote(target.getLoc()) << "payload op";
return diag;
}

scf::ForOp opResult;
if (failed(scf::parallelForToNestedFors(rewriter, target, &opResult))) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "failed to convert parallel into nested fors";
return diag;
}

results.set(cast<OpResult>(getTransformed()[0]), {opResult});
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// LoopOutlineOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
LoopPipelining.cpp
LoopRangeFolding.cpp
LoopSpecialization.cpp
ParallelForToNestedFors.cpp
ParallelLoopCollapsing.cpp
ParallelLoopFusion.cpp
ParallelLoopTiling.cpp
Expand Down
91 changes: 91 additions & 0 deletions mlir/lib/Dialect/SCF/Transforms/ParallelForToNestedFors.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
//===- ParallelForToNestedFors.cpp - scf.parallel to nested scf.for ops --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Transforms SCF.ParallelOp to nested scf.for ops.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"

namespace mlir {
#define GEN_PASS_DEF_SCFPARALLELFORTONESTEDFORS
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir

using namespace mlir;

LogicalResult mlir::scf::parallelForToNestedFors(RewriterBase &rewriter,
scf::ParallelOp parallelOp,
scf::ForOp *result) {

if (!parallelOp.getResults().empty()) {
parallelOp->emitError("Currently ScfParallel to ScfFor conversion "
"doesn't support ScfParallel with results.");
return failure();
}

rewriter.setInsertionPoint(parallelOp);

Location loc = parallelOp.getLoc();
auto lowerBounds = parallelOp.getLowerBound();
auto upperBounds = parallelOp.getUpperBound();
auto steps = parallelOp.getStep();

assert(lowerBounds.size() == upperBounds.size() &&
lowerBounds.size() == steps.size() &&
"Mismatched parallel loop bounds");

SmallVector<Value> ivs;
auto loopNest =
scf::buildLoopNest(rewriter, loc, lowerBounds, upperBounds, steps);

auto oldInductionVars = parallelOp.getInductionVars();
auto newInductionVars = llvm::map_to_vector(
loopNest.loops, [](scf::ForOp forOp) { return forOp.getInductionVar(); });
assert(oldInductionVars.size() == newInductionVars.size() &&
"Mismatched induction variables");
for (auto [oldIV, newIV] : llvm::zip(oldInductionVars, newInductionVars))
oldIV.replaceAllUsesWith(newIV);

auto *linearizedBody = loopNest.loops.back().getBody();
Block &parallelBody = *parallelOp.getBody();
for (Operation &op : llvm::make_early_inc_range(parallelBody)) {
// Skip the terminator of the parallelOp body.
if (&op == parallelBody.getTerminator())
continue;
op.moveBefore(linearizedBody->getTerminator());
}
rewriter.eraseOp(parallelOp);
if (result)
*result = loopNest.loops.front();
return success();
}

namespace {
struct ParallelForToNestedFors final
: public impl::SCFParallelForToNestedForsBase<ParallelForToNestedFors> {
void runOnOperation() override {
Operation *parentOp = getOperation();
IRRewriter rewriter(parentOp->getContext());

parentOp->walk([&](scf::ParallelOp parallelOp) {
if (failed(scf::parallelForToNestedFors(rewriter, parallelOp))) {
return signalPassFailure();
}
});
}
};
} // namespace

std::unique_ptr<Pass> mlir::createParallelForToNestedForsPass() {
return std::make_unique<ParallelForToNestedFors>();
}
80 changes: 80 additions & 0 deletions mlir/test/Dialect/SCF/parallel-to-nested-fors.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-parallel-for-to-nested-fors))' -split-input-file -verify-diagnostics | FileCheck %s

func.func private @callee(%i: index, %j: index)

func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
func.call @callee(%i, %j) : (index, index) -> ()
}
// CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
// CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
// CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
// CHECK: }
// CHECK: }
return
}

// -----

func.func private @callee(%i: index, %j: index)

func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
func.call @callee(%i, %j) : (index, index) -> ()
}

scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
func.call @callee(%i, %j) : (index, index) -> ()
}
// CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
// CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
// CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
// CHECK: }
// CHECK: }
// CHECK: scf.for %[[VAL_2:.*]] = %[[ARG0]] to %[[ARG2]] step %[[ARG4]] {
// CHECK: scf.for %[[VAL_3:.*]] = %[[ARG1]] to %[[ARG3]] step %[[ARG5]] {
// CHECK: func.call @callee(%[[VAL_2]], %[[VAL_3]]) : (index, index) -> ()
// CHECK: }
// CHECK: }

return
}

// -----

func.func private @callee(%i: index, %j: index, %k: index, %l: index)

func.func @nested(%lb1: index, %lb2: index, %lb3: index, %lb4: index, %ub1: index, %ub2: index, %ub3: index, %ub4: index, %step1: index, %step2: index, %step3: index, %step4: index) {
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
scf.parallel (%k, %l) = (%lb3, %lb4) to (%ub3, %ub4) step (%step3, %step4) {
func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
}
}
// CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG4:.*]] step %[[ARG8:.*]] {
// CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG5:.*]] step %[[ARG9:.*]] {
// CHECK: scf.for %[[VAL_2:.*]] = %[[ARG2:.*]] to %[[ARG6:.*]] step %[[ARG10:.*]] {
// CHECK: scf.for %[[VAL_3:.*]] = %[[ARG3:.*]] to %[[ARG7:.*]] step %[[ARG11:.*]] {
// CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) : (index, index, index, index) -> ()
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
return
}

// -----
func.func private @callee(%i: index, %j: index) -> i32

func.func @two_iters_with_reduce(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) -> i32 {
%c0 = arith.constant 0 : i32
// expected-error@+1 {{Currently ScfParallel to ScfFor conversion doesn't support ScfParallel with results}}
%0 = scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) init (%c0) -> i32 {
%curr = func.call @callee(%i, %j) : (index, index) -> i32
scf.reduce(%curr : i32) {
^bb0(%arg3: i32, %arg4: i32):
%3 = arith.addi %arg3, %arg4 : i32
scf.reduce.return %3 : i32
}
}
return %0 : i32
}
62 changes: 62 additions & 0 deletions mlir/test/Dialect/SCF/transform-op-parallel-to-nested-fors.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s

func.func private @callee(%i: index, %j: index)

func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
func.call @callee(%i, %j) : (index, index) -> ()
}
// CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
// CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
// CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
// CHECK: }
// CHECK: }
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}

// -----

func.func private @callee(%i: index, %j: index)

func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
func.call @callee(%i, %j) : (index, index) -> ()
}

scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
func.call @callee(%i, %j) : (index, index) -> ()
}

return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{expected a single payload op}}
transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}

// -----

// expected-note @below {{payload op}}
func.func private @callee(%i: index, %j: index)

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{expected the payload to be scf.parallel}}
transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}