14
14
#include " mlir/Dialect/Func/IR/FuncOps.h"
15
15
#include " mlir/Dialect/Func/Transforms/FuncConversions.h"
16
16
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
17
+ #include " mlir/Dialect/SCF/Transforms.h"
17
18
#include " mlir/IR/ImplicitLocOpBuilder.h"
18
19
#include " mlir/Transforms/DialectConversion.h"
19
20
@@ -38,17 +39,86 @@ class StateTypeConverter : public TypeConverter {
38
39
public:
39
40
StateTypeConverter () {
40
41
addConversion ([](Type type) { return type; });
42
+ addConversion ([&](Type type) { return convertIteratorStateType (type); });
41
43
}
42
44
43
45
private:
46
+ // / Maps an iterator state type to a corresponding LLVMStructType.
47
+ Optional<Type> convertIteratorStateType (Type type) {
48
+ if (auto stateType = type.dyn_cast <StateType>()) {
49
+ llvm::SmallVector<Type> fieldTypes (stateType.getFieldTypes ().begin (),
50
+ stateType.getFieldTypes ().end ());
51
+ for (auto &type : fieldTypes) {
52
+ type = convertType (type);
53
+ }
54
+ return LLVMStructType::getLiteral (type.getContext (), fieldTypes);
55
+ }
56
+ return llvm::None;
57
+ }
58
+ };
59
+
60
+ struct UndefStateOpLowering : public OpConversionPattern <UndefStateOp> {
61
+ UndefStateOpLowering (TypeConverter &typeConverter, MLIRContext *context,
62
+ PatternBenefit benefit = 1 )
63
+ : OpConversionPattern(typeConverter, context, benefit) {}
64
+
65
+ LogicalResult
66
+ matchAndRewrite (UndefStateOp op, OpAdaptor adaptor,
67
+ ConversionPatternRewriter &rewriter) const override {
68
+ Location loc = op->getLoc ();
69
+ Type structType = getTypeConverter ()->convertType (op.getResult ().getType ());
70
+ Value undef = rewriter.create <LLVM::UndefOp>(loc, structType);
71
+ rewriter.replaceOp (op, undef);
72
+ return success ();
73
+ }
74
+ };
75
+
76
+ struct ExtractValueOpLowering
77
+ : public OpConversionPattern<iterators::ExtractValueOp> {
78
+ ExtractValueOpLowering (TypeConverter &typeConverter, MLIRContext *context,
79
+ PatternBenefit benefit = 1 )
80
+ : OpConversionPattern(typeConverter, context, benefit) {}
81
+
82
+ LogicalResult
83
+ matchAndRewrite (iterators::ExtractValueOp op, OpAdaptor adaptor,
84
+ ConversionPatternRewriter &rewriter) const override {
85
+ Location loc = op->getLoc ();
86
+ Type resultType = getTypeConverter ()->convertType (op.getResult ().getType ());
87
+ Value value =
88
+ createExtractValueOp (rewriter, loc, resultType, adaptor.state (),
89
+ {adaptor.index ().getSExtValue ()});
90
+ rewriter.replaceOp (op, value);
91
+ return success ();
92
+ }
93
+ };
94
+
95
+ struct InsertValueOpLowering
96
+ : public OpConversionPattern<iterators::InsertValueOp> {
97
+ InsertValueOpLowering (TypeConverter &typeConverter, MLIRContext *context,
98
+ PatternBenefit benefit = 1 )
99
+ : OpConversionPattern(typeConverter, context, benefit) {}
100
+
101
+ LogicalResult
102
+ matchAndRewrite (iterators::InsertValueOp op, OpAdaptor adaptor,
103
+ ConversionPatternRewriter &rewriter) const override {
104
+ Location loc = op->getLoc ();
105
+ Value updatedState =
106
+ createInsertValueOp (rewriter, loc, adaptor.state (), adaptor.value (),
107
+ {adaptor.index ().getSExtValue ()});
108
+ rewriter.replaceOp (op, updatedState);
109
+ return success ();
110
+ }
44
111
};
45
112
46
113
void mlir::iterators::populateStatesToLLVMConversionPatterns (
47
114
RewritePatternSet &patterns, TypeConverter &typeConverter) {
48
- // patterns.add<
49
- // // clang-format off
50
- // // clang-format on
51
- // >(typeConverter, patterns.getContext());
115
+ patterns.add <
116
+ // clang-format off
117
+ UndefStateOpLowering,
118
+ ExtractValueOpLowering,
119
+ InsertValueOpLowering
120
+ // clang-format on
121
+ >(typeConverter, patterns.getContext ());
52
122
}
53
123
54
124
void ConvertStatesToLLVMPass::runOnOperation () {
@@ -80,6 +150,10 @@ void ConvertStatesToLLVMPass::runOnOperation() {
80
150
return typeConverter.isSignatureLegal (op.getCalleeType ());
81
151
});
82
152
153
+ // Add patterns that convert the types in SCF constructs.
154
+ scf::populateSCFStructuralTypeConversionsAndLegality (typeConverter, patterns,
155
+ target);
156
+
83
157
// Use UnrealizedConversionCast as materializations, which have to be cleaned
84
158
// up by later passes.
85
159
auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
0 commit comments