11
11
#include " mlir/Dialect/Func/IR/FuncOps.h"
12
12
#include " mlir/IR/PatternMatch.h"
13
13
#include " mlir/Transforms/DialectConversion.h"
14
- #include < optional>
15
14
16
15
namespace mlir {
17
16
#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL
@@ -21,36 +20,38 @@ namespace mlir {
21
20
using namespace mlir ;
22
21
23
22
namespace {
24
- struct FloatTypeResolver {
25
- std::optional<bool > operator ()(Type type) const {
26
- auto elementType = cast<FloatType>(type);
27
- if (!isa<Float32Type, Float64Type>(elementType))
28
- return {};
29
- return elementType.getIntOrFloatBitWidth () == 64 ;
30
- }
31
- };
32
23
33
- template <typename Op, typename TypeResolver = FloatTypeResolver>
34
- struct ScalarOpToROCDLCall : public OpRewritePattern <Op> {
24
+ template <typename Op>
25
+ // Pattern to convert Complex ops to ROCDL function calls.
26
+ struct ComplexOpToROCDLCall : public OpRewritePattern <Op> {
35
27
using OpRewritePattern<Op>::OpRewritePattern;
36
- ScalarOpToROCDLCall (MLIRContext *context, StringRef floatFunc,
37
- StringRef doubleFunc, PatternBenefit benefit)
28
+ ComplexOpToROCDLCall (MLIRContext *context, StringRef floatFunc,
29
+ StringRef doubleFunc, PatternBenefit benefit = 1 )
38
30
: OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
39
31
doubleFunc (doubleFunc) {}
40
32
41
33
LogicalResult matchAndRewrite (Op op, PatternRewriter &rewriter) const final {
42
- auto module = SymbolTable::getNearestSymbolTable (op);
43
- auto isDouble = TypeResolver ()(op.getType ());
44
- if (!isDouble.has_value ())
34
+ Operation *symTable = SymbolTable::getNearestSymbolTable (op);
35
+ Type resType = op.getType ();
36
+ if (auto complexType = dyn_cast<ComplexType>(resType))
37
+ resType = complexType.getElementType ();
38
+ FloatType floatTy = dyn_cast<FloatType>(resType);
39
+ if (!floatTy)
45
40
return failure ();
46
41
47
- auto name = *isDouble ? doubleFunc : floatFunc;
42
+ StringRef name;
43
+ if (floatTy.isF64 ())
44
+ name = doubleFunc;
45
+ else if (floatTy.isF32 ())
46
+ name = floatFunc;
47
+ else
48
+ return failure ();
48
49
49
50
auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
50
- SymbolTable::lookupSymbolIn (module , name));
51
+ SymbolTable::lookupSymbolIn (symTable , name));
51
52
if (!opFunc) {
52
53
OpBuilder::InsertionGuard guard (rewriter);
53
- rewriter.setInsertionPointToStart (&module ->getRegion (0 ).front ());
54
+ rewriter.setInsertionPointToStart (&symTable ->getRegion (0 ).front ());
54
55
auto funcTy = FunctionType::get (
55
56
rewriter.getContext (), op->getOperandTypes (), op->getResultTypes ());
56
57
opFunc =
@@ -67,10 +68,12 @@ struct ScalarOpToROCDLCall : public OpRewritePattern<Op> {
67
68
};
68
69
} // namespace
69
70
70
- void mlir::populateComplexToROCDLConversionPatterns (RewritePatternSet &patterns,
71
- PatternBenefit benefit) {
72
- patterns.add <ScalarOpToROCDLCall<complex::AbsOp>>(
73
- patterns.getContext (), " __ocml_cabs_f32" , " __ocml_cabs_f64" , benefit);
71
+ void mlir::populateComplexToROCDLConversionPatterns (
72
+ RewritePatternSet &patterns) {
73
+ patterns.add <ComplexOpToROCDLCall<complex::AbsOp>>(
74
+ patterns.getContext (), " __ocml_cabs_f32" , " __ocml_cabs_f64" );
75
+ patterns.add <ComplexOpToROCDLCall<complex::ExpOp>>(
76
+ patterns.getContext (), " __ocml_cexp_f32" , " __ocml_cexp_f64" );
74
77
}
75
78
76
79
namespace {
@@ -81,14 +84,14 @@ struct ConvertComplexToROCDLPass
81
84
} // namespace
82
85
83
86
void ConvertComplexToROCDLPass::runOnOperation () {
84
- auto module = getOperation ();
87
+ Operation *op = getOperation ();
85
88
86
89
RewritePatternSet patterns (&getContext ());
87
- populateComplexToROCDLConversionPatterns (patterns, /* benefit= */ 1 );
90
+ populateComplexToROCDLConversionPatterns (patterns);
88
91
89
92
ConversionTarget target (getContext ());
90
93
target.addLegalDialect <func::FuncDialect>();
91
- target.addIllegalOp <complex::AbsOp>();
92
- if (failed (applyPartialConversion (module , target, std::move (patterns))))
94
+ target.addIllegalOp <complex::AbsOp, complex::ExpOp >();
95
+ if (failed (applyPartialConversion (op , target, std::move (patterns))))
93
96
signalPassFailure ();
94
97
}
0 commit comments