Skip to content

Commit 31a2196

Browse files
Iterators: Implement lowering of iterator state type and ops. (#607)
1 parent edbb0fb commit 31a2196

File tree

3 files changed

+146
-4
lines changed

3 files changed

+146
-4
lines changed

experimental/iterators/lib/Conversion/StatesToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ add_mlir_conversion_library(MLIRStatesToLLVM
1111
MLIRIterators
1212
MLIRLLVMDialect
1313
MLIRPass
14+
MLIRSCFTransforms
1415
)

experimental/iterators/lib/Conversion/StatesToLLVM/StatesToLLVM.cpp

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/Func/IR/FuncOps.h"
1515
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
1616
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17+
#include "mlir/Dialect/SCF/Transforms.h"
1718
#include "mlir/IR/ImplicitLocOpBuilder.h"
1819
#include "mlir/Transforms/DialectConversion.h"
1920

@@ -38,17 +39,86 @@ class StateTypeConverter : public TypeConverter {
3839
public:
3940
StateTypeConverter() {
4041
addConversion([](Type type) { return type; });
42+
addConversion([&](Type type) { return convertIteratorStateType(type); });
4143
}
4244

4345
private:
46+
/// Maps an iterator state type to a corresponding LLVMStructType.
47+
Optional<Type> convertIteratorStateType(Type type) {
48+
if (auto stateType = type.dyn_cast<StateType>()) {
49+
llvm::SmallVector<Type> fieldTypes(stateType.getFieldTypes().begin(),
50+
stateType.getFieldTypes().end());
51+
for (auto &type : fieldTypes) {
52+
type = convertType(type);
53+
}
54+
return LLVMStructType::getLiteral(type.getContext(), fieldTypes);
55+
}
56+
return llvm::None;
57+
}
58+
};
59+
60+
struct UndefStateOpLowering : public OpConversionPattern<UndefStateOp> {
61+
UndefStateOpLowering(TypeConverter &typeConverter, MLIRContext *context,
62+
PatternBenefit benefit = 1)
63+
: OpConversionPattern(typeConverter, context, benefit) {}
64+
65+
LogicalResult
66+
matchAndRewrite(UndefStateOp op, OpAdaptor adaptor,
67+
ConversionPatternRewriter &rewriter) const override {
68+
Location loc = op->getLoc();
69+
Type structType = getTypeConverter()->convertType(op.getResult().getType());
70+
Value undef = rewriter.create<LLVM::UndefOp>(loc, structType);
71+
rewriter.replaceOp(op, undef);
72+
return success();
73+
}
74+
};
75+
76+
struct ExtractValueOpLowering
77+
: public OpConversionPattern<iterators::ExtractValueOp> {
78+
ExtractValueOpLowering(TypeConverter &typeConverter, MLIRContext *context,
79+
PatternBenefit benefit = 1)
80+
: OpConversionPattern(typeConverter, context, benefit) {}
81+
82+
LogicalResult
83+
matchAndRewrite(iterators::ExtractValueOp op, OpAdaptor adaptor,
84+
ConversionPatternRewriter &rewriter) const override {
85+
Location loc = op->getLoc();
86+
Type resultType = getTypeConverter()->convertType(op.getResult().getType());
87+
Value value =
88+
createExtractValueOp(rewriter, loc, resultType, adaptor.state(),
89+
{adaptor.index().getSExtValue()});
90+
rewriter.replaceOp(op, value);
91+
return success();
92+
}
93+
};
94+
95+
struct InsertValueOpLowering
96+
: public OpConversionPattern<iterators::InsertValueOp> {
97+
InsertValueOpLowering(TypeConverter &typeConverter, MLIRContext *context,
98+
PatternBenefit benefit = 1)
99+
: OpConversionPattern(typeConverter, context, benefit) {}
100+
101+
LogicalResult
102+
matchAndRewrite(iterators::InsertValueOp op, OpAdaptor adaptor,
103+
ConversionPatternRewriter &rewriter) const override {
104+
Location loc = op->getLoc();
105+
Value updatedState =
106+
createInsertValueOp(rewriter, loc, adaptor.state(), adaptor.value(),
107+
{adaptor.index().getSExtValue()});
108+
rewriter.replaceOp(op, updatedState);
109+
return success();
110+
}
44111
};
45112

46113
void mlir::iterators::populateStatesToLLVMConversionPatterns(
47114
RewritePatternSet &patterns, TypeConverter &typeConverter) {
48-
// patterns.add<
49-
// // clang-format off
50-
// // clang-format on
51-
// >(typeConverter, patterns.getContext());
115+
patterns.add<
116+
// clang-format off
117+
UndefStateOpLowering,
118+
ExtractValueOpLowering,
119+
InsertValueOpLowering
120+
// clang-format on
121+
>(typeConverter, patterns.getContext());
52122
}
53123

54124
void ConvertStatesToLLVMPass::runOnOperation() {
@@ -80,6 +150,10 @@ void ConvertStatesToLLVMPass::runOnOperation() {
80150
return typeConverter.isSignatureLegal(op.getCalleeType());
81151
});
82152

153+
// Add patterns that convert the types in SCF constructs.
154+
scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns,
155+
target);
156+
83157
// Use UnrealizedConversionCast as materializations, which have to be cleaned
84158
// up by later passes.
85159
auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// RUN: mlir-proto-opt %s -convert-states-to-llvm \
2+
// RUN: | FileCheck --enable-var-scope %s
3+
4+
func.func @testUndefInsertExtract() {
5+
// CHECK-LABEL: func.func @testUndefInsertExtract() {
6+
%initial_state = iterators.undefstate : !iterators.state<i32>
7+
// CHECK-NEXT: %[[V0:.*]] = llvm.mlir.undef : !llvm.struct<(i32)>
8+
%value = arith.constant 0 : i32
9+
// CHECK-NEXT: %[[V1:.*]] = arith.constant 0 : i32
10+
%inserted_state = iterators.insertvalue %value into %initial_state[0] : !iterators.state<i32>
11+
// CHECK-NEXT: %[[V2:.*]] = llvm.insertvalue %[[V1]], %[[V0]][0 : index] : !llvm.struct<(i32)>
12+
%extracted_value = iterators.extractvalue %inserted_state[0] : !iterators.state<i32>
13+
// CHECK-NEXT: %[[V3:.*]] = llvm.extractvalue %[[V2]][0 : index] : !llvm.struct<(i32)>
14+
return
15+
// CHECK-NEXT: return
16+
}
17+
// CHECK-NEXT: }
18+
19+
func.func @testNestedType() {
20+
// CHECK-LABEL: func.func @testNestedType() {
21+
%outer_state = iterators.undefstate : !iterators.state<i32, !iterators.state<i32>>
22+
// CHECK-NEXT: %[[V0:.*]] = llvm.mlir.undef : !llvm.struct<(i32, struct<(i32)>)>
23+
%inner_state = iterators.extractvalue %outer_state[1] : !iterators.state<i32, !iterators.state<i32>>
24+
// CHECK-NEXT: %[[V1:.*]] = llvm.extractvalue %[[V0]][1 : index] : !llvm.struct<(i32, struct<(i32)>)>
25+
return
26+
// CHECK-NEXT: return
27+
}
28+
// CHECK-NEXT: }
29+
30+
func.func @testFuncReturn(%state: !iterators.state<i32>) -> !iterators.state<i32> {
31+
// CHECK-LABEL: func.func @testFuncReturn(%{{.*}}: !llvm.struct<(i32)>) -> !llvm.struct<(i32)> {
32+
return %state : !iterators.state<i32>
33+
// CHECK-NEXT: return %[[V0:.*]] : !llvm.struct<(i32)>
34+
}
35+
// CHECK-NEXT: }
36+
37+
func.func @testCall() {
38+
// CHECK-LABEL: func.func @testCall() {
39+
%state = iterators.undefstate : !iterators.state<i32>
40+
// CHECK-NEXT: %[[V0:.*]] = llvm.mlir.undef : !llvm.struct<(i32)>
41+
func.call @testFuncReturn(%state) : (!iterators.state<i32>) -> !iterators.state<i32>
42+
// CHECK-NEXT: %[[V1:.*]] = call @testFuncReturn(%[[V0]]) : (!llvm.struct<(i32)>) -> !llvm.struct<(i32)>
43+
return
44+
// CHECK-NEXT: return
45+
}
46+
// CHECK-NEXT: }
47+
48+
func.func @testScf() {
49+
// CHECK-LABEL: func.func @testScf() {
50+
%state = iterators.undefstate : !iterators.state<i32>
51+
// CHECK-NEXT: %[[V0:.*]] = llvm.mlir.undef : !llvm.struct<(i32)>
52+
%cmp = arith.constant true
53+
// CHECK-NEXT: %[[V1:.*]] = arith.constant true
54+
%a = scf.if %cmp -> !iterators.state<i32> {
55+
// CHECK-NEXT: %[[V2:.*]] = scf.if %[[V1]] -> (!llvm.struct<(i32)>) {
56+
scf.yield %state : !iterators.state<i32>
57+
// CHECK-NEXT: scf.yield %[[V0]] : !llvm.struct<(i32)>
58+
} else {
59+
// CHECK-NEXT: } else {
60+
scf.yield %state : !iterators.state<i32>
61+
// CHECK-NEXT: scf.yield %[[V0]] : !llvm.struct<(i32)>
62+
}
63+
// CHECK-NEXT: }
64+
return
65+
// CHECK-NEXT: return
66+
}
67+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)