Skip to content

Commit 1c1a06d

Browse files
committed
Address reviewer changes.
Add conversion for complex.exp.
1 parent 4061a93 commit 1c1a06d

File tree

5 files changed

+54
-42
lines changed

5 files changed

+54
-42
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4146,24 +4146,24 @@ class FIRToLLVMLowering
41464146
// conversions that affect the ModuleOp, e.g. create new
41474147
// function operations in it. We have to run such conversions
41484148
// as passes here.
4149-
mlir::OpPassManager mathConvertionPM("builtin.module");
4149+
mlir::OpPassManager mathConversionPM("builtin.module");
41504150

41514151
bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN();
41524152
// If compiling for AMD target some math operations must be lowered to AMD
41534153
// GPU library calls, the rest can be converted to LLVM intrinsics, which
41544154
// is handled in the mathToLLVM conversion. The lowering to libm calls is
41554155
// not needed since all math operations are handled this way.
41564156
if (isAMDGCN) {
4157-
mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
4158-
mathConvertionPM.addPass(mlir::createConvertComplexToROCDL());
4157+
mathConversionPM.addPass(mlir::createConvertMathToROCDL());
4158+
mathConversionPM.addPass(mlir::createConvertComplexToROCDL());
41594159
}
41604160

41614161
// Convert math::FPowI operations to inline implementation
41624162
// only if the exponent's width is greater than 32, otherwise,
41634163
// it will be lowered to LLVM intrinsic operation by a later conversion.
41644164
mlir::ConvertMathToFuncsOptions mathToFuncsOptions{};
41654165
mathToFuncsOptions.minWidthOfFPowIExponent = 33;
4166-
mathConvertionPM.addPass(
4166+
mathConversionPM.addPass(
41674167
mlir::createConvertMathToFuncs(mathToFuncsOptions));
41684168

41694169
mlir::ConvertComplexToStandardPassOptions complexToStandardOptions{};
@@ -4176,15 +4176,15 @@ class FIRToLLVMLowering
41764176
complexToStandardOptions.complexRange =
41774177
mlir::complex::ComplexRangeFlags::improved;
41784178
}
4179-
mathConvertionPM.addPass(
4179+
mathConversionPM.addPass(
41804180
mlir::createConvertComplexToStandardPass(complexToStandardOptions));
41814181

41824182
// Convert Math dialect operations into LLVM dialect operations.
41834183
// There is no way to prefer MathToLLVM patterns over MathToLibm
41844184
// patterns (applied below), so we have to run MathToLLVM conversion here.
4185-
mathConvertionPM.addNestedPass<mlir::func::FuncOp>(
4185+
mathConversionPM.addNestedPass<mlir::func::FuncOp>(
41864186
mlir::createConvertMathToLLVMPass());
4187-
if (mlir::failed(runPipeline(mathConvertionPM, mod)))
4187+
if (mlir::failed(runPipeline(mathConversionPM, mod)))
41884188
return signalPassFailure();
41894189

41904190
std::optional<mlir::DataLayout> dl =

mlir/include/mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ class RewritePatternSet;
2020

2121
/// Populate the given list with patterns that convert from Complex to ROCDL
2222
/// calls.
23-
void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
24-
PatternBenefit benefit);
23+
void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns);
2524
} // namespace mlir
2625

2726
#endif // MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_

