2
2
3
3
#include " iterators/Dialect/Iterators/IR/Iterators.h"
4
4
#include " iterators/Utils/NameAssigner.h"
5
- #include " mlir/Dialect/LLVMIR/LLVMTypes.h"
6
5
#include " mlir/IR/BuiltinAttributes.h"
7
6
#include " mlir/Transforms/DialectConversion.h"
8
7
#include " llvm/ADT/TypeSwitch.h"
9
8
10
9
using namespace mlir ;
11
10
using namespace mlir ::iterators;
12
- using namespace LLVM ;
13
11
14
12
using SymbolTriple = std::tuple<SymbolRefAttr, SymbolRefAttr, SymbolRefAttr>;
15
13
@@ -52,8 +50,8 @@ class StateTypeComputer {
52
50
// / Computes the state type of the given op whose upstream iterator ops have
53
51
// / the state types given in upstreamStateTypes.
54
52
template <typename OpType>
55
- LLVMStructType
56
- operator ()(OpType op, llvm::SmallVector<LLVMStructType > upstreamStateTypes);
53
+ StateType operator ()(OpType op,
54
+ llvm::SmallVector<StateType > upstreamStateTypes);
57
55
58
56
private:
59
57
TypeConverter typeConverter;
@@ -62,41 +60,41 @@ class StateTypeComputer {
62
60
// / The state of ConstantStreamOp consists of a single number that corresponds
63
61
// / to the index of the next struct returned by the iterator.
64
62
template <>
65
- LLVMStructType StateTypeComputer::operator ()(
66
- ConstantStreamOp op,
67
- llvm::SmallVector<LLVMStructType> /* upstreamStateTypes*/ ) {
63
+ StateType StateTypeComputer::operator ()(
64
+ ConstantStreamOp op, llvm::SmallVector<StateType> /* upstreamStateTypes*/ ) {
68
65
MLIRContext *context = op->getContext ();
69
66
Type i32 = IntegerType::get (context, /* width=*/ 32 );
70
- return LLVMStructType::getNewIdentified (
71
- context, " iterators.constant_stream_state" , {i32 });
67
+ return StateType::get (context, {i32 });
72
68
}
73
69
74
70
// / The state of FilterOp only consists of the state of its upstream iterator,
75
71
// / i.e., the state of the iterator that produces its input stream.
76
72
template <>
77
- LLVMStructType StateTypeComputer::operator ()(
78
- FilterOp op, llvm::SmallVector<LLVMStructType> upstreamStateTypes) {
79
- return LLVMStructType::getNewIdentified (
80
- op->getContext (), " iterators.filter_state" , {upstreamStateTypes[0 ]});
73
+ StateType
74
+ StateTypeComputer::operator ()(FilterOp op,
75
+ llvm::SmallVector<StateType> upstreamStateTypes) {
76
+ MLIRContext *context = op->getContext ();
77
+ return StateType::get (context, {upstreamStateTypes[0 ]});
81
78
}
82
79
83
80
// / The state of MapOp only consists of the state of its upstream iterator,
84
81
// / i.e., the state of the iterator that produces its input stream.
85
82
template <>
86
- LLVMStructType StateTypeComputer::operator ()(
87
- MapOp op, llvm::SmallVector<LLVMStructType> upstreamStateTypes) {
88
- return LLVMStructType::getNewIdentified (
89
- op->getContext (), " iterators.map_state" , {upstreamStateTypes[0 ]});
83
+ StateType
84
+ StateTypeComputer::operator ()(MapOp op,
85
+ llvm::SmallVector<StateType> upstreamStateTypes) {
86
+ MLIRContext *context = op->getContext ();
87
+ return StateType::get (context, {upstreamStateTypes[0 ]});
90
88
}
91
89
92
90
// / The state of ReduceOp only consists of the state of its upstream iterator,
93
91
// / i.e., the state of the iterator that produces its input stream.
94
92
template <>
95
- LLVMStructType StateTypeComputer::operator ()(
96
- ReduceOp op, llvm::SmallVector<LLVMStructType> upstreamStateTypes) {
97
- assert (upstreamStateTypes. size () == 1 );
98
- return LLVMStructType::getNewIdentified (
99
- op-> getContext (), " iterators.reduce_state " , {upstreamStateTypes[0 ]});
93
+ StateType
94
+ StateTypeComputer::operator ()(ReduceOp op,
95
+ llvm::SmallVector<StateType> upstreamStateTypes) {
96
+ MLIRContext *context = op-> getContext ();
97
+ return StateType::get (context , {upstreamStateTypes[0 ]});
100
98
}
101
99
102
100
// / The state of TabularViewToStreamOp consists of a single number that
@@ -106,23 +104,21 @@ LLVMStructType StateTypeComputer::operator()(
106
104
// / template <typename TabularViewType>
107
105
// / struct { int64_t currentIndex; TabularViewType view; }
108
106
template <>
109
- LLVMStructType StateTypeComputer::operator ()(
107
+ StateType StateTypeComputer::operator ()(
110
108
TabularViewToStreamOp op,
111
- llvm::SmallVector<LLVM::LLVMStructType > /* upstreamStateTypes*/ ) {
109
+ llvm::SmallVector<StateType > /* upstreamStateTypes*/ ) {
112
110
MLIRContext *context = op->getContext ();
113
- Type i64 = IntegerType::get (context, /* width=*/ 64 );
111
+ Type indexType = IntegerType::get (context, /* width=*/ 64 );
114
112
Type viewType = typeConverter.convertType (op.input ().getType ());
115
- return LLVM::LLVMStructType::getNewIdentified (
116
- op->getContext (), " iterators.tabular_view_to_stream_state" ,
117
- {i64 , viewType});
113
+ return StateType::get (context, {indexType, viewType});
118
114
}
119
115
120
116
// / Build IteratorInfo, assigning new unique names as needed. Takes the
121
- // / `LLVMStructType ` as a parameter, to ensure proper build order (all uses are
117
+ // / `StateType ` as a parameter, to ensure proper build order (all uses are
122
118
// / visited before any def).
123
119
mlir::iterators::IteratorInfo::IteratorInfo (IteratorOpInterface op,
124
120
NameAssigner &nameAssigner,
125
- LLVMStructType t) {
121
+ StateType t) {
126
122
std::tie (openFunc, nextFunc, closeFunc) =
127
123
assignFunctionNames (op, nameAssigner);
128
124
stateType = t;
@@ -166,16 +162,16 @@ mlir::iterators::IteratorAnalysis::IteratorAnalysis(
166
162
TabularViewToStreamOp
167
163
// clang-format on
168
164
>([&](auto op) {
169
- llvm::SmallVector<LLVMStructType > upstreamStateTypes;
165
+ llvm::SmallVector<StateType > upstreamStateTypes;
170
166
llvm::transform (op->getOperands (),
171
167
std::back_inserter (upstreamStateTypes),
172
168
[&](auto operand) {
173
169
Operation *def = operand.getDefiningOp ();
174
170
if (!def || !llvm::isa<IteratorOpInterface>(def))
175
- return LLVMStructType ();
171
+ return StateType ();
176
172
return getExpectedIteratorInfo (def).stateType ;
177
173
});
178
- LLVMStructType stateType = stateTypeComputer (op, upstreamStateTypes);
174
+ StateType stateType = stateTypeComputer (op, upstreamStateTypes);
179
175
setIteratorInfo (op, IteratorInfo (op, nameAssigner, stateType));
180
176
})
181
177
.Default ([&](auto op) { assert (false && " Unexpected op" ); });
0 commit comments