Skip to content

Commit 9bf20dc

Browse files
Iterators: Add boilerplate code for new 'States to LLVM' pass. (#605)
1 parent b18bced commit 9bf20dc

File tree

9 files changed

+175
-2
lines changed

9 files changed

+175
-2
lines changed

experimental/iterators/include/iterators/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define ITERATORS_CONVERSION_PASSES_H
1111

1212
#include "iterators/Conversion/IteratorsToLLVM/IteratorsToLLVM.h"
13+
#include "iterators/Conversion/StatesToLLVM/StatesToLLVM.h"
1314
#include "iterators/Conversion/TabularToLLVM/TabularToLLVM.h"
1415
#include "mlir/Pass/Pass.h"
1516

experimental/iterators/include/iterators/Conversion/Passes.td

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def ConvertIteratorsToLLVM : Pass<"convert-iterators-to-llvm", "ModuleOp"> {
2929
currently only works if the use-def chains of `Stream`s form a tree, i.e.,
3030
every `Stream` is used as an operand by exactly one subsequent iterator.
3131

32-
More precisely, for each iterator, the lowering produces a state in the form
33-
of an `!llvm.struct<...>`, including any local state that the iterator might
32+
More precisely, for each iterator, the lowering produces a state with a
33+
number of typed fields, including any local state that the iterator might
3434
require **plus the states of all iterators in the transitive use-def chain**
3535
of its operands. The computations are expressed as three functions, `Open`,
3636
`Next`, and `Close`, which operate on that state and which continuously pass
@@ -58,6 +58,26 @@ def ConvertIteratorsToLLVM : Pass<"convert-iterators-to-llvm", "ModuleOp"> {
5858
];
5959
}
6060

61+
//===----------------------------------------------------------------------===//
62+
// IteratorsToLLVM
63+
//===----------------------------------------------------------------------===//
64+
65+
def ConvertStatesToLLVM : Pass<"convert-states-to-llvm", "ModuleOp"> {
66+
let summary = "Convert the operations on iterator states into the LLVM "
67+
"dialect";
68+
let description = [{
69+
This lowering pass converts operations on iterator states into equivalent
70+
operations of the LLVM dialect. Currently, the ops on iterator states are
71+
essentially equivalent to the LLVM ops dealing with structs (but allow
72+
arbitrary types), so the lowering only consists of straightforward,
73+
one-to-one patterns.
74+
}];
75+
let constructor = "mlir::createConvertStatesToLLVMPass()";
76+
let dependentDialects = [
77+
"LLVM::LLVMDialect"
78+
];
79+
}
80+
6181
//===----------------------------------------------------------------------===//
6282
// TabularToLLVM
6383
//===----------------------------------------------------------------------===//
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//===-- StatesToLLVM.h - Utils to convert from Iterators states -*- C++ -*-===//
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef ITERATORS_CONVERSION_STATESTOLLVM_STATESTOLLVM_H
10+
#define ITERATORS_CONVERSION_STATESTOLLVM_STATESTOLLVM_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class ModuleOp;
16+
template <typename T>
17+
class OperationPass;
18+
class RewritePatternSet;
19+
class TypeConverter;
20+
21+
namespace iterators {
22+
23+
/// Populate the given list with patterns that convert from Iterator states to
24+
/// LLVM.
25+
void populateStatesToLLVMConversionPatterns(RewritePatternSet &patterns,
26+
TypeConverter &typeConverter);
27+
28+
} // namespace iterators
29+
30+
/// Create a pass to convert operations on Iterator states to the LLVM dialect.
31+
std::unique_ptr<OperationPass<ModuleOp>> createConvertStatesToLLVMPass();
32+
33+
} // namespace mlir
34+
35+
#endif // ITERATORS_CONVERSION_STATESTOLLVM_STATESTOLLVM_H