mlir/lib/Conversion/ComplexToROCDL/CMakeLists.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@ add_mlir_conversion_library(MLIRComplexToROCDL
77
DEPENDS
88
MLIRConversionPassIncGen
99

10-
LINK_COMPONENTS
11-
Core
12-
1310
LINK_LIBS PUBLIC
1411
MLIRComplexDialect
1512
MLIRFuncDialect

mlir/lib/Conversion/ComplexToROCDL/ComplexToROCDL.cpp

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include "mlir/Dialect/Func/IR/FuncOps.h"
1212
#include "mlir/IR/PatternMatch.h"
1313
#include "mlir/Transforms/DialectConversion.h"
14-
#include <optional>
1514

1615
namespace mlir {
1716
#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL
@@ -21,36 +20,38 @@ namespace mlir {
2120
using namespace mlir;
2221

2322
namespace {
24-
struct FloatTypeResolver {
25-
std::optional<bool> operator()(Type type) const {
26-
auto elementType = cast<FloatType>(type);
27-
if (!isa<Float32Type, Float64Type>(elementType))
28-
return {};
29-
return elementType.getIntOrFloatBitWidth() == 64;
30-
}
31-
};
3223

33-
template <typename Op, typename TypeResolver = FloatTypeResolver>
34-
struct ScalarOpToROCDLCall : public OpRewritePattern<Op> {
24+
template <typename Op>
25+
// Pattern to convert Complex ops to ROCDL function calls.
26+
struct ComplexOpToROCDLCall : public OpRewritePattern<Op> {
3527
using OpRewritePattern<Op>::OpRewritePattern;
36-
ScalarOpToROCDLCall(MLIRContext *context, StringRef floatFunc,
37-
StringRef doubleFunc, PatternBenefit benefit)
28+
ComplexOpToROCDLCall(MLIRContext *context, StringRef floatFunc,
29+
StringRef doubleFunc, PatternBenefit benefit = 1)
3830
: OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
3931
doubleFunc(doubleFunc) {}
4032

4133
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
42-
auto module = SymbolTable::getNearestSymbolTable(op);
43-
auto isDouble = TypeResolver()(op.getType());
44-
if (!isDouble.has_value())
34+
Operation *symTable = SymbolTable::getNearestSymbolTable(op);
35+
Type resType = op.getType();
36+
if (auto complexType = dyn_cast<ComplexType>(resType))
37+
resType = complexType.getElementType();
38+
FloatType floatTy = dyn_cast<FloatType>(resType);
39+
if (!floatTy)
4540
return failure();
4641

47-
auto name = *isDouble ? doubleFunc : floatFunc;
42+
StringRef name;
43+
if (floatTy.isF64())
44+
name = doubleFunc;
45+
else if (floatTy.isF32())
46+
name = floatFunc;
47+
else
48+
return failure();
4849

4950
auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
50-
SymbolTable::lookupSymbolIn(module, name));
51+
SymbolTable::lookupSymbolIn(symTable, name));
5152
if (!opFunc) {
5253
OpBuilder::InsertionGuard guard(rewriter);
53-
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
54+
rewriter.setInsertionPointToStart(&symTable->getRegion(0).front());
5455
auto funcTy = FunctionType::get(
5556
rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
5657
opFunc =
@@ -67,10 +68,12 @@ struct ScalarOpToROCDLCall : public OpRewritePattern<Op> {
6768
};
6869
} // namespace
6970

70-
void mlir::populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
71-
PatternBenefit benefit) {
72-
patterns.add<ScalarOpToROCDLCall<complex::AbsOp>>(
73-
patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64", benefit);
71+
void mlir::populateComplexToROCDLConversionPatterns(
72+
RewritePatternSet &patterns) {
73+
patterns.add<ComplexOpToROCDLCall<complex::AbsOp>>(
74+
patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64");
75+
patterns.add<ComplexOpToROCDLCall<complex::ExpOp>>(
76+
patterns.getContext(), "__ocml_cexp_f32", "__ocml_cexp_f64");
7477
}
7578

7679
namespace {
@@ -81,14 +84,14 @@ struct ConvertComplexToROCDLPass
8184
} // namespace
8285

8386
void ConvertComplexToROCDLPass::runOnOperation() {
84-
auto module = getOperation();
87+
Operation *op = getOperation();
8588

8689
RewritePatternSet patterns(&getContext());
87-
populateComplexToROCDLConversionPatterns(patterns, /*benefit=*/1);
90+
populateComplexToROCDLConversionPatterns(patterns);
8891

8992
ConversionTarget target(getContext());
9093
target.addLegalDialect<func::FuncDialect>();
91-
target.addIllegalOp<complex::AbsOp>();
92-
if (failed(applyPartialConversion(module, target, std::move(patterns))))
94+
target.addIllegalOp<complex::AbsOp, complex::ExpOp>();
95+
if (failed(applyPartialConversion(op, target, std::move(patterns))))
9396
signalPassFailure();
9497
}
Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,26 @@
1-
// RUN: mlir-opt %s -convert-complex-to-rocdl -canonicalize | FileCheck %s
1+
// RUN: mlir-opt %s -convert-complex-to-rocdl | FileCheck %s
22

33
// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
44
// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
5+
// CHECK-DAG: @__ocml_cexp_f32(complex<f32>) -> complex<f32>
6+
// CHECK-DAG: @__ocml_cexp_f64(complex<f64>) -> complex<f64>
57

8+
//CHECK-LABEL: @abs_caller
69
func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
7-
// CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%[[F:.*]])
10+
// CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%{{.*}})
811
%rf = complex.abs %f : complex<f32>
9-
// CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%[[D:.*]])
12+
// CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%{{.*}})
1013
%rd = complex.abs %d : complex<f64>
1114
// CHECK: return %[[RF]], %[[RD]]
1215
return %rf, %rd : f32, f64
1316
}
17+
18+
//CHECK-LABEL: @exp_caller
19+
func.func @exp_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
20+
// CHECK: %[[EF:.*]] = call @__ocml_cexp_f32(%{{.*}})
21+
%ef = complex.exp %f : complex<f32>
22+
// CHECK: %[[ED:.*]] = call @__ocml_cexp_f64(%{{.*}})
23+
%ed = complex.exp %d : complex<f64>
24+
// CHECK: return %[[EF]], %[[ED]]
25+
return %ef, %ed : complex<f32>, complex<f64>
26+
}

0 commit comments

Comments
 (0)