Skip to content

Commit 9960e12

Browse files
ezhulenevmemfrob
authored andcommitted
[mlir] Implement lowering to LLVM of async.execute ops with token dependencies
Add support for lowering `async.execute` operations with token dependencies Example: ``` %dep = ... : !async.token %token = async.execute[%dep] { ... } ``` Token dependencies lowered to `async.await` operations inside the outline coroutine body. Reviewed By: herhut, mehdi_amini, ftynse Differential Revision: https://reviews.llvm.org/D89958
1 parent e569adf commit 9960e12

File tree

3 files changed

+85
-26
lines changed

3 files changed

+85
-26
lines changed

mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -462,14 +462,15 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
462462

463463
OpBuilder moduleBuilder(module.getBody()->getTerminator());
464464

465-
// Get values captured by the async region
466-
llvm::SetVector<mlir::Value> usedAbove;
467-
getUsedValuesDefinedAbove(execute.body(), usedAbove);
468-
469-
// Collect types of the captured values.
470-
auto usedAboveTypes =
471-
llvm::map_range(usedAbove, [](Value value) { return value.getType(); });
472-
SmallVector<Type, 4> inputTypes(usedAboveTypes.begin(), usedAboveTypes.end());
465+
// Collect all outlined function inputs.
466+
llvm::SetVector<mlir::Value> functionInputs(execute.dependencies().begin(),
467+
execute.dependencies().end());
468+
getUsedValuesDefinedAbove(execute.body(), functionInputs);
469+
470+
// Collect types for the outlined function inputs and outputs.
471+
auto typesRange = llvm::map_range(
472+
functionInputs, [](Value value) { return value.getType(); });
473+
SmallVector<Type, 4> inputTypes(typesRange.begin(), typesRange.end());
473474
auto outputTypes = execute.getResultTypes();
474475

475476
auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes);
@@ -510,14 +511,19 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
510511
Block *resume = addSuspensionPoint(coro, coroSave.getResult(0),
511512
entryBlock->getTerminator());
512513

513-
// Map from values defined above the execute op to the function arguments.
514+
// Await on all dependencies before starting to execute the body region.
515+
builder.setInsertionPointToStart(resume);
516+
for (size_t i = 0; i < execute.dependencies().size(); ++i)
517+
builder.create<AwaitOp>(loc, func.getArgument(i));
518+
519+
// Map from function inputs defined above the execute op to the function
520+
// arguments.
514521
BlockAndValueMapping valueMapping;
515-
valueMapping.map(usedAbove, func.getArguments());
522+
valueMapping.map(functionInputs, func.getArguments());
516523

517524
// Clone all operations from the execute operation body into the outlined
518525
// function body, and replace all `async.yield` operations with a call
519526
// to async runtime to emplace the result token.
520-
builder.setInsertionPointToStart(resume);
521527
for (Operation &op : execute.body().getOps()) {
522528
if (isa<async::YieldOp>(op)) {
523529
builder.create<CallOp>(loc, kEmplaceToken, Type(), coro.asyncToken);
@@ -528,9 +534,9 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
528534

529535
// Replace the original `async.execute` with a call to outlined function.
530536
OpBuilder callBuilder(execute);
531-
SmallVector<Value, 4> usedAboveArgs(usedAbove.begin(), usedAbove.end());
532-
auto callOutlinedFunc = callBuilder.create<CallOp>(
533-
loc, func.getName(), execute.getResultTypes(), usedAboveArgs);
537+
auto callOutlinedFunc =
538+
callBuilder.create<CallOp>(loc, func.getName(), execute.getResultTypes(),
539+
functionInputs.getArrayRef());
534540
execute.replaceAllUsesWith(callOutlinedFunc.getResults());
535541
execute.erase();
536542

@@ -673,13 +679,11 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
673679
llvm::DenseMap<FuncOp, CoroMachinery> outlinedFunctions;
674680

675681
WalkResult outlineResult = module.walk([&](ExecuteOp execute) {
676-
// We currently do not support execute operations that take async
677-
// token dependencies, async value arguments or produce async results.
678-
if (!execute.dependencies().empty() || !execute.operands().empty() ||
679-
!execute.results().empty()) {
680-
execute.emitOpError(
681-
"Can't outline async.execute op with async dependencies, arguments "
682-
"or returned async results");
682+
// We currently do not support execute operations that have async value
683+
// operands or produce async results.
684+
if (!execute.operands().empty() || !execute.results().empty()) {
685+
execute.emitOpError("can't outline async.execute op with async value "
686+
"operands or returned async results");
683687
return WalkResult::interrupt();
684688
}
685689

mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
1515
}
1616

1717
// Function outlined from the async.execute operation.
18-
// CHECK: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
18+
// CHECK-LABEL: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
1919
// CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
2020

2121
// Create token for return op, and mark a function as a coroutine.
@@ -79,7 +79,7 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
7979
}
8080

8181
// Function outlined from the inner async.execute operation.
82-
// CHECK: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>, %arg2: index)
82+
// CHECK-LABEL: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>, %arg2: index)
8383
// CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
8484
// CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken()
8585
// CHECK: %[[HDL_0:.*]] = llvm.call @llvm.coro.begin
@@ -89,7 +89,7 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
8989
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]])
9090

