@@ -21,59 +21,53 @@ using namespace mlir;
21
21
22
22
namespace {
23
23
24
- template <typename Op>
24
+ template <typename Op, typename Ty >
25
25
// Pattern to convert Complex ops to ROCDL function calls.
26
26
struct ComplexOpToROCDLCall : public OpRewritePattern <Op> {
27
27
using OpRewritePattern<Op>::OpRewritePattern;
28
- ComplexOpToROCDLCall (MLIRContext *context, StringRef floatFunc,
29
- StringRef doubleFunc, PatternBenefit benefit = 1 )
30
- : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
31
- doubleFunc (doubleFunc) {}
28
+ ComplexOpToROCDLCall (MLIRContext *context, StringRef funcName,
29
+ PatternBenefit benefit = 1 )
30
+ : OpRewritePattern<Op>(context, benefit), funcName(funcName) {}
32
31
33
32
LogicalResult matchAndRewrite (Op op, PatternRewriter &rewriter) const final {
34
33
Operation *symTable = SymbolTable::getNearestSymbolTable (op);
35
34
Type resType = op.getType ();
36
35
if (auto complexType = dyn_cast<ComplexType>(resType))
37
36
resType = complexType.getElementType ();
38
- FloatType floatTy = dyn_cast<FloatType>(resType);
39
- if (!floatTy)
40
- return failure ();
41
-
42
- StringRef name;
43
- if (floatTy.isF64 ())
44
- name = doubleFunc;
45
- else if (floatTy.isF32 ())
46
- name = floatFunc;
47
- else
37
+ if (!isa<Ty>(resType))
48
38
return failure ();
49
39
50
40
auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
51
- SymbolTable::lookupSymbolIn (symTable, name ));
41
+ SymbolTable::lookupSymbolIn (symTable, funcName ));
52
42
if (!opFunc) {
53
43
OpBuilder::InsertionGuard guard (rewriter);
54
44
rewriter.setInsertionPointToStart (&symTable->getRegion (0 ).front ());
55
45
auto funcTy = FunctionType::get (
56
46
rewriter.getContext (), op->getOperandTypes (), op->getResultTypes ());
57
- opFunc =
58
- rewriter. create <func::FuncOp>(rewriter. getUnknownLoc (), name, funcTy);
47
+ opFunc = rewriter. create <func::FuncOp>(rewriter. getUnknownLoc (), funcName,
48
+ funcTy);
59
49
opFunc.setPrivate ();
60
50
}
61
- rewriter.replaceOpWithNewOp <func::CallOp>(op, name , op.getType (),
51
+ rewriter.replaceOpWithNewOp <func::CallOp>(op, funcName , op.getType (),
62
52
op->getOperands ());
63
53
return success ();
64
54
}
65
55
66
56
private:
67
- std::string floatFunc, doubleFunc ;
57
+ std::string funcName ;
68
58
};
69
59
} // namespace
70
60
71
61
void mlir::populateComplexToROCDLConversionPatterns (
72
62
RewritePatternSet &patterns) {
73
- patterns.add <ComplexOpToROCDLCall<complex::AbsOp>>(
74
- patterns.getContext (), " __ocml_cabs_f32" , " __ocml_cabs_f64" );
75
- patterns.add <ComplexOpToROCDLCall<complex::ExpOp>>(
76
- patterns.getContext (), " __ocml_cexp_f32" , " __ocml_cexp_f64" );
63
+ patterns.add <ComplexOpToROCDLCall<complex::AbsOp, Float32Type>>(
64
+ patterns.getContext (), " __ocml_cabs_f32" );
65
+ patterns.add <ComplexOpToROCDLCall<complex::AbsOp, Float64Type>>(
66
+ patterns.getContext (), " __ocml_cabs_f64" );
67
+ patterns.add <ComplexOpToROCDLCall<complex::ExpOp, Float32Type>>(
68
+ patterns.getContext (), " __ocml_cexp_f32" );
69
+ patterns.add <ComplexOpToROCDLCall<complex::ExpOp, Float64Type>>(
70
+ patterns.getContext (), " __ocml_cexp_f64" );
77
71
}
78
72
79
73
namespace {
0 commit comments