experimental/iterators/lib/CAPI/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ add_mlir_public_c_api_library(IteratorsCAPI
77
MLIRTabular
88
MLIRTabularToLLVM
99
MLIRPass
10+
MLIRStatesToLLVM
1011
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(IteratorsToLLVM)
2+
add_subdirectory(StatesToLLVM)
23
add_subdirectory(TabularToLLVM)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
add_mlir_conversion_library(MLIRStatesToLLVM
2+
StatesToLLVM.cpp
3+
4+
DEPENDS
5+
MLIRIteratorsConversionIncGen
6+
7+
LINK_LIBS PUBLIC
8+
IteratorsUtils
9+
MLIRFuncDialect
10+
MLIRFuncTransforms
11+
MLIRIterators
12+
MLIRLLVMDialect
13+
MLIRPass
14+
)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
//===-- StatesToLLVM.h - Conversion from Iterator states to LLVM-*- C++ -*-===//
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "iterators/Conversion/StatesToLLVM/StatesToLLVM.h"
10+
11+
#include "../PassDetail.h"
12+
#include "iterators/Dialect/Iterators/IR/Iterators.h"
13+
#include "iterators/Utils/MLIRSupport.h"
14+
#include "mlir/Dialect/Func/IR/FuncOps.h"
15+
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
16+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17+
#include "mlir/IR/ImplicitLocOpBuilder.h"
18+
#include "mlir/Transforms/DialectConversion.h"
19+
20+
namespace mlir {
21+
class MLIRContext;
22+
} // namespace mlir
23+
24+
using namespace mlir;
25+
using namespace mlir::func;
26+
using namespace mlir::iterators;
27+
using namespace mlir::LLVM;
28+
29+
namespace {
30+
struct ConvertStatesToLLVMPass
31+
: public ConvertStatesToLLVMBase<ConvertStatesToLLVMPass> {
32+
void runOnOperation() override;
33+
};
34+
} // namespace
35+
36+
/// Maps state types from the Iterators dialect to corresponding types in LLVM.
37+
class StateTypeConverter : public TypeConverter {
38+
public:
39+
StateTypeConverter() {
40+
addConversion([](Type type) { return type; });
41+
}
42+
43+
private:
44+
};
45+
46+
void mlir::iterators::populateStatesToLLVMConversionPatterns(
47+
RewritePatternSet &patterns, TypeConverter &typeConverter) {
48+
// patterns.add<
49+
// // clang-format off
50+
// // clang-format on
51+
// >(typeConverter, patterns.getContext());
52+
}
53+
54+
void ConvertStatesToLLVMPass::runOnOperation() {
55+
auto module = getOperation();
56+
StateTypeConverter typeConverter;
57+
58+
// Convert the remaining ops of this dialect using dialect conversion.
59+
ConversionTarget target(getContext());
60+
target.addLegalDialect<LLVMDialect>();
61+
target.addLegalOp<ModuleOp>();
62+
RewritePatternSet patterns(&getContext());
63+
64+
populateStatesToLLVMConversionPatterns(patterns, typeConverter);
65+
66+
// Add patterns that converts function signature and calls.
67+
populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
68+
typeConverter);
69+
populateCallOpTypeConversionPattern(patterns, typeConverter);
70+
populateReturnOpTypeConversionPattern(patterns, typeConverter);
71+
72+
// Force application of that pattern if signature is not legal yet.
73+
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
74+
return typeConverter.isSignatureLegal(op.getFunctionType());
75+
});
76+
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
77+
return typeConverter.isLegal(op.getOperandTypes());
78+
});
79+
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
80+
return typeConverter.isSignatureLegal(op.getCalleeType());
81+
});
82+
83+
// Use UnrealizedConversionCast as materializations, which have to be cleaned
84+
// up by later passes.
85+
auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
86+
Location loc) {
87+
auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
88+
return Optional<Value>(cast.getResult(0));
89+
};
90+
typeConverter.addSourceMaterialization(addUnrealizedCast);
91+
typeConverter.addTargetMaterialization(addUnrealizedCast);
92+
93+
if (failed(applyPartialConversion(module, target, std::move(patterns))))
94+
signalPassFailure();
95+
}
96+
97+
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertStatesToLLVMPass() {
98+
return std::make_unique<ConvertStatesToLLVMPass>();
99+
}

tools/mlir-proto-lsp-server/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ if(SANDBOX_ENABLE_ITERATORS)
1616
list(APPEND dialect_libs MLIRTabular)
1717
list(APPEND conversion_libs MLIRIteratorsToLLVM)
1818
list(APPEND conversion_libs MLIRTabularToLLVM)
19+
list(APPEND conversion_libs MLIRStatesToLLVM)
1920
endif()
2021

2122
target_link_libraries(mlir-proto-lsp-server

tools/mlir-proto-opt/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ if(SANDBOX_ENABLE_ITERATORS)
1616
list(APPEND dialect_libs MLIRTabular)
1717
list(APPEND conversion_libs MLIRIteratorsToLLVM)
1818
list(APPEND conversion_libs MLIRTabularToLLVM)
19+
list(APPEND conversion_libs MLIRStatesToLLVM)
1920
endif()
2021

2122
target_link_libraries(mlir-proto-opt

0 commit comments

Comments
 (0)