Skip to content

Commit 655af65

Browse files
committed
[MLIR] Add async.value type to Async dialect
Return values from async regions as !async.value<...>. Reviewed By: mehdi_amini, csigg Differential Revision: https://reviews.llvm.org/D88510
1 parent c3193e4 commit 655af65

File tree

5 files changed

+198
-16
lines changed

5 files changed

+198
-16
lines changed

mlir/include/mlir/Dialect/Async/IR/Async.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,28 @@
2222
namespace mlir {
2323
namespace async {
2424

25+
namespace detail {
26+
struct ValueTypeStorage;
27+
} // namespace detail
28+
2529
/// The token type to represent asynchronous operation completion.
2630
class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
2731
public:
2832
using Base::Base;
2933
};
3034

35+
/// The value type to represent values returned from asynchronous operations.
36+
class ValueType
37+
: public Type::TypeBase<ValueType, Type, detail::ValueTypeStorage> {
38+
public:
39+
using Base::Base;
40+
41+
/// Get or create an async ValueType with the provided value type.
42+
static ValueType get(Type valueType);
43+
44+
Type getValueType();
45+
};
46+
3147
} // namespace async
3248
} // namespace mlir
3349

mlir/include/mlir/Dialect/Async/IR/AsyncBase.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,24 @@ def Async_TokenType : DialectType<AsyncDialect,
3939
}];
4040
}
4141

42+
class Async_ValueType<Type type>
43+
: DialectType<AsyncDialect,
44+
And<[
45+
CPred<"$_self.isa<::mlir::async::ValueType>()">,
46+
SubstLeaves<"$_self",
47+
"$_self.cast<::mlir::async::ValueType>().getValueType()",
48+
type.predicate>
49+
]>, "async value type with " # type.description # " underlying type"> {
50+
let typeDescription = [{
51+
`async.value` represents a value returned by asynchronous operations,
52+
which may or may not be available currently, but will be available at some
53+
point in the future.
54+
}];
55+
56+
Type valueType = type;
57+
}
58+
59+
def Async_AnyValueType : Type<CPred<"$_self.isa<::mlir::async::ValueType>()">,
60+
"async value type">;
61+
4262
#endif // ASYNC_BASE_TD

mlir/include/mlir/Dialect/Async/IR/AsyncOps.td

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,24 @@ def Async_ExecuteOp : Async_Op<"execute"> {
4040
state). All dependencies must be made explicit with async execute arguments
4141
(`async.token` or `async.value`).
4242

43-
Example:
44-
4543
```mlir
46-
%0 = async.execute {
47-
"compute0"(...)
48-
async.yield
49-
} : !async.token
44+
%done, %values = async.execute {
45+
%0 = "compute0"(...) : !some.type
46+
async.yield %1 : f32
47+
} : !async.token, !async.value<!some.type>
5048

51-
%1 = "compute1"(...)
49+
%1 = "compute1"(...) : !some.type
5250
```
5351
}];
5452

5553
// TODO: Take async.tokens/async.values as arguments.
5654
let arguments = (ins );
57-
let results = (outs Async_TokenType:$done);
55+
let results = (outs Async_TokenType:$done,
56+
Variadic<Async_AnyValueType>:$values);
5857
let regions = (region SizedRegion<1>:$body);
5958

60-
let assemblyFormat = "$body attr-dict `:` type($done)";
59+
let printer = [{ return ::mlir::async::print(p, *this); }];
60+
let parser = [{ return ::mlir::async::parse$cppClass(parser, result); }];
6161
}
6262

6363
def Async_YieldOp :
@@ -71,6 +71,8 @@ def Async_YieldOp :
7171
let arguments = (ins Variadic<AnyType>:$operands);
7272

7373
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
74+
75+
let verifier = [{ return ::mlir::async::verify(*this); }];
7476
}
7577

7678
#endif // ASYNC_OPS

mlir/lib/Dialect/Async/IR/Async.cpp

Lines changed: 117 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,16 @@
1919
#include "llvm/ADT/TypeSwitch.h"
2020
#include "llvm/Support/raw_ostream.h"
2121

22-
using namespace mlir;
23-
using namespace mlir::async;
22+
namespace mlir {
23+
namespace async {
2424

2525
void AsyncDialect::initialize() {
2626
addOperations<
2727
#define GET_OP_LIST
2828
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
2929
>();
3030
addTypes<TokenType>();
31+
addTypes<ValueType>();
3132
}
3233

3334
/// Parse a type registered to this dialect.
@@ -39,16 +40,129 @@ Type AsyncDialect::parseType(DialectAsmParser &parser) const {
3940
if (keyword == "token")
4041
return TokenType::get(getContext());
4142

43+
if (keyword == "value") {
44+
Type ty;
45+
if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
46+
parser.emitError(parser.getNameLoc(), "failed to parse async value type");
47+
return Type();
48+
}
49+
return ValueType::get(ty);
50+
}
51+
4252
parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword;
4353
return Type();
4454
}
4555

4656
/// Print a type registered to this dialect.
4757
void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
4858
TypeSwitch<Type>(type)
49-
.Case<TokenType>([&](Type) { os << "token"; })
59+
.Case<TokenType>([&](TokenType) { os << "token"; })
60+
.Case<ValueType>([&](ValueType valueTy) {
61+
os << "value<";
62+
os.printType(valueTy.getValueType());
63+
os << '>';
64+
})
5065
.Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); });
5166
}
5267