9191
// Function outlined from the outer async.execute operation.
92-
// CHECK: func @async_execute_fn_0(%arg0: f32, %arg1: memref<1xf32>, %arg2: f32)
92+
// CHECK-LABEL: func @async_execute_fn_0(%arg0: f32, %arg1: memref<1xf32>, %arg2: f32)
9393
// CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
9494
// CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken()
9595
// CHECK: %[[HDL_1:.*]] = llvm.call @llvm.coro.begin
@@ -108,4 +108,52 @@ func @nested_async_execute(%arg0: f32, %arg1: f32, %arg2: memref<1xf32>) {
108108
// CHECK: store %arg2, %arg1[%c0] : memref<1xf32>
109109
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]])
110110

111+
// -----
112+
113+
// CHECK-LABEL: async_execute_token_dependency
114+
func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) {
115+
// CHECK: %0 = call @async_execute_fn(%arg0, %arg1)
116+
%token = async.execute {
117+
%c0 = constant 0 : index
118+
store %arg0, %arg1[%c0] : memref<1xf32>
119+
async.yield
120+
}
121+
// CHECK: %1 = call @async_execute_fn_0(%0, %arg0, %arg1)
122+
%token_0 = async.execute [%token] {
123+
%c0 = constant 0 : index
124+
store %arg0, %arg1[%c0] : memref<1xf32>
125+
async.yield
126+
}
127+
return
128+
}
129+
130+
// Function outlined from the first async.execute operation.
131+
// CHECK-LABEL: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>)
132+
// CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
133+
// CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken()
134+
// CHECK: %[[HDL_0:.*]] = llvm.call @llvm.coro.begin
135+
// CHECK: call @mlirAsyncRuntimeExecute
136+
// CHECK: llvm.call @llvm.coro.suspend
137+
// CHECK: store %arg0, %arg1[%c0] : memref<1xf32>
138+
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]])
139+
140+
// Function outlined from the second async.execute operation with dependency.
141+
// CHECK-LABEL: func @async_execute_fn_0(%arg0: !llvm.ptr<i8>, %arg1: f32, %arg2: memref<1xf32>)
142+
// CHECK-SAME: -> !llvm.ptr<i8> attributes {sym_visibility = "private"}
143+
// CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken()
144+
// CHECK: %[[HDL_1:.*]] = llvm.call @llvm.coro.begin
145+
146+
// Suspend coroutine in the beginning.
147+
// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL_1]],
148+
// CHECK: llvm.call @llvm.coro.suspend
149+
150+
// Suspend coroutine second time waiting for the completion of token dependency.
151+
// CHECK: llvm.call @llvm.coro.save
152+
// CHECK: call @mlirAsyncRuntimeAwaitTokenAndExecute(%arg0, %[[HDL_1]],
153+
// CHECK: llvm.call @llvm.coro.suspend
154+
155+
// Emplace result token after second resumption.
156+
// CHECK: store %arg1, %arg2[%c0] : memref<1xf32>
157+
// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]])
158+
111159

mlir/test/mlir-cpu-runner/async.mlir

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,15 @@ func @main() {
4141
call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
4242
call @print_memref_f32(%U): (memref<*xf32>) -> ()
4343

44-
%inner = async.execute {
44+
// No op async region to create a token for testing async dependency.
45+
%noop = async.execute {
4546
// CHECK: Current thread id: [[THREAD1:.*]]
47+
call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
48+
async.yield
49+
}
50+
51+
%inner = async.execute [%noop] {
52+
// CHECK: Current thread id: [[THREAD2:.*]]
4653
// CHECK: [1, 2, 3, 0]
4754
store %c3, %A[%i2]: memref<4xf32>
4855
call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()
@@ -52,7 +59,7 @@ func @main() {
5259
}
5360
async.await %inner : !async.token
5461

55-
// CHECK: Current thread id: [[THREAD2:.*]]
62+
// CHECK: Current thread id: [[THREAD3:.*]]
5663
// CHECK: [1, 2, 3, 4]
5764
store %c4, %A[%i3]: memref<4xf32>
5865
call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()

0 commit comments

Comments
 (0)