19
19
#include " llvm/ADT/TypeSwitch.h"
20
20
#include " llvm/Support/raw_ostream.h"
21
21
22
- using namespace mlir ;
23
- using namespace mlir :: async;
22
+ namespace mlir {
23
+ namespace async {
24
24
25
25
void AsyncDialect::initialize () {
26
26
addOperations<
27
27
#define GET_OP_LIST
28
28
#include " mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
29
29
>();
30
30
addTypes<TokenType>();
31
+ addTypes<ValueType>();
31
32
}
32
33
33
34
// / Parse a type registered to this dialect.
@@ -39,16 +40,129 @@ Type AsyncDialect::parseType(DialectAsmParser &parser) const {
39
40
if (keyword == " token" )
40
41
return TokenType::get (getContext ());
41
42
43
+ if (keyword == " value" ) {
44
+ Type ty;
45
+ if (parser.parseLess () || parser.parseType (ty) || parser.parseGreater ()) {
46
+ parser.emitError (parser.getNameLoc (), " failed to parse async value type" );
47
+ return Type ();
48
+ }
49
+ return ValueType::get (ty);
50
+ }
51
+
42
52
parser.emitError (parser.getNameLoc (), " unknown async type: " ) << keyword;
43
53
return Type ();
44
54
}
45
55
46
56
// / Print a type registered to this dialect.
47
57
void AsyncDialect::printType (Type type, DialectAsmPrinter &os) const {
48
58
TypeSwitch<Type>(type)
49
- .Case <TokenType>([&](Type) { os << " token" ; })
59
+ .Case <TokenType>([&](TokenType) { os << " token" ; })
60
+ .Case <ValueType>([&](ValueType valueTy) {
61
+ os << " value<" ;
62
+ os.printType (valueTy.getValueType ());
63
+ os << ' >' ;
64
+ })
50
65
.Default ([](Type) { llvm_unreachable (" unexpected 'async' type kind" ); });
51
66
}
52
67
68
+ // ===----------------------------------------------------------------------===//
69
+ // / ValueType
70
+ // ===----------------------------------------------------------------------===//
71
+
72
+ namespace detail {
73
+
74
+ // Storage for `async.value<T>` type, the only member is the wrapped type.
75
+ struct ValueTypeStorage : public TypeStorage {
76
+ ValueTypeStorage (Type valueType) : valueType(valueType) {}
77
+
78
+ // / The hash key used for uniquing.
79
+ using KeyTy = Type;
80
+ bool operator ==(const KeyTy &key) const { return key == valueType; }
81
+
82
+ // / Construction.
83
+ static ValueTypeStorage *construct (TypeStorageAllocator &allocator,
84
+ Type valueType) {
85
+ return new (allocator.allocate <ValueTypeStorage>())
86
+ ValueTypeStorage (valueType);
87
+ }
88
+
89
+ Type valueType;
90
+ };
91
+
92
+ } // namespace detail
93
+
94
+ ValueType ValueType::get (Type valueType) {
95
+ return Base::get (valueType.getContext (), valueType);
96
+ }
97
+
98
+ Type ValueType::getValueType () { return getImpl ()->valueType ; }
99
+
100
+ // ===----------------------------------------------------------------------===//
101
+ // YieldOp
102
+ // ===----------------------------------------------------------------------===//
103
+
104
+ static LogicalResult verify (YieldOp op) {
105
+ // Get the underlying value types from async values returned from the
106
+ // parent `async.execute` operation.
107
+ auto executeOp = op.getParentOfType <ExecuteOp>();
108
+ auto types = llvm::map_range (executeOp.values (), [](const OpResult &result) {
109
+ return result.getType ().cast <ValueType>().getValueType ();
110
+ });
111
+
112
+ if (!std::equal (types.begin (), types.end (), op.getOperandTypes ().begin ()))
113
+ return op.emitOpError (" Operand types do not match the types returned from "
114
+ " the parent ExecuteOp" );
115
+
116
+ return success ();
117
+ }
118
+
119
+ // ===----------------------------------------------------------------------===//
120
+ // / ExecuteOp
121
+ // ===----------------------------------------------------------------------===//
122
+
123
+ static void print (OpAsmPrinter &p, ExecuteOp op) {
124
+ p << " async.execute " ;
125
+ p.printRegion (op.body ());
126
+ p.printOptionalAttrDict (op.getAttrs ());
127
+ p << " : " ;
128
+ p.printType (op.done ().getType ());
129
+ if (!op.values ().empty ())
130
+ p << " , " ;
131
+ llvm::interleaveComma (op.values (), p, [&](const OpResult &result) {
132
+ p.printType (result.getType ());
133
+ });
134
+ }
135
+
136
+ static ParseResult parseExecuteOp (OpAsmParser &parser, OperationState &result) {
137
+ MLIRContext *ctx = result.getContext ();
138
+
139
+ // Parse asynchronous region.
140
+ Region *body = result.addRegion ();
141
+ if (parser.parseRegion (*body, /* arguments=*/ {}, /* argTypes=*/ {},
142
+ /* enableNameShadowing=*/ false ))
143
+ return failure ();
144
+
145
+ // Parse operation attributes.
146
+ NamedAttrList attrs;
147
+ if (parser.parseOptionalAttrDict (attrs))
148
+ return failure ();
149
+ result.addAttributes (attrs);
150
+
151
+ // Parse result types.
152
+ SmallVector<Type, 4 > resultTypes;
153
+ if (parser.parseColonTypeList (resultTypes))
154
+ return failure ();
155
+
156
+ // First result type must be an async token type.
157
+ if (resultTypes.empty () || resultTypes.front () != TokenType::get (ctx))
158
+ return failure ();
159
+ parser.addTypesToList (resultTypes, result.types );
160
+
161
+ return success ();
162
+ }
163
+
164
+ } // namespace async
165
+ } // namespace mlir
166
+
53
167
#define GET_OP_CLASSES
54
168
#include " mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
0 commit comments