68+
//===----------------------------------------------------------------------===//
69+
/// ValueType
70+
//===----------------------------------------------------------------------===//
71+
72+
namespace detail {
73+
74+
// Storage for `async.value<T>` type, the only member is the wrapped type.
75+
struct ValueTypeStorage : public TypeStorage {
76+
ValueTypeStorage(Type valueType) : valueType(valueType) {}
77+
78+
/// The hash key used for uniquing.
79+
using KeyTy = Type;
80+
bool operator==(const KeyTy &key) const { return key == valueType; }
81+
82+
/// Construction.
83+
static ValueTypeStorage *construct(TypeStorageAllocator &allocator,
84+
Type valueType) {
85+
return new (allocator.allocate<ValueTypeStorage>())
86+
ValueTypeStorage(valueType);
87+
}
88+
89+
Type valueType;
90+
};
91+
92+
} // namespace detail
93+
94+
ValueType ValueType::get(Type valueType) {
95+
return Base::get(valueType.getContext(), valueType);
96+
}
97+
98+
Type ValueType::getValueType() { return getImpl()->valueType; }
99+
100+
//===----------------------------------------------------------------------===//
101+
// YieldOp
102+
//===----------------------------------------------------------------------===//
103+
104+
static LogicalResult verify(YieldOp op) {
105+
// Get the underlying value types from async values returned from the
106+
// parent `async.execute` operation.
107+
auto executeOp = op.getParentOfType<ExecuteOp>();
108+
auto types = llvm::map_range(executeOp.values(), [](const OpResult &result) {
109+
return result.getType().cast<ValueType>().getValueType();
110+
});
111+
112+
if (!std::equal(types.begin(), types.end(), op.getOperandTypes().begin()))
113+
return op.emitOpError("Operand types do not match the types returned from "
114+
"the parent ExecuteOp");
115+
116+
return success();
117+
}
118+
119+
//===----------------------------------------------------------------------===//
120+
/// ExecuteOp
121+
//===----------------------------------------------------------------------===//
122+
123+
static void print(OpAsmPrinter &p, ExecuteOp op) {
124+
p << "async.execute ";
125+
p.printRegion(op.body());
126+
p.printOptionalAttrDict(op.getAttrs());
127+
p << " : ";
128+
p.printType(op.done().getType());
129+
if (!op.values().empty())
130+
p << ", ";
131+
llvm::interleaveComma(op.values(), p, [&](const OpResult &result) {
132+
p.printType(result.getType());
133+
});
134+
}
135+
136+
static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
137+
MLIRContext *ctx = result.getContext();
138+
139+
// Parse asynchronous region.
140+
Region *body = result.addRegion();
141+
if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{},
142+
/*enableNameShadowing=*/false))
143+
return failure();
144+
145+
// Parse operation attributes.
146+
NamedAttrList attrs;
147+
if (parser.parseOptionalAttrDict(attrs))
148+
return failure();
149+
result.addAttributes(attrs);
150+
151+
// Parse result types.
152+
SmallVector<Type, 4> resultTypes;
153+
if (parser.parseColonTypeList(resultTypes))
154+
return failure();
155+
156+
// First result type must be an async token type.
157+
if (resultTypes.empty() || resultTypes.front() != TokenType::get(ctx))
158+
return failure();
159+
parser.addTypesToList(resultTypes, result.types);
160+
161+
return success();
162+
}
163+
164+
} // namespace async
165+
} // namespace mlir
166+
53167
#define GET_OP_CLASSES
54168
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"

mlir/test/Dialect/Async/ops.mlir

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,46 @@
11
// RUN: mlir-opt %s | FileCheck %s
22

3-
// CHECK-LABEL: @identity
4-
func @identity(%arg0 : !async.token) -> !async.token {
3+
// CHECK-LABEL: @identity_token
4+
func @identity_token(%arg0 : !async.token) -> !async.token {
55
// CHECK: return %arg0 : !async.token
66
return %arg0 : !async.token
77
}
88

9+
// CHECK-LABEL: @identity_value
10+
func @identity_value(%arg0 : !async.value<f32>) -> !async.value<f32> {
11+
// CHECK: return %arg0 : !async.value<f32>
12+
return %arg0 : !async.value<f32>
13+
}
14+
915
// CHECK-LABEL: @empty_async_execute
1016
func @empty_async_execute() -> !async.token {
11-
%0 = async.execute {
17+
%done = async.execute {
1218
async.yield
1319
} : !async.token
1420

15-
return %0 : !async.token
21+
// CHECK: return %done : !async.token
22+
return %done : !async.token
23+
}
24+
25+
// CHECK-LABEL: @return_async_value
26+
func @return_async_value() -> !async.value<f32> {
27+
%done, %values = async.execute {
28+
%cst = constant 1.000000e+00 : f32
29+
async.yield %cst : f32
30+
} : !async.token, !async.value<f32>
31+
32+
// CHECK: return %values : !async.value<f32>
33+
return %values : !async.value<f32>
34+
}
35+
36+
// CHECK-LABEL: @return_async_values
37+
func @return_async_values() -> (!async.value<f32>, !async.value<f32>) {
38+
%done, %values:2 = async.execute {
39+
%cst1 = constant 1.000000e+00 : f32
40+
%cst2 = constant 2.000000e+00 : f32
41+
async.yield %cst1, %cst2 : f32, f32
42+
} : !async.token, !async.value<f32>, !async.value<f32>
43+
44+
// CHECK: return %values#0, %values#1 : !async.value<f32>, !async.value<f32>
45+
return %values#0, %values#1 : !async.value<f32>, !async.value<f32>
1646
}

0 commit comments

Comments
 (0)