Skip to content

Commit 0949d97

Browse files
committed
Add FloatTy as a template parameter.
1 parent 012aa6c commit 0949d97

File tree

1 file changed

+18
-24
lines changed

1 file changed

+18
-24
lines changed

mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,59 +21,53 @@ using namespace mlir;
2121

2222
namespace {
2323

24-
template <typename Op>
24+
template <typename Op, typename Ty>
2525
// Pattern to convert Complex ops to ROCDL function calls.
2626
struct ComplexOpToROCDLCall : public OpRewritePattern<Op> {
2727
using OpRewritePattern<Op>::OpRewritePattern;
28-
ComplexOpToROCDLCall(MLIRContext *context, StringRef floatFunc,
29-
StringRef doubleFunc, PatternBenefit benefit = 1)
30-
: OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
31-
doubleFunc(doubleFunc) {}
28+
ComplexOpToROCDLCall(MLIRContext *context, StringRef funcName,
29+
PatternBenefit benefit = 1)
30+
: OpRewritePattern<Op>(context, benefit), funcName(funcName) {}
3231

3332
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
3433
Operation *symTable = SymbolTable::getNearestSymbolTable(op);
3534
Type resType = op.getType();
3635
if (auto complexType = dyn_cast<ComplexType>(resType))
3736
resType = complexType.getElementType();
38-
FloatType floatTy = dyn_cast<FloatType>(resType);
39-
if (!floatTy)
40-
return failure();
41-
42-
StringRef name;
43-
if (floatTy.isF64())
44-
name = doubleFunc;
45-
else if (floatTy.isF32())
46-
name = floatFunc;
47-
else
37+
if (!isa<Ty>(resType))
4838
return failure();
4939

5040
auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
51-
SymbolTable::lookupSymbolIn(symTable, name));
41+
SymbolTable::lookupSymbolIn(symTable, funcName));
5242
if (!opFunc) {
5343
OpBuilder::InsertionGuard guard(rewriter);
5444
rewriter.setInsertionPointToStart(&symTable->getRegion(0).front());
5545
auto funcTy = FunctionType::get(
5646
rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
57-
opFunc =
58-
rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, funcTy);
47+
opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), funcName,
48+
funcTy);
5949
opFunc.setPrivate();
6050
}
61-
rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
51+
rewriter.replaceOpWithNewOp<func::CallOp>(op, funcName, op.getType(),
6252
op->getOperands());
6353
return success();
6454
}
6555

6656
private:
67-
std::string floatFunc, doubleFunc;
57+
std::string funcName;
6858
};
6959
} // namespace
7060

7161
void mlir::populateComplexToROCDLConversionPatterns(
7262
RewritePatternSet &patterns) {
73-
patterns.add<ComplexOpToROCDLCall<complex::AbsOp>>(
74-
patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64");
75-
patterns.add<ComplexOpToROCDLCall<complex::ExpOp>>(
76-
patterns.getContext(), "__ocml_cexp_f32", "__ocml_cexp_f64");
63+
patterns.add<ComplexOpToROCDLCall<complex::AbsOp, Float32Type>>(
64+
patterns.getContext(), "__ocml_cabs_f32");
65+
patterns.add<ComplexOpToROCDLCall<complex::AbsOp, Float64Type>>(
66+
patterns.getContext(), "__ocml_cabs_f64");
67+
patterns.add<ComplexOpToROCDLCall<complex::ExpOp, Float32Type>>(
68+
patterns.getContext(), "__ocml_cexp_f32");
69+
patterns.add<ComplexOpToROCDLCall<complex::ExpOp, Float64Type>>(
70+
patterns.getContext(), "__ocml_cexp_f64");
7771
}
7872

7973
namespace {

0 commit comments

Comments
 (0)