From dd6b0fb47552b0ba46d9ebc80043bf2361afbea0 Mon Sep 17 00:00:00 2001 From: Hanwen Date: Thu, 19 Jun 2025 16:48:25 +0800 Subject: [PATCH 01/13] rewrite complex division (__divdc3) into arithmetic-base computation --- flang/include/flang/Optimizer/HLFIR/Passes.td | 4 + .../Optimizer/HLFIR/Transforms/CMakeLists.txt | 1 + .../HLFIR/Transforms/XSComplexConversion.cpp | 98 +++++++++++++++++++ flang/lib/Optimizer/Passes/Pipelines.cpp | 1 + 4 files changed, 104 insertions(+) create mode 100644 flang/lib/Optimizer/HLFIR/Transforms/XSComplexConversion.cpp diff --git a/flang/include/flang/Optimizer/HLFIR/Passes.td b/flang/include/flang/Optimizer/HLFIR/Passes.td index 04d7aec5fe489..a66d40c11924c 100644 --- a/flang/include/flang/Optimizer/HLFIR/Passes.td +++ b/flang/include/flang/Optimizer/HLFIR/Passes.td @@ -34,6 +34,10 @@ def LowerHLFIRIntrinsics : Pass<"lower-hlfir-intrinsics", "::mlir::ModuleOp"> { let summary = "Lower HLFIR transformational intrinsic operations"; } +def XSComplexConversion : Pass<"XS-complex", "::mlir::ModuleOp"> { + let summary = "XSComplexConversion transformational intrinsic operations"; +} + def LowerHLFIROrderedAssignments : Pass<"lower-hlfir-ordered-assignments", "::mlir::ModuleOp"> { let summary = "Lower HLFIR ordered assignments like forall and where operations"; let options = [ diff --git a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt index cc74273d9c5d9..69bafc076f2a2 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt @@ -12,6 +12,7 @@ add_flang_library(HLFIRTransforms SimplifyHLFIRIntrinsics.cpp OptimizedBufferization.cpp PropagateFortranVariableAttributes.cpp + XSComplexConversion.cpp DEPENDS CUFAttrsIncGen diff --git a/flang/lib/Optimizer/HLFIR/Transforms/XSComplexConversion.cpp b/flang/lib/Optimizer/HLFIR/Transforms/XSComplexConversion.cpp new file mode 100644 index 0000000000000..f82818c1fbb10 --- /dev/null +++ b/flang/lib/Optimizer/HLFIR/Transforms/XSComplexConversion.cpp @@ -0,0 +1,98 @@ +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Builder/HLFIRTools.h" +#include "flang/Optimizer/Builder/Todo.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Dialect/Support/FIRContext.h" +#include "flang/Optimizer/HLFIR/HLFIRDialect.h" +#include "flang/Optimizer/HLFIR/HLFIROps.h" +#include "flang/Optimizer/HLFIR/Passes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include +#include "llvm/Support/Debug.h" +#define DEBUG_TYPE "xs-complex-conversion" + +namespace hlfir { +#define GEN_PASS_DEF_XSCOMPLEXCONVERSION +#include "flang/Optimizer/HLFIR/Passes.h.inc" +} + +static llvm::cl::opt EnableXSDivc( + "enable-XSDivc", llvm::cl::init(false), llvm::cl::Hidden, + llvm::cl::desc("Enable calling of XSDivc.")); + +namespace { +class HlfirXSComplexConversion : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + llvm::LogicalResult matchAndRewrite(fir::CallOp callOp, + mlir::PatternRewriter &rewriter) const override { + if (!EnableXSDivc) { + LLVM_DEBUG(llvm::dbgs() << "XS Complex Division support is currently disabled \n"); + return mlir::failure(); + } + fir::FirOpBuilder builder{rewriter, callOp.getOperation()}; + const mlir::Location &loc = callOp.getLoc(); + if (!callOp.getCallee()) { + LLVM_DEBUG(llvm::dbgs() << "No callee found for CallOp at " << loc << "\n"); + return mlir::failure(); + } + + const mlir::SymbolRefAttr &callee = *callOp.getCallee(); + const auto &fctName = callee.getRootReference().getValue(); + if (fctName!= "__divdc3") + return mlir::failure(); + + const mlir::Type &eleTy = callOp.getOperands()[0].getType(); + const mlir::Type &resTy = callOp.getResult(0).getType(); + + auto x0 = callOp.getOperands()[0]; + auto y0 = callOp.getOperands()[1]; + auto x1 = callOp.getOperands()[2]; + auto y1 = callOp.getOperands()[3]; + + auto xx = rewriter.create(loc, eleTy, x0, x1); + auto x1x1 = rewriter.create(loc, eleTy, x1, x1); + auto yx = rewriter.create(loc, eleTy, y0, x1); + auto xy = rewriter.create(loc, eleTy, x0, y1); + auto yy = rewriter.create(loc, eleTy, y0, y1); + auto y1y1 = rewriter.create(loc, eleTy, y1, y1); + auto d = rewriter.create(loc, eleTy, x1x1, y1y1); + auto rrn = rewriter.create(loc, eleTy, xx, yy); + auto rin = rewriter.create(loc, eleTy, yx, xy); + auto rr = rewriter.create(loc, eleTy, rrn, d); + auto ri = rewriter.create(loc, eleTy, rin, d); + auto ra = rewriter.create(loc, resTy); + auto indexAttr0 = builder.getArrayAttr({builder.getI32IntegerAttr(0)}); + auto indexAttr1 = builder.getArrayAttr({builder.getI32IntegerAttr(1)}); + auto r1 = rewriter.create(loc,resTy, ra, rr, indexAttr0); + auto r0 = rewriter.create(loc,resTy, r1, ri, indexAttr1); + rewriter.replaceOp(callOp, r0.getResult()); + return mlir::success(); + } +}; + +class XSComplexConversion : public hlfir::impl::XSComplexConversionBase { +public: + void runOnOperation() override { + mlir::ModuleOp module = this->getOperation(); + mlir::MLIRContext *context = &getContext(); + mlir::RewritePatternSet patterns(context); + patterns.insert(context); + + mlir::GreedyRewriteConfig config; + config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled); + + if (mlir::failed(mlir::applyPatternsGreedily(module, std::move(patterns), config))) + { + mlir::emitError(mlir::UnknownLoc::get(context), "failure in XS Complex HLFIR intrinsic lowering"); + signalPassFailure(); + } + } +}; +} \ No newline at end of file diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 70f57bdeddd3f..7c48282bf0f67 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -267,6 +267,7 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP, } } pm.addPass(hlfir::createLowerHLFIROrderedAssignments()); + pm.addPass(hlfir::createXSComplexConversion()); pm.addPass(hlfir::createLowerHLFIRIntrinsics()); hlfir::BufferizeHLFIROptions bufferizeOptions; From 62130a287eea3a544d4a28f65ca5a60f081a568f Mon Sep 17 00:00:00 2001 From: Hanwen Date: Fri, 20 Jun 2025 17:24:36 +0800 Subject: [PATCH 02/13] revise code naming conventions --- flang/include/flang/Optimizer/HLFIR/Passes.td | 4 +- .../Optimizer/HLFIR/Transforms/CMakeLists.txt | 2 +- ...on.cpp => XSComplexDivisionConversion.cpp} | 61 +++++++++++-------- flang/lib/Optimizer/Passes/Pipelines.cpp | 2 +- 4 files changed, 39 insertions(+), 30 deletions(-) rename flang/lib/Optimizer/HLFIR/Transforms/{XSComplexConversion.cpp => XSComplexDivisionConversion.cpp} (50%) diff --git a/flang/include/flang/Optimizer/HLFIR/Passes.td b/flang/include/flang/Optimizer/HLFIR/Passes.td index a66d40c11924c..9e67743d2c410 100644 --- a/flang/include/flang/Optimizer/HLFIR/Passes.td +++ b/flang/include/flang/Optimizer/HLFIR/Passes.td @@ -34,8 +34,8 @@ def LowerHLFIRIntrinsics : Pass<"lower-hlfir-intrinsics", "::mlir::ModuleOp"> { let summary = "Lower HLFIR transformational intrinsic operations"; } -def XSComplexConversion : Pass<"XS-complex", "::mlir::ModuleOp"> { - let summary = "XSComplexConversion transformational intrinsic operations"; +def XSComplexDivisionConversion : Pass<"XS-complex", "::mlir::ModuleOp"> { + let summary = "XSComplexDivisionConversion transformational intrinsic operations"; } def LowerHLFIROrderedAssignments : Pass<"lower-hlfir-ordered-assignments", "::mlir::ModuleOp"> { diff --git a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt index 69bafc076f2a2..12ed3cefabfba 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt @@ -12,7 +12,7 @@ add_flang_library(HLFIRTransforms SimplifyHLFIRIntrinsics.cpp OptimizedBufferization.cpp PropagateFortranVariableAttributes.cpp - XSComplexConversion.cpp + XSComplexDivisionConversion.cpp DEPENDS CUFAttrsIncGen diff --git a/flang/lib/Optimizer/HLFIR/Transforms/XSComplexConversion.cpp b/flang/lib/Optimizer/HLFIR/Transforms/XSComplexDivisionConversion.cpp similarity index 50% rename from flang/lib/Optimizer/HLFIR/Transforms/XSComplexConversion.cpp rename to flang/lib/Optimizer/HLFIR/Transforms/XSComplexDivisionConversion.cpp index f82818c1fbb10..2a63d12d6bd3c 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/XSComplexConversion.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/XSComplexDivisionConversion.cpp @@ -19,7 +19,7 @@ #define DEBUG_TYPE "xs-complex-conversion" namespace hlfir { -#define GEN_PASS_DEF_XSCOMPLEXCONVERSION +#define GEN_PASS_DEF_XSCOMPLEXDIVISIONCONVERSION #include "flang/Optimizer/HLFIR/Passes.h.inc" } @@ -28,7 +28,7 @@ static llvm::cl::opt EnableXSDivc( llvm::cl::desc("Enable calling of XSDivc.")); namespace { -class HlfirXSComplexConversion : public mlir::OpRewritePattern { +class HlfirXSComplexDivisionConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; llvm::LogicalResult matchAndRewrite(fir::CallOp callOp, mlir::PatternRewriter &rewriter) const override { @@ -51,46 +51,55 @@ class HlfirXSComplexConversion : public mlir::OpRewritePattern { const mlir::Type &eleTy = callOp.getOperands()[0].getType(); const mlir::Type &resTy = callOp.getResult(0).getType(); - auto x0 = callOp.getOperands()[0]; - auto y0 = callOp.getOperands()[1]; - auto x1 = callOp.getOperands()[2]; - auto y1 = callOp.getOperands()[3]; + auto x0 = callOp.getOperands()[0]; // real part of numerator : x0 + auto y0 = callOp.getOperands()[1]; // imaginary part of numerator : y0 + auto x1 = callOp.getOperands()[2]; // real part of denominator : x1 + auto y1 = callOp.getOperands()[3]; // imaginary part of denominator : y1 - auto xx = rewriter.create(loc, eleTy, x0, x1); - auto x1x1 = rewriter.create(loc, eleTy, x1, x1); - auto yx = rewriter.create(loc, eleTy, y0, x1); - auto xy = rewriter.create(loc, eleTy, x0, y1); - auto yy = rewriter.create(loc, eleTy, y0, y1); - auto y1y1 = rewriter.create(loc, eleTy, y1, y1); - auto d = rewriter.create(loc, eleTy, x1x1, y1y1); - auto rrn = rewriter.create(loc, eleTy, xx, yy); - auto rin = rewriter.create(loc, eleTy, yx, xy); - auto rr = rewriter.create(loc, eleTy, rrn, d); - auto ri = rewriter.create(loc, eleTy, rin, d); - auto ra = rewriter.create(loc, resTy); - auto indexAttr0 = builder.getArrayAttr({builder.getI32IntegerAttr(0)}); - auto indexAttr1 = builder.getArrayAttr({builder.getI32IntegerAttr(1)}); - auto r1 = rewriter.create(loc,resTy, ra, rr, indexAttr0); - auto r0 = rewriter.create(loc,resTy, r1, ri, indexAttr1); - rewriter.replaceOp(callOp, r0.getResult()); + auto x0x1 = rewriter.create(loc, eleTy, x0, x1); // x0 * x1 + auto x1Squared = rewriter.create(loc, eleTy, x1, x1); // x1^2 + auto y0x1 = rewriter.create(loc, eleTy, y0, x1); // y0 * x1 + auto x0y1 = rewriter.create(loc, eleTy, x0, y1); // x0 * y1 + auto y0y1 = rewriter.create(loc, eleTy, y0, y1); // y0 * y1 + auto y1Squared = rewriter.create(loc, eleTy, y1, y1); // y1^2 + + // compute denominator: x1^2 + y1^2 + auto denom = rewriter.create(loc, eleTy, x1Squared, y1Squared); + + // compute real numerator: x0*x1 + y0*y1 + auto realNumerator = rewriter.create(loc, eleTy, x0x1, y0y1); + // compute imag numerator: y0*x1 - x0*y1 + auto imagNumerator = rewriter.create(loc, eleTy, y0x1, x0y1); + + // compute final real and imaginary parts + auto realResult = rewriter.create(loc, eleTy, realNumerator, denom); + auto imagResult = rewriter.create(loc, eleTy, imagNumerator, denom); + + // construct the result complex number + auto undefComplex = rewriter.create(loc, resTy); + auto index0 = builder.getArrayAttr({builder.getI32IntegerAttr(0)}); // index for real part + auto index1 = builder.getArrayAttr({builder.getI32IntegerAttr(1)}); // index for imag part + auto complexWithReal = rewriter.create(loc, resTy, undefComplex, realResult, index0); // Insert real part + auto resComplex = rewriter.create(loc, resTy, complexWithReal, imagResult, index1); // Insert imaginary part + rewriter.replaceOp(callOp, resComplex.getResult()); return mlir::success(); } }; -class XSComplexConversion : public hlfir::impl::XSComplexConversionBase { +class XSComplexDivisionConversion : public hlfir::impl::XSComplexDivisionConversionBase { public: void runOnOperation() override { mlir::ModuleOp module = this->getOperation(); mlir::MLIRContext *context = &getContext(); mlir::RewritePatternSet patterns(context); - patterns.insert(context); + patterns.insert(context); mlir::GreedyRewriteConfig config; config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled); if (mlir::failed(mlir::applyPatternsGreedily(module, std::move(patterns), config))) { - mlir::emitError(mlir::UnknownLoc::get(context), "failure in XS Complex HLFIR intrinsic lowering"); + mlir::emitError(mlir::UnknownLoc::get(context), "failure in XS Complex Division HLFIR intrinsic lowering"); signalPassFailure(); } } diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 7c48282bf0f67..6f728a018a4db 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -267,7 +267,7 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP, } } pm.addPass(hlfir::createLowerHLFIROrderedAssignments()); - pm.addPass(hlfir::createXSComplexConversion()); + pm.addPass(hlfir::createXSComplexDivisionConversion()); pm.addPass(hlfir::createLowerHLFIRIntrinsics()); hlfir::BufferizeHLFIROptions bufferizeOptions; From a74a6b10564f6d584c08a89b80374617b3a69592 Mon Sep 17 00:00:00 2001 From: Hanwen Date: Wed, 25 Jun 2025 10:39:45 +0800 Subject: [PATCH 03/13] revise code naming conventions --- flang/include/flang/Optimizer/HLFIR/Passes.td | 4 +-- ...thmeticBasedComplexDivisionConversion.cpp} | 33 +++++++++---------- .../Optimizer/HLFIR/Transforms/CMakeLists.txt | 2 +- flang/lib/Optimizer/Passes/Pipelines.cpp | 2 +- 4 files changed, 19 insertions(+), 22 deletions(-) rename flang/lib/Optimizer/HLFIR/Transforms/{XSComplexDivisionConversion.cpp => ArithmeticBasedComplexDivisionConversion.cpp} (79%) diff --git a/flang/include/flang/Optimizer/HLFIR/Passes.td b/flang/include/flang/Optimizer/HLFIR/Passes.td index 9e67743d2c410..9ac91121f434b 100644 --- a/flang/include/flang/Optimizer/HLFIR/Passes.td +++ b/flang/include/flang/Optimizer/HLFIR/Passes.td @@ -34,8 +34,8 @@ def LowerHLFIRIntrinsics : Pass<"lower-hlfir-intrinsics", "::mlir::ModuleOp"> { let summary = "Lower HLFIR transformational intrinsic operations"; } -def XSComplexDivisionConversion : Pass<"XS-complex", "::mlir::ModuleOp"> { - let summary = "XSComplexDivisionConversion transformational intrinsic operations"; +def ComplexDivisionConversion : Pass<"complex-division-conversion", "::mlir::ModuleOp"> { + let summary = "ComplexDivisionConversion transformational intrinsic operations"; } def LowerHLFIROrderedAssignments : Pass<"lower-hlfir-ordered-assignments", "::mlir::ModuleOp"> { diff --git a/flang/lib/Optimizer/HLFIR/Transforms/XSComplexDivisionConversion.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ArithmeticBasedComplexDivisionConversion.cpp similarity index 79% rename from flang/lib/Optimizer/HLFIR/Transforms/XSComplexDivisionConversion.cpp rename to flang/lib/Optimizer/HLFIR/Transforms/ArithmeticBasedComplexDivisionConversion.cpp index 2a63d12d6bd3c..424c0dc1a3e50 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/XSComplexDivisionConversion.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ArithmeticBasedComplexDivisionConversion.cpp @@ -19,21 +19,21 @@ #define DEBUG_TYPE "xs-complex-conversion" namespace hlfir { -#define GEN_PASS_DEF_XSCOMPLEXDIVISIONCONVERSION +#define GEN_PASS_DEF_COMPLEXDIVISIONCONVERSION #include "flang/Optimizer/HLFIR/Passes.h.inc" } -static llvm::cl::opt EnableXSDivc( - "enable-XSDivc", llvm::cl::init(false), llvm::cl::Hidden, - llvm::cl::desc("Enable calling of XSDivc.")); +static llvm::cl::opt EnableArithmeticBasedComplexDiv( + "enable-arithmetic-based-complex-divsion", llvm::cl::init(false), llvm::cl::Hidden, + llvm::cl::desc("Enable calling of Arithmetic-based Complex Division.")); namespace { -class HlfirXSComplexDivisionConversion : public mlir::OpRewritePattern { +class HlfirArithmeticBasedComplexDivisionConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; llvm::LogicalResult matchAndRewrite(fir::CallOp callOp, mlir::PatternRewriter &rewriter) const override { - if (!EnableXSDivc) { - LLVM_DEBUG(llvm::dbgs() << "XS Complex Division support is currently disabled \n"); + if (!EnableArithmeticBasedComplexDiv) { + LLVM_DEBUG(llvm::dbgs() << "Arithmetic-based Complex Division support is currently disabled \n"); return mlir::failure(); } fir::FirOpBuilder builder{rewriter, callOp.getOperation()}; @@ -56,6 +56,8 @@ class HlfirXSComplexDivisionConversion : public mlir::OpRewritePattern(loc, eleTy, x0, x1); // x0 * x1 auto x1Squared = rewriter.create(loc, eleTy, x1, x1); // x1^2 auto y0x1 = rewriter.create(loc, eleTy, y0, x1); // y0 * x1 @@ -63,13 +65,9 @@ class HlfirXSComplexDivisionConversion : public mlir::OpRewritePattern(loc, eleTy, y0, y1); // y0 * y1 auto y1Squared = rewriter.create(loc, eleTy, y1, y1); // y1^2 - // compute denominator: x1^2 + y1^2 - auto denom = rewriter.create(loc, eleTy, x1Squared, y1Squared); - - // compute real numerator: x0*x1 + y0*y1 - auto realNumerator = rewriter.create(loc, eleTy, x0x1, y0y1); - // compute imag numerator: y0*x1 - x0*y1 - auto imagNumerator = rewriter.create(loc, eleTy, y0x1, x0y1); + auto denom = rewriter.create(loc, eleTy, x1Squared, y1Squared); // x1^2 + y1^2 + auto realNumerator = rewriter.create(loc, eleTy, x0x1, y0y1); // x0*x1 + y0*y1 + auto imagNumerator = rewriter.create(loc, eleTy, y0x1, x0y1); // y0*x1 - x0*y1 // compute final real and imaginary parts auto realResult = rewriter.create(loc, eleTy, realNumerator, denom); @@ -85,21 +83,20 @@ class HlfirXSComplexDivisionConversion : public mlir::OpRewritePattern { +class ArithmeticBasedComplexDivisionConversion : public hlfir::impl::ComplexDivisionConversionBase { public: void runOnOperation() override { mlir::ModuleOp module = this->getOperation(); mlir::MLIRContext *context = &getContext(); mlir::RewritePatternSet patterns(context); - patterns.insert(context); + patterns.insert(context); mlir::GreedyRewriteConfig config; config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled); if (mlir::failed(mlir::applyPatternsGreedily(module, std::move(patterns), config))) { - mlir::emitError(mlir::UnknownLoc::get(context), "failure in XS Complex Division HLFIR intrinsic lowering"); + mlir::emitError(mlir::UnknownLoc::get(context), "failure in Arithmetic-based Complex Division HLFIR intrinsic lowering"); signalPassFailure(); } } diff --git a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt index 12ed3cefabfba..100d4ca263d38 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt @@ -12,7 +12,7 @@ add_flang_library(HLFIRTransforms SimplifyHLFIRIntrinsics.cpp OptimizedBufferization.cpp PropagateFortranVariableAttributes.cpp - XSComplexDivisionConversion.cpp + ArithmeticBasedComplexDivisionConversion.cpp DEPENDS CUFAttrsIncGen diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 6f728a018a4db..19fc86375248e 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -267,7 +267,7 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP, } } pm.addPass(hlfir::createLowerHLFIROrderedAssignments()); - pm.addPass(hlfir::createXSComplexDivisionConversion()); + pm.addPass(hlfir::createComplexDivisionConversion()); pm.addPass(hlfir::createLowerHLFIRIntrinsics()); hlfir::BufferizeHLFIROptions bufferizeOptions; From 68f56007553dae10746f2c87dca8fd4356402730 Mon Sep 17 00:00:00 2001 From: Hanwen Date: Wed, 25 Jun 2025 20:38:32 +0800 Subject: [PATCH 04/13] revise code naming conventions --- .../Transforms/ArithmeticBasedComplexDivisionConversion.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ArithmeticBasedComplexDivisionConversion.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ArithmeticBasedComplexDivisionConversion.cpp index 424c0dc1a3e50..37e80ae8c9668 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ArithmeticBasedComplexDivisionConversion.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ArithmeticBasedComplexDivisionConversion.cpp @@ -24,7 +24,7 @@ namespace hlfir { } static llvm::cl::opt EnableArithmeticBasedComplexDiv( - "enable-arithmetic-based-complex-divsion", llvm::cl::init(false), llvm::cl::Hidden, + "enable-arithmetic-based-complex-div", llvm::cl::init(false), llvm::cl::Hidden, llvm::cl::desc("Enable calling of Arithmetic-based Complex Division.")); namespace { From 29faf80e811f81cdd966ff99ae5b0e1f256d1fb4 Mon Sep 17 00:00:00 2001 From: Hanwen Date: Wed, 25 Jun 2025 20:41:07 +0800 Subject: [PATCH 05/13] add test case for the frontend pass of complex-division-conversion --- .../Fir/target-rewrite-complex-division.fir | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 flang/test/Fir/target-rewrite-complex-division.fir diff --git a/flang/test/Fir/target-rewrite-complex-division.fir b/flang/test/Fir/target-rewrite-complex-division.fir new file mode 100644 index 0000000000000..d9d149444d519 --- /dev/null +++ b/flang/test/Fir/target-rewrite-complex-division.fir @@ -0,0 +1,31 @@ +// RUN: fir-opt %s --complex-division-conversion --enable-arithmetic-based-complex-div | FileCheck %s + +func.func @test_double_complex_div(%arg0: !fir.ref>, %arg1: !fir.ref>, %arg2: !fir.ref>) { + %0 = fir.load %arg1 : !fir.ref> + %1 = fir.load %arg2 : !fir.ref> + %2 = fir.extract_value %0, [0 : index] : (complex) -> f64 + %3 = fir.extract_value %0, [1 : index] : (complex) -> f64 + %4 = fir.extract_value %1, [0 : index] : (complex) -> f64 + %5 = fir.extract_value %1, [1 : index] : (complex) -> f64 + %6 = fir.call @__divdc3(%2, %3, %4, %5) fastmath : (f64, f64, f64, f64) -> complex + fir.store %6 to %arg0 : !fir.ref> + return +} + +// CHECK-LABEL: func.func @test_double_complex_div +// CHECK-NOT: fir.call @__divdc3 +// CHECK: %[[R1:.*]] = arith.mulf %2, %4 : f64 +// CHECK: %[[R2:.*]] = arith.mulf %4, %4 : f64 +// CHECK: %[[R3:.*]] = arith.mulf %3, %4 : f64 +// CHECK: %[[R4:.*]] = arith.mulf %2, %5 : f64 +// CHECK: %[[R5:.*]] = arith.mulf %3, %5 : f64 +// CHECK: %[[R6:.*]] = arith.mulf %5, %5 : f64 +// CHECK: %[[DENOM:.*]] = arith.addf %[[R2]], %[[R6]] : f64 +// CHECK: %[[NUM_RE:.*]] = arith.addf %[[R1]], %[[R5]] : f64 +// CHECK: %[[NUM_IM:.*]] = arith.subf %[[R3]], %[[R4]] : f64 +// CHECK: %[[RES_RE:.*]] = arith.divf %[[NUM_RE]], %[[DENOM]] : f64 +// CHECK: %[[RES_IM:.*]] = arith.divf %[[NUM_IM]], %[[DENOM]] : f64 +// CHECK: %[[U:.*]] = fir.undefined complex +// CHECK: %[[C0:.*]] = fir.insert_value %[[U]], %[[RES_RE]], [0 : i32] : (complex, f64) -> complex +// CHECK: %[[C1:.*]] = fir.insert_value %[[C0]], %[[RES_IM]], [1 : i32] : (complex, f64) -> complex +// CHECK: fir.store %[[C1]] to %arg0 : !fir.ref> \ No newline at end of file From 8ed6624cab28cfd490d982dc914e51b7eee83f82 Mon Sep 17 00:00:00 2001 From: Hanwen Date: Wed, 25 Jun 2025 20:47:14 +0800 Subject: [PATCH 06/13] revise test case naming conventions --- .../Fir/target-rewrite-complex-division.fir | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/flang/test/Fir/target-rewrite-complex-division.fir b/flang/test/Fir/target-rewrite-complex-division.fir index d9d149444d519..6fae94ca3e225 100644 --- a/flang/test/Fir/target-rewrite-complex-division.fir +++ b/flang/test/Fir/target-rewrite-complex-division.fir @@ -20,12 +20,12 @@ func.func @test_double_complex_div(%arg0: !fir.ref>, %arg1: !fir.re // CHECK: %[[R4:.*]] = arith.mulf %2, %5 : f64 // CHECK: %[[R5:.*]] = arith.mulf %3, %5 : f64 // CHECK: %[[R6:.*]] = arith.mulf %5, %5 : f64 -// CHECK: %[[DENOM:.*]] = arith.addf %[[R2]], %[[R6]] : f64 -// CHECK: %[[NUM_RE:.*]] = arith.addf %[[R1]], %[[R5]] : f64 -// CHECK: %[[NUM_IM:.*]] = arith.subf %[[R3]], %[[R4]] : f64 -// CHECK: %[[RES_RE:.*]] = arith.divf %[[NUM_RE]], %[[DENOM]] : f64 -// CHECK: %[[RES_IM:.*]] = arith.divf %[[NUM_IM]], %[[DENOM]] : f64 -// CHECK: %[[U:.*]] = fir.undefined complex -// CHECK: %[[C0:.*]] = fir.insert_value %[[U]], %[[RES_RE]], [0 : i32] : (complex, f64) -> complex -// CHECK: %[[C1:.*]] = fir.insert_value %[[C0]], %[[RES_IM]], [1 : i32] : (complex, f64) -> complex -// CHECK: fir.store %[[C1]] to %arg0 : !fir.ref> \ No newline at end of file +// CHECK: %[[DENOM:.*]] = arith.addf %[[R2]], %[[R6]] : f64 +// CHECK: %[[NUM_RE:.*]] = arith.addf %[[R1]], %[[R5]] : f64 +// CHECK: %[[NUM_IM:.*]] = arith.subf %[[R3]], %[[R4]] : f64 +// CHECK: %[[RES_RE:.*]] = arith.divf %[[NUM_RE]], %[[DENOM]] : f64 +// CHECK: %[[RES_IM:.*]] = arith.divf %[[NUM_IM]], %[[DENOM]] : f64 +// CHECK: %[[U:.*]] = fir.undefined complex +// CHECK: %[[C0:.*]] = fir.insert_value %[[U]], %[[RES_RE]], [0 : i32] : (complex, f64) -> complex +// CHECK: %[[C1:.*]] = fir.insert_value %[[C0]], %[[RES_IM]], [1 : i32] : (complex, f64) -> complex +// CHECK: fir.store %[[C1]] to %arg0 : !fir.ref> \ No newline at end of file From 437f98977e6b7efbf5a2597b08154fb8d5467ad6 Mon Sep 17 00:00:00 2001 From: Hanwen Date: Thu, 26 Jun 2025 11:29:49 +0800 Subject: [PATCH 07/13] revise code naming conventions --- flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt | 2 +- ...visionConversion.cpp => ComplexDivisionConversion.cpp} | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) rename flang/lib/Optimizer/HLFIR/Transforms/{ArithmeticBasedComplexDivisionConversion.cpp => ComplexDivisionConversion.cpp} (92%) diff --git a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt index 100d4ca263d38..81ac978e38723 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt @@ -12,7 +12,7 @@ add_flang_library(HLFIRTransforms SimplifyHLFIRIntrinsics.cpp OptimizedBufferization.cpp PropagateFortranVariableAttributes.cpp - ArithmeticBasedComplexDivisionConversion.cpp + ComplexDivisionConversion.cpp DEPENDS CUFAttrsIncGen diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ArithmeticBasedComplexDivisionConversion.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ComplexDivisionConversion.cpp similarity index 92% rename from flang/lib/Optimizer/HLFIR/Transforms/ArithmeticBasedComplexDivisionConversion.cpp rename to flang/lib/Optimizer/HLFIR/Transforms/ComplexDivisionConversion.cpp index 37e80ae8c9668..d227959e94080 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ArithmeticBasedComplexDivisionConversion.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ComplexDivisionConversion.cpp @@ -16,7 +16,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include #include "llvm/Support/Debug.h" -#define DEBUG_TYPE "xs-complex-conversion" +#define DEBUG_TYPE "complex-conversion" namespace hlfir { #define GEN_PASS_DEF_COMPLEXDIVISIONCONVERSION @@ -28,7 +28,7 @@ static llvm::cl::opt EnableArithmeticBasedComplexDiv( llvm::cl::desc("Enable calling of Arithmetic-based Complex Division.")); namespace { -class HlfirArithmeticBasedComplexDivisionConversion : public mlir::OpRewritePattern { +class HlfirComplexDivisionConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; llvm::LogicalResult matchAndRewrite(fir::CallOp callOp, mlir::PatternRewriter &rewriter) const override { @@ -83,13 +83,13 @@ class HlfirArithmeticBasedComplexDivisionConversion : public mlir::OpRewritePatt return mlir::success(); } }; -class ArithmeticBasedComplexDivisionConversion : public hlfir::impl::ComplexDivisionConversionBase { +class ComplexDivisionConversion : public hlfir::impl::ComplexDivisionConversionBase { public: void runOnOperation() override { mlir::ModuleOp module = this->getOperation(); mlir::MLIRContext *context = &getContext(); mlir::RewritePatternSet patterns(context); - patterns.insert(context); + patterns.insert(context); mlir::GreedyRewriteConfig config; config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled); From ba91f7875225370a86d8e9adba8690b07d93011f Mon Sep 17 00:00:00 2001 From: Hanwen Date: Thu, 26 Jun 2025 14:21:02 +0800 Subject: [PATCH 08/13] apply clang-format check --- .../Transforms/ComplexDivisionConversion.cpp | 109 ++++++++++++------ 1 file changed, 76 insertions(+), 33 deletions(-) diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ComplexDivisionConversion.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ComplexDivisionConversion.cpp index d227959e94080..63ede2cb5c185 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/ComplexDivisionConversion.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ComplexDivisionConversion.cpp @@ -14,38 +14,59 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include #include "llvm/Support/Debug.h" +#include #define DEBUG_TYPE "complex-conversion" namespace hlfir { #define GEN_PASS_DEF_COMPLEXDIVISIONCONVERSION #include "flang/Optimizer/HLFIR/Passes.h.inc" -} +} // namespace hlfir static llvm::cl::opt EnableArithmeticBasedComplexDiv( - "enable-arithmetic-based-complex-div", llvm::cl::init(false), llvm::cl::Hidden, + "enable-arithmetic-based-complex-div", llvm::cl::init(false), + llvm::cl::Hidden, llvm::cl::desc("Enable calling of Arithmetic-based Complex Division.")); namespace { -class HlfirComplexDivisionConversion : public mlir::OpRewritePattern { +/// This rewrite pattern class performs a custom transformation on FIR +/// 'fir.call' operations that invoke the '__divdc3' runtime function, which is +/// typically used to perform double-precision complex division. +/// +/// When the 'EnableArithmeticBasedComplexDiv' flag is enabled, this pattern +/// matches calls to '__divdc3', extracts the real and imaginary components of +/// the numerator and denominator, and replaces the function call with an +/// explicit computation using MLIR's arithmetic operations. +/// +/// Specifically, it replaces the call to '__divdc3(x0, y0, x1, y1)' —where +/// (x0 + y0i) / (x1 + y1i) is the intended operation—with the mathematically +/// equivalent expression: +/// real_part = (x0*x1 + y0*y1) / (x1^2 + y1^2) +/// imag_part = (y0*x1 - x0*y1) / (x1^2 + y1^2) +/// The result is then reassembled into a 'complex' value using FIR's +/// 'InsertValueOp' instructions. +class HlfirComplexDivisionConversion + : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; - llvm::LogicalResult matchAndRewrite(fir::CallOp callOp, - mlir::PatternRewriter &rewriter) const override { + llvm::LogicalResult + matchAndRewrite(fir::CallOp callOp, + mlir::PatternRewriter &rewriter) const override { if (!EnableArithmeticBasedComplexDiv) { - LLVM_DEBUG(llvm::dbgs() << "Arithmetic-based Complex Division support is currently disabled \n"); + LLVM_DEBUG(llvm::dbgs() << "Arithmetic-based Complex Division support is " + "currently disabled \n"); return mlir::failure(); } fir::FirOpBuilder builder{rewriter, callOp.getOperation()}; const mlir::Location &loc = callOp.getLoc(); if (!callOp.getCallee()) { - LLVM_DEBUG(llvm::dbgs() << "No callee found for CallOp at " << loc << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "No callee found for CallOp at " << loc << "\n"); return mlir::failure(); - } + } const mlir::SymbolRefAttr &callee = *callOp.getCallee(); const auto &fctName = callee.getRootReference().getValue(); - if (fctName!= "__divdc3") + if (fctName != "__divdc3") return mlir::failure(); const mlir::Type &eleTy = callOp.getOperands()[0].getType(); @@ -57,48 +78,70 @@ class HlfirComplexDivisionConversion : public mlir::OpRewritePattern(loc, eleTy, x0, x1); // x0 * x1 - auto x1Squared = rewriter.create(loc, eleTy, x1, x1); // x1^2 - auto y0x1 = rewriter.create(loc, eleTy, y0, x1); // y0 * x1 - auto x0y1 = rewriter.create(loc, eleTy, x0, y1); // x0 * y1 - auto y0y1 = rewriter.create(loc, eleTy, y0, y1); // y0 * y1 - auto y1Squared = rewriter.create(loc, eleTy, y1, y1); // y1^2 + // (x0 + y0i)/(x1 + y1i) = ((x0*x1 + y0*y1)/(x1^2 + y1^2)) + ((y0*x1 - + // x0*y1)/(x1^2 + y1^2))i + auto x0x1 = + rewriter.create(loc, eleTy, x0, x1); // x0 * x1 + auto x1Squared = + rewriter.create(loc, eleTy, x1, x1); // x1^2 + auto y0x1 = + rewriter.create(loc, eleTy, y0, x1); // y0 * x1 + auto x0y1 = + rewriter.create(loc, eleTy, x0, y1); // x0 * y1 + auto y0y1 = + rewriter.create(loc, eleTy, y0, y1); // y0 * y1 + auto y1Squared = + rewriter.create(loc, eleTy, y1, y1); // y1^2 - auto denom = rewriter.create(loc, eleTy, x1Squared, y1Squared); // x1^2 + y1^2 - auto realNumerator = rewriter.create(loc, eleTy, x0x1, y0y1); // x0*x1 + y0*y1 - auto imagNumerator = rewriter.create(loc, eleTy, y0x1, x0y1); // y0*x1 - x0*y1 + auto denom = rewriter.create(loc, eleTy, x1Squared, + y1Squared); // x1^2 + y1^2 + auto realNumerator = rewriter.create( + loc, eleTy, x0x1, y0y1); // x0*x1 + y0*y1 + auto imagNumerator = rewriter.create( + loc, eleTy, y0x1, x0y1); // y0*x1 - x0*y1 // compute final real and imaginary parts - auto realResult = rewriter.create(loc, eleTy, realNumerator, denom); - auto imagResult = rewriter.create(loc, eleTy, imagNumerator, denom); + auto realResult = + rewriter.create(loc, eleTy, realNumerator, denom); + auto imagResult = + rewriter.create(loc, eleTy, imagNumerator, denom); // construct the result complex number auto undefComplex = rewriter.create(loc, resTy); - auto index0 = builder.getArrayAttr({builder.getI32IntegerAttr(0)}); // index for real part - auto index1 = builder.getArrayAttr({builder.getI32IntegerAttr(1)}); // index for imag part - auto complexWithReal = rewriter.create(loc, resTy, undefComplex, realResult, index0); // Insert real part - auto resComplex = rewriter.create(loc, resTy, complexWithReal, imagResult, index1); // Insert imaginary part + auto index0 = builder.getArrayAttr( + {builder.getI32IntegerAttr(0)}); // index for real part + auto index1 = builder.getArrayAttr( + {builder.getI32IntegerAttr(1)}); // index for imag part + auto complexWithReal = rewriter.create( + loc, resTy, undefComplex, realResult, index0); // Insert real part + auto resComplex = rewriter.create( + loc, resTy, complexWithReal, imagResult, + index1); // Insert imaginary part rewriter.replaceOp(callOp, resComplex.getResult()); return mlir::success(); } }; -class ComplexDivisionConversion : public hlfir::impl::ComplexDivisionConversionBase { +class ComplexDivisionConversion + : public hlfir::impl::ComplexDivisionConversionBase< + ComplexDivisionConversion> { public: void runOnOperation() override { mlir::ModuleOp module = this->getOperation(); mlir::MLIRContext *context = &getContext(); mlir::RewritePatternSet patterns(context); patterns.insert(context); - + mlir::GreedyRewriteConfig config; - config.setRegionSimplificationLevel(mlir::GreedySimplifyRegionLevel::Disabled); + config.setRegionSimplificationLevel( + mlir::GreedySimplifyRegionLevel::Disabled); - if (mlir::failed(mlir::applyPatternsGreedily(module, std::move(patterns), config))) - { - mlir::emitError(mlir::UnknownLoc::get(context), "failure in Arithmetic-based Complex Division HLFIR intrinsic lowering"); + if (mlir::failed( + mlir::applyPatternsGreedily(module, std::move(patterns), config))) { + mlir::emitError(mlir::UnknownLoc::get(context), + "failure in Arithmetic-based Complex Division HLFIR " + "intrinsic lowering"); signalPassFailure(); } } }; -} \ No newline at end of file +} // namespace \ No newline at end of file From 7b0b65506568bea40d66ce6cb2297d38b5f93873 Mon Sep 17 00:00:00 2001 From: Hanwen Date: Thu, 26 Jun 2025 16:08:02 +0800 Subject: [PATCH 09/13] revise complex div pass format and location --- flang/include/flang/Optimizer/HLFIR/Passes.td | 3 - .../Optimizer/HLFIR/Transforms/CMakeLists.txt | 1 - .../Transforms/ComplexDivisionConversion.cpp | 147 ------------------ .../Transforms/SimplifyHLFIRIntrinsics.cpp | 97 ++++++++++++ flang/lib/Optimizer/Passes/Pipelines.cpp | 1 - .../Fir/target-rewrite-complex-division.fir | 2 +- 6 files changed, 98 insertions(+), 153 deletions(-) delete mode 100644 flang/lib/Optimizer/HLFIR/Transforms/ComplexDivisionConversion.cpp diff --git a/flang/include/flang/Optimizer/HLFIR/Passes.td b/flang/include/flang/Optimizer/HLFIR/Passes.td index 9ac91121f434b..e59dc952537d2 100644 --- a/flang/include/flang/Optimizer/HLFIR/Passes.td +++ b/flang/include/flang/Optimizer/HLFIR/Passes.td @@ -34,9 +34,6 @@ def LowerHLFIRIntrinsics : Pass<"lower-hlfir-intrinsics", "::mlir::ModuleOp"> { let summary = "Lower HLFIR transformational intrinsic operations"; } -def ComplexDivisionConversion : Pass<"complex-division-conversion", "::mlir::ModuleOp"> { - let summary = "ComplexDivisionConversion transformational intrinsic operations"; -} def LowerHLFIROrderedAssignments : Pass<"lower-hlfir-ordered-assignments", "::mlir::ModuleOp"> { let summary = "Lower HLFIR ordered assignments like forall and where operations"; diff --git a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt index 81ac978e38723..cc74273d9c5d9 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/HLFIR/Transforms/CMakeLists.txt @@ -12,7 +12,6 @@ add_flang_library(HLFIRTransforms SimplifyHLFIRIntrinsics.cpp OptimizedBufferization.cpp PropagateFortranVariableAttributes.cpp - ComplexDivisionConversion.cpp DEPENDS CUFAttrsIncGen diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ComplexDivisionConversion.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ComplexDivisionConversion.cpp deleted file mode 100644 index 63ede2cb5c185..0000000000000 --- a/flang/lib/Optimizer/HLFIR/Transforms/ComplexDivisionConversion.cpp +++ /dev/null @@ -1,147 +0,0 @@ -#include "flang/Optimizer/Builder/FIRBuilder.h" -#include "flang/Optimizer/Builder/HLFIRTools.h" -#include "flang/Optimizer/Builder/Todo.h" -#include "flang/Optimizer/Dialect/FIRDialect.h" -#include "flang/Optimizer/Dialect/FIROps.h" -#include "flang/Optimizer/Dialect/FIRType.h" -#include "flang/Optimizer/Dialect/Support/FIRContext.h" -#include "flang/Optimizer/HLFIR/HLFIRDialect.h" -#include "flang/Optimizer/HLFIR/HLFIROps.h" -#include "flang/Optimizer/HLFIR/Passes.h" -#include "mlir/IR/BuiltinDialect.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/Support/Debug.h" -#include -#define DEBUG_TYPE "complex-conversion" - -namespace hlfir { -#define GEN_PASS_DEF_COMPLEXDIVISIONCONVERSION -#include "flang/Optimizer/HLFIR/Passes.h.inc" -} // namespace hlfir - -static llvm::cl::opt EnableArithmeticBasedComplexDiv( - "enable-arithmetic-based-complex-div", llvm::cl::init(false), - llvm::cl::Hidden, - llvm::cl::desc("Enable calling of Arithmetic-based Complex Division.")); - -namespace { -/// This rewrite pattern class performs a custom transformation on FIR -/// 'fir.call' operations that invoke the '__divdc3' runtime function, which is -/// typically used to perform double-precision complex division. -/// -/// When the 'EnableArithmeticBasedComplexDiv' flag is enabled, this pattern -/// matches calls to '__divdc3', extracts the real and imaginary components of -/// the numerator and denominator, and replaces the function call with an -/// explicit computation using MLIR's arithmetic operations. -/// -/// Specifically, it replaces the call to '__divdc3(x0, y0, x1, y1)' —where -/// (x0 + y0i) / (x1 + y1i) is the intended operation—with the mathematically -/// equivalent expression: -/// real_part = (x0*x1 + y0*y1) / (x1^2 + y1^2) -/// imag_part = (y0*x1 - x0*y1) / (x1^2 + y1^2) -/// The result is then reassembled into a 'complex' value using FIR's -/// 'InsertValueOp' instructions. -class HlfirComplexDivisionConversion - : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - llvm::LogicalResult - matchAndRewrite(fir::CallOp callOp, - mlir::PatternRewriter &rewriter) const override { - if (!EnableArithmeticBasedComplexDiv) { - LLVM_DEBUG(llvm::dbgs() << "Arithmetic-based Complex Division support is " - "currently disabled \n"); - return mlir::failure(); - } - fir::FirOpBuilder builder{rewriter, callOp.getOperation()}; - const mlir::Location &loc = callOp.getLoc(); - if (!callOp.getCallee()) { - LLVM_DEBUG(llvm::dbgs() - << "No callee found for CallOp at " << loc << "\n"); - return mlir::failure(); - } - - const mlir::SymbolRefAttr &callee = *callOp.getCallee(); - const auto &fctName = callee.getRootReference().getValue(); - if (fctName != "__divdc3") - return mlir::failure(); - - const mlir::Type &eleTy = callOp.getOperands()[0].getType(); - const mlir::Type &resTy = callOp.getResult(0).getType(); - - auto x0 = callOp.getOperands()[0]; // real part of numerator : x0 - auto y0 = callOp.getOperands()[1]; // imaginary part of numerator : y0 - auto x1 = callOp.getOperands()[2]; // real part of denominator : x1 - auto y1 = callOp.getOperands()[3]; // imaginary part of denominator : y1 - - // standard complex division formula: - // (x0 + y0i)/(x1 + y1i) = ((x0*x1 + y0*y1)/(x1^2 + y1^2)) + ((y0*x1 - - // x0*y1)/(x1^2 + y1^2))i - auto x0x1 = - rewriter.create(loc, eleTy, x0, x1); // x0 * x1 - auto x1Squared = - rewriter.create(loc, eleTy, x1, x1); // x1^2 - auto y0x1 = - rewriter.create(loc, eleTy, y0, x1); // y0 * x1 - auto x0y1 = - rewriter.create(loc, eleTy, x0, y1); // x0 * y1 - auto y0y1 = - rewriter.create(loc, eleTy, y0, y1); // y0 * y1 - auto y1Squared = - rewriter.create(loc, eleTy, y1, y1); // y1^2 - - auto denom = rewriter.create(loc, eleTy, x1Squared, - y1Squared); // x1^2 + y1^2 - auto realNumerator = rewriter.create( - loc, eleTy, x0x1, y0y1); // x0*x1 + y0*y1 - auto imagNumerator = rewriter.create( - loc, eleTy, y0x1, x0y1); // y0*x1 - x0*y1 - - // compute final real and imaginary parts - auto realResult = - rewriter.create(loc, eleTy, realNumerator, denom); - auto imagResult = - rewriter.create(loc, eleTy, imagNumerator, denom); - - // construct the result complex number - auto undefComplex = rewriter.create(loc, resTy); - auto index0 = builder.getArrayAttr( - {builder.getI32IntegerAttr(0)}); // index for real part - auto index1 = builder.getArrayAttr( - {builder.getI32IntegerAttr(1)}); // index for imag part - auto complexWithReal = rewriter.create( - loc, resTy, undefComplex, realResult, index0); // Insert real part - auto resComplex = rewriter.create( - loc, resTy, complexWithReal, imagResult, - index1); // Insert imaginary part - rewriter.replaceOp(callOp, resComplex.getResult()); - return mlir::success(); - } -}; -class ComplexDivisionConversion - : public hlfir::impl::ComplexDivisionConversionBase< - ComplexDivisionConversion> { -public: - void runOnOperation() override { - mlir::ModuleOp module = this->getOperation(); - mlir::MLIRContext *context = &getContext(); - mlir::RewritePatternSet patterns(context); - patterns.insert(context); - - mlir::GreedyRewriteConfig config; - config.setRegionSimplificationLevel( - mlir::GreedySimplifyRegionLevel::Disabled); - - if (mlir::failed( - mlir::applyPatternsGreedily(module, std::move(patterns), config))) { - mlir::emitError(mlir::UnknownLoc::get(context), - "failure in Arithmetic-based Complex Division HLFIR " - "intrinsic lowering"); - signalPassFailure(); - } - } -}; -} // namespace \ No newline at end of file diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index 79582390d1294..bdf465ee605e5 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -35,6 +35,9 @@ static llvm::cl::opt forceMatmulAsElemental( llvm::cl::desc("Expand hlfir.matmul as elemental operation"), llvm::cl::init(false)); +static llvm::cl::opt EnableComplexDivConverter( + "enable-complex-div-converter", llvm::cl::init(false), llvm::cl::Hidden, + llvm::cl::desc("Enable calling of Complex Divi Converter.")); namespace { // Helper class to generate operations related to computing @@ -2320,6 +2323,99 @@ class ReshapeAsElementalConversion } }; +/// This rewrite pattern class performs a custom transformation on FIR +/// 'fir.call' operations that invoke the '__divdc3' runtime function, which is +/// typically used to perform double-precision complex division. +/// +/// When the 'EnableArithmeticBasedComplexDiv' flag is enabled, this pattern +/// matches calls to '__divdc3', extracts the real and imaginary components of +/// the numerator and denominator, and replaces the function call with an +/// explicit computation using MLIR's arithmetic operations. +/// +/// Specifically, it replaces the call to '__divdc3(x0, y0, x1, y1)' —where +/// (x0 + y0i) / (x1 + y1i) is the intended operation—with the mathematically +/// equivalent expression: +/// real_part = (x0*x1 + y0*y1) / (x1^2 + y1^2) +/// imag_part = (y0*x1 - x0*y1) / (x1^2 + y1^2) +/// The result is then reassembled into a 'complex' value using FIR's +/// 'InsertValueOp' instructions. +class HlfirComplexDivisionConversion + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + llvm::LogicalResult + matchAndRewrite(fir::CallOp callOp, + mlir::PatternRewriter &rewriter) const override { + if (!EnableComplexDivConverter) { + LLVM_DEBUG(llvm::dbgs() << "Arithmetic-based Complex Division support is " + "currently disabled \n"); + return mlir::failure(); + } + fir::FirOpBuilder builder{rewriter, callOp.getOperation()}; + const mlir::Location &loc = callOp.getLoc(); + if (!callOp.getCallee()) { + LLVM_DEBUG(llvm::dbgs() + << "No callee found for CallOp at " << loc << "\n"); + return mlir::failure(); + } + + const mlir::SymbolRefAttr &callee = *callOp.getCallee(); + const auto &fctName = callee.getRootReference().getValue(); + if (fctName != "__divdc3") + return mlir::failure(); + + const mlir::Type &eleTy = callOp.getOperands()[0].getType(); + const mlir::Type &resTy = callOp.getResult(0).getType(); + + auto x0 = callOp.getOperands()[0]; // real part of numerator + auto y0 = callOp.getOperands()[1]; // imaginary part of numerator + auto x1 = callOp.getOperands()[2]; // real part of denominator + auto y1 = callOp.getOperands()[3]; // imaginary part of denominator + + // standard complex division formula: + // (x0 + y0i)/(x1 + y1i) = ((x0*x1 + y0*y1)/(x1^2 + y1^2)) + ((y0*x1 - + // x0*y1)/(x1^2 + y1^2))i + auto x0x1 = + rewriter.create(loc, eleTy, x0, x1); // x0 * x1 + auto x1Squared = + rewriter.create(loc, eleTy, x1, x1); // x1^2 + auto y0x1 = + rewriter.create(loc, eleTy, y0, x1); // y0 * x1 + auto x0y1 = + rewriter.create(loc, eleTy, x0, y1); // x0 * y1 + auto y0y1 = + rewriter.create(loc, eleTy, y0, y1); // y0 * y1 + auto y1Squared = + rewriter.create(loc, eleTy, y1, y1); // y1^2 + + auto denom = rewriter.create(loc, eleTy, x1Squared, + y1Squared); // x1^2 + y1^2 + auto realNumerator = rewriter.create( + loc, eleTy, x0x1, y0y1); // x0*x1 + y0*y1 + auto imagNumerator = rewriter.create( + loc, eleTy, y0x1, x0y1); // y0*x1 - x0*y1 + + // compute final real and imaginary parts + auto realResult = + rewriter.create(loc, eleTy, realNumerator, denom); + auto imagResult = + rewriter.create(loc, eleTy, imagNumerator, denom); + + // construct the result complex number + auto undefComplex = rewriter.create(loc, resTy); + auto index0 = builder.getArrayAttr( + {builder.getI32IntegerAttr(0)}); // index for real part + auto index1 = builder.getArrayAttr( + {builder.getI32IntegerAttr(1)}); // index for imag part + auto complexWithReal = rewriter.create( + loc, resTy, undefComplex, realResult, index0); // Insert real part + auto resComplex = rewriter.create( + loc, resTy, complexWithReal, imagResult, + index1); // Insert imaginary part + rewriter.replaceOp(callOp, resComplex.getResult()); + return mlir::success(); + } +}; + class SimplifyHLFIRIntrinsics : public hlfir::impl::SimplifyHLFIRIntrinsicsBase { public: @@ -2365,6 +2461,7 @@ class SimplifyHLFIRIntrinsics patterns.insert(context); patterns.insert(context); + patterns.insert(context); if (mlir::failed(mlir::applyPatternsGreedily( getOperation(), std::move(patterns), config))) { diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 19fc86375248e..70f57bdeddd3f 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -267,7 +267,6 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, bool enableOpenMP, } } pm.addPass(hlfir::createLowerHLFIROrderedAssignments()); - pm.addPass(hlfir::createComplexDivisionConversion()); pm.addPass(hlfir::createLowerHLFIRIntrinsics()); hlfir::BufferizeHLFIROptions bufferizeOptions; diff --git a/flang/test/Fir/target-rewrite-complex-division.fir b/flang/test/Fir/target-rewrite-complex-division.fir index 6fae94ca3e225..f38975a1de1a3 100644 --- a/flang/test/Fir/target-rewrite-complex-division.fir +++ b/flang/test/Fir/target-rewrite-complex-division.fir @@ -1,4 +1,4 @@ -// RUN: fir-opt %s --complex-division-conversion --enable-arithmetic-based-complex-div | FileCheck %s +// RUN: fir-opt %s --simplify-hlfir-intrinsics --enable-complex-div-converter | FileCheck %s func.func @test_double_complex_div(%arg0: !fir.ref>, %arg1: !fir.ref>, %arg2: !fir.ref>) { %0 = fir.load %arg1 : !fir.ref> From 9da44582ab844cf29e97fce597d6d759b8035fde Mon Sep 17 00:00:00 2001 From: Hanwen Date: Thu, 26 Jun 2025 16:10:28 +0800 Subject: [PATCH 10/13] remove redundant lines --- flang/include/flang/Optimizer/HLFIR/Passes.td | 1 - 1 file changed, 1 deletion(-) diff --git a/flang/include/flang/Optimizer/HLFIR/Passes.td b/flang/include/flang/Optimizer/HLFIR/Passes.td index e59dc952537d2..04d7aec5fe489 100644 --- a/flang/include/flang/Optimizer/HLFIR/Passes.td +++ b/flang/include/flang/Optimizer/HLFIR/Passes.td @@ -34,7 +34,6 @@ def LowerHLFIRIntrinsics : Pass<"lower-hlfir-intrinsics", "::mlir::ModuleOp"> { let summary = "Lower HLFIR transformational intrinsic operations"; } - def LowerHLFIROrderedAssignments : Pass<"lower-hlfir-ordered-assignments", "::mlir::ModuleOp"> { let summary = "Lower HLFIR ordered assignments like forall and where operations"; let options = [ From 6744fc7364c25a5304ea05c2e8ef3e9132d27f0c Mon Sep 17 00:00:00 2001 From: Hanwen Date: Thu, 26 Jun 2025 16:19:48 +0800 Subject: [PATCH 11/13] revise naming conventions --- .../Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index bdf465ee605e5..cd8392070884a 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -2327,7 +2327,7 @@ class ReshapeAsElementalConversion /// 'fir.call' operations that invoke the '__divdc3' runtime function, which is /// typically used to perform double-precision complex division. /// -/// When the 'EnableArithmeticBasedComplexDiv' flag is enabled, this pattern +/// When the 'EnableComplexDivConverter' flag is enabled, this pattern /// matches calls to '__divdc3', extracts the real and imaginary components of /// the numerator and denominator, and replaces the function call with an /// explicit computation using MLIR's arithmetic operations. @@ -2339,7 +2339,7 @@ class ReshapeAsElementalConversion /// imag_part = (y0*x1 - x0*y1) / (x1^2 + y1^2) /// The result is then reassembled into a 'complex' value using FIR's /// 'InsertValueOp' instructions. -class HlfirComplexDivisionConversion +class ComplexDivisionConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; llvm::LogicalResult @@ -2461,7 +2461,7 @@ class SimplifyHLFIRIntrinsics patterns.insert(context); patterns.insert(context); - patterns.insert(context); + patterns.insert(context); if (mlir::failed(mlir::applyPatternsGreedily( getOperation(), std::move(patterns), config))) { From 316092edcc1603141fac306da31987cf2d43278f Mon Sep 17 00:00:00 2001 From: Hanwen Date: Thu, 26 Jun 2025 17:40:14 +0800 Subject: [PATCH 12/13] revise naming conventions --- .../Transforms/SimplifyHLFIRIntrinsics.cpp | 31 ++++++++++++------- .../Fir/target-rewrite-complex-division.fir | 2 +- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index cd8392070884a..7dfc1613b4ac6 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -35,9 +35,10 @@ static llvm::cl::opt forceMatmulAsElemental( llvm::cl::desc("Expand hlfir.matmul as elemental operation"), llvm::cl::init(false)); -static llvm::cl::opt EnableComplexDivConverter( - "enable-complex-div-converter", llvm::cl::init(false), llvm::cl::Hidden, - llvm::cl::desc("Enable calling of Complex Divi Converter.")); +static llvm::cl::opt forceComplexDivAsArithmetic( + "flang-complex-div-converter", llvm::cl::init(false), llvm::cl::Hidden, + llvm::cl::desc("Force complex div as arithmetic calculation.")); + namespace { // Helper class to generate operations related to computing @@ -2324,11 +2325,11 @@ class ReshapeAsElementalConversion }; /// This rewrite pattern class performs a custom transformation on FIR -/// 'fir.call' operations that invoke the '__divdc3' runtime function, which is +/// 'fir.call' operation that invoke the '__divdc3' runtime function, which is /// typically used to perform double-precision complex division. /// -/// When the 'EnableComplexDivConverter' flag is enabled, this pattern -/// matches calls to '__divdc3', extracts the real and imaginary components of +/// If the 'forceComplexDivAsArithmetic' flag option is true, this pattern +/// matches call to '__divdc3', extracts the real and imaginary components of /// the numerator and denominator, and replaces the function call with an /// explicit computation using MLIR's arithmetic operations. /// @@ -2339,15 +2340,15 @@ class ReshapeAsElementalConversion /// imag_part = (y0*x1 - x0*y1) / (x1^2 + y1^2) /// The result is then reassembled into a 'complex' value using FIR's /// 'InsertValueOp' instructions. -class ComplexDivisionConversion - : public mlir::OpRewritePattern { +class ComplexDivisionConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; llvm::LogicalResult matchAndRewrite(fir::CallOp callOp, mlir::PatternRewriter &rewriter) const override { - if (!EnableComplexDivConverter) { - LLVM_DEBUG(llvm::dbgs() << "Arithmetic-based Complex Division support is " - "currently disabled \n"); + if (!forceComplexDivAsArithmetic) { + LLVM_DEBUG(llvm::dbgs() + << "Complex division with arithmetic calculation support is " + "currently disabled \n"); return mlir::failure(); } fir::FirOpBuilder builder{rewriter, callOp.getOperation()}; @@ -2461,7 +2462,13 @@ class SimplifyHLFIRIntrinsics patterns.insert(context); patterns.insert(context); - patterns.insert(context); + + /// If the 'forceComplexDivAsArithmetic' flag option is true, this pattern + /// matches call to '__divdc3', extracts the real and imaginary components + /// of the numerator and denominator, and replaces the function call with an + /// explicit computation using MLIR's arithmetic operations. + if (forceComplexDivAsArithmetic) + patterns.insert(context); if (mlir::failed(mlir::applyPatternsGreedily( getOperation(), std::move(patterns), config))) { diff --git a/flang/test/Fir/target-rewrite-complex-division.fir b/flang/test/Fir/target-rewrite-complex-division.fir index f38975a1de1a3..24744a37acb3d 100644 --- a/flang/test/Fir/target-rewrite-complex-division.fir +++ b/flang/test/Fir/target-rewrite-complex-division.fir @@ -1,4 +1,4 @@ -// RUN: fir-opt %s --simplify-hlfir-intrinsics --enable-complex-div-converter | FileCheck %s +// RUN: fir-opt %s --simplify-hlfir-intrinsics --flang-complex-div-converter | FileCheck %s func.func @test_double_complex_div(%arg0: !fir.ref>, %arg1: !fir.ref>, %arg2: !fir.ref>) { %0 = fir.load %arg1 : !fir.ref> From ef7b57275d10d36f276b8dad58c2ade1a02bcd20 Mon Sep 17 00:00:00 2001 From: fanyikang Date: Fri, 27 Jun 2025 10:56:32 +0800 Subject: [PATCH 13/13] revise comment naming convention Co-Authored-By: ict-ql <168183727+ict-ql@users.noreply.github.com> Co-Authored-By: Chyaka <52224511+liliumshade@users.noreply.github.com> --- flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index 7dfc1613b4ac6..2736da374a687 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -2332,7 +2332,6 @@ class ReshapeAsElementalConversion /// matches call to '__divdc3', extracts the real and imaginary components of /// the numerator and denominator, and replaces the function call with an /// explicit computation using MLIR's arithmetic operations. -/// /// Specifically, it replaces the call to '__divdc3(x0, y0, x1, y1)' —where /// (x0 + y0i) / (x1 + y1i) is the intended operation—with the mathematically /// equivalent expression: