diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp index a2f03f1e1056e..a1df366cef132 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp @@ -40,12 +40,15 @@ mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp, SmallVector ivs = llvm::map_to_vector( loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); }); + SmallVector replacementVals = ivs; + for (Value shared : forallOp.getOutputs()) + replacementVals.push_back(shared); Block *innermostBlock = loopNest.loops.back().getBody(); rewriter.eraseOp(forallOp.getBody()->getTerminator()); rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock, innermostBlock->getTerminator()->getIterator(), - ivs); - rewriter.eraseOp(forallOp); + replacementVals); + rewriter.replaceOp(forallOp, forallOp.getOutputs()); if (results) { llvm::move(loopNest.loops, std::back_inserter(*results)); diff --git a/mlir/test/Dialect/SCF/forall-to-for.mlir b/mlir/test/Dialect/SCF/forall-to-for.mlir index e7d183fb9d2b5..17598a154fefd 100644 --- a/mlir/test/Dialect/SCF/forall-to-for.mlir +++ b/mlir/test/Dialect/SCF/forall-to-for.mlir @@ -55,3 +55,26 @@ func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) { } return } + +// ----- + + func.func @parallel_insert_slice(%arg0: tensor<100xf32>) -> tensor<100xf32> { + %c100 = arith.constant 100 : index + %res = scf.forall (%i) in (%c100) shared_outs(%s = %arg0) -> (tensor<100xf32>) { + %t = "test.foo"() : () -> tensor<100xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %t into %s[%i] [100] [1] : tensor<100xf32> into tensor<100xf32> + } + } + return %res : tensor<100xf32> + } +// CHECK-LABEL: func.func @parallel_insert_slice( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<100xf32>) -> tensor<100xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant 100 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] { +// CHECK: %[[VAL_5:.*]] = "test.foo"() : () -> tensor<100xf32> +// CHECK: } +// CHECK: return %[[VAL_0]] : tensor<100xf32> +// CHECK: } \ No newline at end of file