Skip to content

Commit ba1230d

Browse files
committed
[MLIR] Add ComplexTOROCDL pass
This patch adds a new ComplexToROCDL pass to convert complex.abs operations to __ocml_cabs_f32/__ocml_cabs_f64 calls.
1 parent 35f6d91 commit ba1230d

File tree

9 files changed

+164
-1
lines changed

9 files changed

+164
-1
lines changed

flang/lib/Optimizer/CodeGen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ add_flang_library(FIRCodeGen
4040
MLIRMathToLLVM
4141
MLIRMathToLibm
4242
MLIRMathToROCDL
43+
MLIRComplexToROCDL
4344
MLIROpenMPToLLVM
4445
MLIROpenACCDialect
4546
MLIRBuiltinToLLVMIRTranslation

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
3434
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
3535
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
36+
#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
3637
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
3738
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
3839
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
@@ -4105,8 +4106,10 @@ class FIRToLLVMLowering
41054106
// GPU library calls, the rest can be converted to LLVM intrinsics, which
41064107
// is handled in the mathToLLVM conversion. The lowering to libm calls is
41074108
// not needed since all math operations are handled this way.
4108-
if (isAMDGCN)
4109+
if (isAMDGCN) {
41094110
mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
4111+
mathConvertionPM.addPass(mlir::createConvertComplexToROCDL());
4112+
}
41104113

41114114
// Convert math::FPowI operations to inline implementation
41124115
// only if the exponent's width is greater than 32, otherwise,
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
2+
#define MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_
3+
4+
#include "mlir/IR/PatternMatch.h"
5+
#include "mlir/Pass/Pass.h"
6+
7+
namespace mlir {
8+
class RewritePatternSet;
9+
10+
#define GEN_PASS_DECL_CONVERTCOMPLEXTOROCDL
11+
#include "mlir/Conversion/Passes.h.inc"
12+
13+
/// Populate the given list with patterns that convert from Complex to ROCDL
14+
/// calls.
15+
void populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
16+
PatternBenefit benefit);
17+
} // namespace mlir
18+
19+
#endif // MLIR_CONVERSION_COMPLEXTOROCDL_COMPLEXTOROCDL_H_

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
2424
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
2525
#include "mlir/Conversion/ComplexToLibm/ComplexToLibm.h"
26+
#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
2627
#include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRVPass.h"
2728
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
2829
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,18 @@ def ConvertComplexToLibm : Pass<"convert-complex-to-libm", "ModuleOp"> {
312312
let dependentDialects = ["func::FuncDialect"];
313313
}
314314

