@@ -462,14 +462,15 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
462
462
463
463
OpBuilder moduleBuilder (module .getBody ()->getTerminator ());
464
464
465
- // Get values captured by the async region
466
- llvm::SetVector<mlir::Value> usedAbove;
467
- getUsedValuesDefinedAbove (execute.body (), usedAbove);
468
-
469
- // Collect types of the captured values.
470
- auto usedAboveTypes =
471
- llvm::map_range (usedAbove, [](Value value) { return value.getType (); });
472
- SmallVector<Type, 4 > inputTypes (usedAboveTypes.begin (), usedAboveTypes.end ());
465
+ // Collect all outlined function inputs.
466
+ llvm::SetVector<mlir::Value> functionInputs (execute.dependencies ().begin (),
467
+ execute.dependencies ().end ());
468
+ getUsedValuesDefinedAbove (execute.body (), functionInputs);
469
+
470
+ // Collect types for the outlined function inputs and outputs.
471
+ auto typesRange = llvm::map_range (
472
+ functionInputs, [](Value value) { return value.getType (); });
473
+ SmallVector<Type, 4 > inputTypes (typesRange.begin (), typesRange.end ());
473
474
auto outputTypes = execute.getResultTypes ();
474
475
475
476
auto funcType = moduleBuilder.getFunctionType (inputTypes, outputTypes);
@@ -510,14 +511,19 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
510
511
Block *resume = addSuspensionPoint (coro, coroSave.getResult (0 ),
511
512
entryBlock->getTerminator ());
512
513
513
- // Map from values defined above the execute op to the function arguments.
514
+ // Await on all dependencies before starting to execute the body region.
515
+ builder.setInsertionPointToStart (resume);
516
+ for (size_t i = 0 ; i < execute.dependencies ().size (); ++i)
517
+ builder.create <AwaitOp>(loc, func.getArgument (i));
518
+
519
+ // Map from function inputs defined above the execute op to the function
520
+ // arguments.
514
521
BlockAndValueMapping valueMapping;
515
- valueMapping.map (usedAbove , func.getArguments ());
522
+ valueMapping.map (functionInputs , func.getArguments ());
516
523
517
524
// Clone all operations from the execute operation body into the outlined
518
525
// function body, and replace all `async.yield` operations with a call
519
526
// to async runtime to emplace the result token.
520
- builder.setInsertionPointToStart (resume);
521
527
for (Operation &op : execute.body ().getOps ()) {
522
528
if (isa<async::YieldOp>(op)) {
523
529
builder.create <CallOp>(loc, kEmplaceToken , Type (), coro.asyncToken );
@@ -528,9 +534,9 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
528
534
529
535
// Replace the original `async.execute` with a call to outlined function.
530
536
OpBuilder callBuilder (execute);
531
- SmallVector<Value, 4 > usedAboveArgs (usedAbove. begin (), usedAbove. end ());
532
- auto callOutlinedFunc = callBuilder.create <CallOp>(
533
- loc, func. getName (), execute. getResultTypes (), usedAboveArgs );
537
+ auto callOutlinedFunc =
538
+ callBuilder.create <CallOp>(loc, func. getName (), execute. getResultTypes (),
539
+ functionInputs. getArrayRef () );
534
540
execute.replaceAllUsesWith (callOutlinedFunc.getResults ());
535
541
execute.erase ();
536
542
@@ -673,13 +679,11 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
673
679
llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
674
680
675
681
WalkResult outlineResult = module .walk ([&](ExecuteOp execute) {
676
- // We currently do not support execute operations that take async
677
- // token dependencies, async value arguments or produce async results.
678
- if (!execute.dependencies ().empty () || !execute.operands ().empty () ||
679
- !execute.results ().empty ()) {
680
- execute.emitOpError (
681
- " Can't outline async.execute op with async dependencies, arguments "
682
- " or returned async results" );
682
+ // We currently do not support execute operations that have async value
683
+ // operands or produce async results.
684
+ if (!execute.operands ().empty () || !execute.results ().empty ()) {
685
+ execute.emitOpError (" can't outline async.execute op with async value "
686
+ " operands or returned async results" );
683
687
return WalkResult::interrupt ();
684
688
}
685
689
0 commit comments