diff --git a/mlir/include/Driver/CompilerDriver.h b/mlir/include/Driver/CompilerDriver.h index 7c0ba8b9eb..9e03baf74c 100644 --- a/mlir/include/Driver/CompilerDriver.h +++ b/mlir/include/Driver/CompilerDriver.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -22,7 +23,6 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" #include "Driver/Pipelines.h" @@ -35,11 +35,11 @@ namespace driver { // low-level messages, we might want to hide these. enum class Verbosity { Silent = 0, Urgent = 1, Debug = 2, All = 3 }; -enum SaveTemps { None, AfterPipeline, AfterPass }; +enum class SaveTemps { None, AfterPipeline, AfterPass }; -enum Action { OPT, Translate, LLC, All }; +enum class Action { OPT, Translate, LLC, All }; -enum InputType { MLIR, LLVMIR, OTHER }; +enum class InputType { MLIR, LLVMIR, OTHER }; /// Helper verbose reporting macro. #define CO_MSG(opt, level, op) \ @@ -84,7 +84,7 @@ struct CompilerOptions { }; struct CompilerOutput { - typedef std::unordered_map PipelineOutputs; + using PipelineOutputs = std::unordered_map; std::string outputFilename; std::string outIR; std::string diagnosticMessages; @@ -94,9 +94,10 @@ struct CompilerOutput { bool isCheckpointFound; // Gets the next pipeline dump file name, prefixed with number. - std::string nextPipelineDumpFilename(std::string pipelineName, std::string ext = ".mlir") + std::string nextPipelineDumpFilename(const std::string &pipelineName, + const std::string &ext = ".mlir") { - return std::filesystem::path(std::to_string(this->pipelineCounter++) + "_" + pipelineName) + return std::filesystem::path(std::format("{}_{}", this->pipelineCounter++, pipelineName)) .replace_extension(ext); }; }; diff --git a/mlir/lib/Driver/CompilerDriver.cpp b/mlir/lib/Driver/CompilerDriver.cpp index d72ef39ebb..1605d22a7e 100644 --- a/mlir/lib/Driver/CompilerDriver.cpp +++ b/mlir/lib/Driver/CompilerDriver.cpp @@ -17,12 +17,10 @@ #include #include #include -#include #include #include #include #include -#include #include "mhlo/IR/register.h" #include "mhlo/transforms/passes.h" @@ -41,7 +39,6 @@ #include "llvm/IRReader/IRReader.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Passes/PassBuilder.h" -#include "llvm/Support/FileSystem.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" @@ -65,15 +62,12 @@ #include "Gradient/IR/GradientDialect.h" #include "Gradient/IR/GradientInterfaces.h" #include "Gradient/Transforms/BufferizableOpInterfaceImpl.h" -#include "Gradient/Transforms/Passes.h" #include "Ion/IR/IonDialect.h" #include "MBQC/IR/MBQCDialect.h" #include "Mitigation/IR/MitigationDialect.h" -#include "Mitigation/Transforms/Passes.h" #include "QEC/IR/QECDialect.h" #include "Quantum/IR/QuantumDialect.h" #include "Quantum/Transforms/BufferizableOpInterfaceImpl.h" -#include "Quantum/Transforms/Passes.h" #include "Enzyme.h" #include "Timer.hpp" @@ -98,19 +92,19 @@ namespace catalyst::utils { */ class LinesCount { private: - inline static void print(const std::string &opStrBuf, const std::string &name) + inline static void print(std::string_view opStrBuf, std::string_view name) { - const auto num_lines = std::count(opStrBuf.cbegin(), opStrBuf.cend(), '\n'); + const auto num_lines = std::ranges::count(opStrBuf, '\n'); if (!name.empty()) { std::cerr << "[DIAGNOSTICS] After " << std::setw(25) << std::left << name; } std::cerr << "\t" << std::fixed << "programsize: " << num_lines << std::fixed << " lines\n"; } - inline static void store(const std::string &opStrBuf, const std::string &name, + inline static void store(std::string_view opStrBuf, std::string_view name, const std::filesystem::path &file_path) { - const auto num_lines = std::count(opStrBuf.cbegin(), opStrBuf.cend(), '\n'); + const auto num_lines = std::ranges::count(opStrBuf, '\n'); const std::string_view key_padding = " "; const std::string_view val_padding = " "; @@ -137,7 +131,7 @@ class LinesCount { ofile.close(); } - inline static void dump(const std::string &opStrBuf, const std::string &name = {}) + inline static void dump(std::string_view opStrBuf, std::string_view name = {}) { char *file = getenv("DIAGNOSTICS_RESULTS_PATH"); if (!file) { @@ -151,11 +145,11 @@ class LinesCount { public: [[nodiscard]] inline static bool is_diagnostics_enabled() { - char *value = getenv("ENABLE_DIAGNOSTICS"); + const char *const value = getenv("ENABLE_DIAGNOSTICS"); return value && std::string(value) == "ON"; } - static void Operation(Operation *op, const std::string &name = {}) + static void Operation(const Operation *op, const std::string &name = {}) { if (!is_diagnostics_enabled()) { return; @@ -208,11 +202,11 @@ std::string joinPasses(const llvm::SmallVector &passes) } struct CatalystIRPrinterConfig : public PassManager::IRPrinterConfig { - typedef std::function PrintHandler; + using PrintHandler = std::function; PrintHandler printHandler; - CatalystIRPrinterConfig(PrintHandler printHandler) - : IRPrinterConfig(/*printModuleScope=*/true), printHandler(printHandler) + explicit CatalystIRPrinterConfig(PrintHandler printHandler) + : IRPrinterConfig(/*printModuleScope=*/true), printHandler(std::move(printHandler)) { } @@ -225,15 +219,16 @@ struct CatalystIRPrinterConfig : public PassManager::IRPrinterConfig { }; struct CatalystPassInstrumentation : public PassInstrumentation { - typedef std::function PassCallback; + using PassCallback = std::function; PassCallback beforePassCallback; PassCallback afterPassCallback; PassCallback afterPassFailedCallback; CatalystPassInstrumentation(PassCallback beforePassCallback, PassCallback afterPassCallback, PassCallback afterPassFailedCallback) - : beforePassCallback(beforePassCallback), afterPassCallback(afterPassCallback), - afterPassFailedCallback(afterPassFailedCallback) + : beforePassCallback(std::move(beforePassCallback)), + afterPassCallback(std::move(afterPassCallback)), + afterPassFailedCallback(std::move(afterPassFailedCallback)) { } @@ -270,7 +265,7 @@ OwningOpRef parseMLIRSource(MLIRContext *ctx, const llvm::SourceMgr &s bool containsGradients(mlir::ModuleOp moduleOp) { bool contain = false; - moduleOp.walk([&](catalyst::gradient::GradientOpInterface op) { + moduleOp.walk([&]([[maybe_unused]] catalyst::gradient::GradientOpInterface op) { contain = true; return WalkResult::interrupt(); }); @@ -310,7 +305,7 @@ void registerAllCatalystDialects(DialectRegistry ®istry) // Determines if the compilation stage should be executed if a checkpointStage is given bool shouldRunStage(const CompilerOptions &options, CompilerOutput &output, - const std::string &stageName) + std::string_view stageName) { if (options.checkpointStage.empty()) { return true; @@ -354,7 +349,7 @@ LogicalResult runCoroLLVMPasses(const CompilerOptions &options, // Optimize the IR! CoroPM.run(*llvmModule.get(), MAM); - if (options.keepIntermediate) { + if (options.keepIntermediate != SaveTemps::None) { std::string tmp; llvm::raw_string_ostream rawStringOstream{tmp}; llvmModule->print(rawStringOstream, nullptr); @@ -406,7 +401,7 @@ LogicalResult runO2LLVMPasses(const CompilerOptions &options, // Optimize the IR! MPM.run(*llvmModule.get(), MAM); - if (options.keepIntermediate) { + if (options.keepIntermediate != SaveTemps::None) { std::string tmp; llvm::raw_string_ostream rawStringOstream{tmp}; llvmModule->print(rawStringOstream, nullptr); @@ -454,7 +449,7 @@ LogicalResult runEnzymePasses(const CompilerOptions &options, // Optimize the IR! MPM.run(*llvmModule.get(), MAM); - if (options.keepIntermediate) { + if (options.keepIntermediate != SaveTemps::None) { std::string tmp; llvm::raw_string_ostream rawStringOstream{tmp}; llvmModule->print(rawStringOstream, nullptr); @@ -469,7 +464,8 @@ std::string readInputFile(const std::string &filename) { if (filename == "-") { std::stringstream buffer; - std::istreambuf_iterator begin(std::cin), end; + std::istreambuf_iterator begin(std::cin); + std::istreambuf_iterator end; buffer << std::string(begin, end); return buffer.str(); } @@ -486,7 +482,8 @@ LogicalResult preparePassManager(PassManager &pm, const CompilerOptions &options CompilerOutput &output, catalyst::utils::Timer &timer, TimingScope &timing) { - auto beforePassCallback = [&](Pass *pass, Operation *op) { + auto beforePassCallback = [&]([[maybe_unused]] const Pass *const pass, + [[maybe_unused]] const Operation *const op) { if (options.verbosity >= Verbosity::Debug && !timer.is_active()) { timer.start(); } @@ -494,7 +491,7 @@ LogicalResult preparePassManager(PassManager &pm, const CompilerOptions &options // For each pipeline-terminating pass, print the IR into the corresponding dump file and // into a diagnostic output buffer. Note that one pass can terminate multiple pipelines. - auto afterPassCallback = [&](Pass *pass, Operation *op) { + auto afterPassCallback = [&](const Pass *const pass, const Operation *const op) { auto pipelineName = pass->getName(); if (options.verbosity >= Verbosity::Debug) { timer.dump(pipelineName.str(), /*add_endl */ false); @@ -514,12 +511,12 @@ LogicalResult preparePassManager(PassManager &pm, const CompilerOptions &options }; // For each failed pass, print the owner pipeline name into a diagnostic stream. - auto afterPassFailedCallback = [&](Pass *pass, Operation *op) { + auto afterPassFailedCallback = [&](const Pass *const pass, const Operation *const op) { options.diagnosticStream << "While processing '" << pass->getName().str() << "' pass "; std::string tmp; llvm::raw_string_ostream s{tmp}; s << *op; - if (options.keepIntermediate) { + if (options.keepIntermediate != SaveTemps::None) { dumpToFile(options, output.nextPipelineDumpFilename(pass->getName().str() + "_FAILED"), tmp); } @@ -532,8 +529,8 @@ LogicalResult preparePassManager(PassManager &pm, const CompilerOptions &options if (failed(config.setupPassPipeline(pm))) return failure(); pm.enableTiming(timing); - pm.addInstrumentation(std::unique_ptr(new CatalystPassInstrumentation( - beforePassCallback, afterPassCallback, afterPassFailedCallback))); + pm.addInstrumentation(std::make_unique( + beforePassCallback, afterPassCallback, afterPassFailedCallback)); return success(); } @@ -559,7 +556,7 @@ LogicalResult configurePipeline(PassManager &pm, const CompilerOptions &options, LogicalResult runPipeline(PassManager &pm, const CompilerOptions &options, CompilerOutput &output, Pipeline &pipeline, bool clHasManualPipeline, ModuleOp moduleOp) { - if (!shouldRunStage(options, output, pipeline.getName()) || pipeline.getPasses().size() == 0) { + if (!shouldRunStage(options, output, pipeline.getName()) || pipeline.getPasses().empty()) { return success(); } if (failed(configurePipeline(pm, options, pipeline, clHasManualPipeline))) { @@ -570,7 +567,8 @@ LogicalResult runPipeline(PassManager &pm, const CompilerOptions &options, Compi llvm::errs() << "Failed to run pipeline: " << pipeline.getName() << "\n"; return failure(); } - if (options.keepIntermediate && (options.checkpointStage.empty() || output.isCheckpointFound)) { + if ((options.keepIntermediate != SaveTemps::None) && + (options.checkpointStage.empty() || output.isCheckpointFound)) { std::string tmp; llvm::raw_string_ostream s{tmp}; s << moduleOp; @@ -583,7 +581,8 @@ LogicalResult runLowering(const CompilerOptions &options, MLIRContext *ctx, Modu CompilerOutput &output, TimingScope &timing) { - if (options.keepIntermediate && (options.checkpointStage.empty() || output.isCheckpointFound)) { + if ((options.keepIntermediate != SaveTemps::None) && + (options.checkpointStage.empty() || output.isCheckpointFound)) { std::string tmp; llvm::raw_string_ostream s{tmp}; s << moduleOp; @@ -599,7 +598,7 @@ LogicalResult runLowering(const CompilerOptions &options, MLIRContext *ctx, Modu return failure(); } - bool clHasIndividualPass = pm.size() > 0; + bool clHasIndividualPass = !pm.empty(); bool clHasManualPipeline = !options.pipelinesCfg.empty(); if (clHasIndividualPass && clHasManualPipeline) { llvm::errs() << "--catalyst-pipeline option can't be used with individual pass options " @@ -648,7 +647,7 @@ LogicalResult verifyInputType(const CompilerOptions &options, InputType inType) } LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput &output, - DialectRegistry ®istry) + const DialectRegistry ®istry) { using timer = catalyst::utils::Timer; @@ -663,7 +662,7 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput & ctx.loadAllAvailableDialects(); ScopedDiagnosticHandler scopedHandler( - &ctx, [&](Diagnostic &diag) { diag.print(options.diagnosticStream); }); + &ctx, [&](const Diagnostic &diag) { diag.print(options.diagnosticStream); }); llvm::LLVMContext llvmContext; std::shared_ptr llvmModule; @@ -728,7 +727,7 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput & return failure(); } output.outIR.clear(); - if (options.keepIntermediate) { + if (options.keepIntermediate != SaveTemps::None) { outIRStream << *mlirModule; } optTiming.stop(); @@ -748,7 +747,7 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput & inType = InputType::LLVMIR; catalyst::utils::LinesCount::Module(*llvmModule); - if (options.keepIntermediate) { + if (options.keepIntermediate != SaveTemps::None) { std::string tmp; llvm::raw_string_ostream rawStringOstream{tmp}; llvmModule->print(rawStringOstream, nullptr); @@ -756,7 +755,7 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput & dumpToFile(options, outFile, tmp); } output.outIR.clear(); - if (options.keepIntermediate) { + if (options.keepIntermediate != SaveTemps::None) { outIRStream << *llvmModule; } translateTiming.stop(); @@ -824,7 +823,7 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput & TimingScope outputTiming = llcTiming.nest("compileObject"); output.outIR.clear(); - if (options.keepIntermediate) { + if (options.keepIntermediate != SaveTemps::None) { outIRStream << *llvmModule; } @@ -850,7 +849,7 @@ LogicalResult QuantumDriverMain(const CompilerOptions &options, CompilerOutput & outfile->keep(); } - if (options.keepIntermediate and output.outputFilename != "-") { + if ((options.keepIntermediate != SaveTemps::None) && output.outputFilename != "-") { outfile->os() << output.outIR; outfile->keep(); } @@ -900,7 +899,7 @@ std::vector parsePipelines(const cl::list &catalystPipeli passesStr.split(passList, ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false); llvm::SmallVector passes; - for (auto &pass : passList) { + for (const auto &pass : passList) { passes.push_back(pass.trim().str()); } @@ -972,7 +971,8 @@ int QuantumDriverMainFromCL(int argc, char **argv) catalyst::quantum::registerBufferizableOpInterfaceExternalModels(registry); // Register and parse command line options. - std::string inputFilename, outputFilename; + std::string inputFilename; + std::string outputFilename; std::string helpStr = "Catalyst Command Line Interface options. \n" "Below, there is a complete list of options for the Catalyst CLI tool" "In the first section, you can find the options that are used to" @@ -990,7 +990,7 @@ int QuantumDriverMainFromCL(int argc, char **argv) return 1; } - std::unique_ptr output(new CompilerOutput()); + auto output = std::make_unique(); assert(output); output->outputFilename = outputFilename; llvm::raw_string_ostream errStream{output->diagnosticMessages};