315+
//===----------------------------------------------------------------------===//
316+
// ComplexToROCDL
317+
//===----------------------------------------------------------------------===//
318+
319+
def ConvertComplexToROCDL : Pass<"convert-complex-to-rocdl", "ModuleOp"> {
320+
let summary = "Convert Complex dialect to ROCDL calls";
321+
let description = [{
322+
This pass converts supported Complex ops to calls to the AMD device library.
323+
}];
324+
let dependentDialects = ["func::FuncDialect"];
325+
}
326+
315327
//===----------------------------------------------------------------------===//
316328
// ComplexToSPIRV
317329
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_subdirectory(AsyncToLLVM)
1313
add_subdirectory(BufferizationToMemRef)
1414
add_subdirectory(ComplexCommon)
1515
add_subdirectory(ComplexToLibm)
16+
add_subdirectory(ComplexToROCDL)
1617
add_subdirectory(ComplexToLLVM)
1718
add_subdirectory(ComplexToSPIRV)
1819
add_subdirectory(ComplexToStandard)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
add_mlir_conversion_library(MLIRComplexToROCDL
2+
ComplexToROCDL.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ComplexToROCDL
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRComplexDialect
15+
MLIRFuncDialect
16+
MLIRPass
17+
MLIRTransformUtils
18+
)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
//===-- ComplexToROCDL.cpp - conversion from Complex to ROCDL calls -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/ComplexToROCDL/ComplexToROCDL.h"
10+
11+
#include "mlir/Dialect/Complex/IR/Complex.h"
12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/Transforms/DialectConversion.h"
15+
#include <optional>
16+
17+
namespace mlir {
18+
#define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDL
19+
#include "mlir/Conversion/Passes.h.inc"
20+
} // namespace mlir
21+
22+
using namespace mlir;
23+
24+
namespace {
25+
struct FloatTypeResolver {
26+
std::optional<bool> operator()(Type type) const {
27+
auto elementType = cast<FloatType>(type);
28+
if (!isa<Float32Type, Float64Type>(elementType))
29+
return {};
30+
return elementType.getIntOrFloatBitWidth() == 64;
31+
}
32+
};
33+
34+
template <typename Op, typename TypeResolver = FloatTypeResolver>
35+
struct ScalarOpToROCDLCall : public OpRewritePattern<Op> {
36+
using OpRewritePattern<Op>::OpRewritePattern;
37+
ScalarOpToROCDLCall(MLIRContext *context, StringRef floatFunc,
38+
StringRef doubleFunc, PatternBenefit benefit)
39+
: OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
40+
doubleFunc(doubleFunc) {}
41+
42+
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
43+
auto module = SymbolTable::getNearestSymbolTable(op);
44+
auto isDouble = TypeResolver()(op.getType());
45+
if (!isDouble.has_value())
46+
return failure();
47+
48+
auto name = *isDouble ? doubleFunc : floatFunc;
49+
50+
auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
51+
SymbolTable::lookupSymbolIn(module, name));
52+
if (!opFunc) {
53+
OpBuilder::InsertionGuard guard(rewriter);
54+
rewriter.setInsertionPointToStart(&module->getRegion(0).front());
55+
auto funcTy = FunctionType::get(
56+
rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
57+
opFunc =
58+
rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name, funcTy);
59+
opFunc.setPrivate();
60+
}
61+
rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
62+
op->getOperands());
63+
return success();
64+
}
65+
66+
private:
67+
std::string floatFunc, doubleFunc;
68+
};
69+
} // namespace
70+
71+
void mlir::populateComplexToROCDLConversionPatterns(RewritePatternSet &patterns,
72+
PatternBenefit benefit) {
73+
patterns.add<ScalarOpToROCDLCall<complex::AbsOp>>(
74+
patterns.getContext(), "__ocml_cabs_f32", "__ocml_cabs_f64", benefit);
75+
}
76+
77+
namespace {
78+
struct ConvertComplexToROCDLPass
79+
: public impl::ConvertComplexToROCDLBase<ConvertComplexToROCDLPass> {
80+
void runOnOperation() override;
81+
};
82+
} // namespace
83+
84+
void ConvertComplexToROCDLPass::runOnOperation() {
85+
auto module = getOperation();
86+
87+
RewritePatternSet patterns(&getContext());
88+
populateComplexToROCDLConversionPatterns(patterns, /*benefit=*/1);
89+
90+
ConversionTarget target(getContext());
91+
target.addLegalDialect<func::FuncDialect>();
92+
target.addIllegalOp<complex::AbsOp>();
93+
if (failed(applyPartialConversion(module, target, std::move(patterns))))
94+
signalPassFailure();
95+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: mlir-opt %s -convert-complex-to-rocdl -canonicalize | FileCheck %s
2+
3+
// CHECK-DAG: @__ocml_cabs_f32(complex<f32>) -> f32
4+
// CHECK-DAG: @__ocml_cabs_f64(complex<f64>) -> f64
5+
6+
func.func @abs_caller(%f: complex<f32>, %d: complex<f64>) -> (f32, f64) {
7+
// CHECK: %[[RF:.*]] = call @__ocml_cabs_f32(%[[F:.*]])
8+
%rf = complex.abs %f : complex<f32>
9+
// CHECK: %[[RD:.*]] = call @__ocml_cabs_f64(%[[D:.*]])
10+
%rd = complex.abs %d : complex<f64>
11+
// CHECK: return %[[RF]], %[[RD]]
12+
return %rf, %rd : f32, f64
13+
}

0 commit comments

Comments
 (0)