From 1522c61697d307ec0271a3def5f012174f02292f Mon Sep 17 00:00:00 2001 From: Kiran Chandramohan Date: Tue, 12 Dec 2023 15:47:22 +0000 Subject: [PATCH 1/7] [Flang] WIP: Allow compiler directives for module procedures --- flang/include/flang/Parser/parse-tree.h | 3 ++- flang/lib/Parser/program-parsers.cpp | 3 ++- flang/lib/Semantics/program-tree.cpp | 3 +++ flang/test/Parser/compiler-directives.f90 | 8 ++++++++ 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h index 393e0e24ec5cb..880f1e249d34a 100644 --- a/flang/include/flang/Parser/parse-tree.h +++ b/flang/include/flang/Parser/parse-tree.h @@ -2894,7 +2894,8 @@ struct ModuleSubprogram { UNION_CLASS_BOILERPLATE(ModuleSubprogram); std::variant, common::Indirection, - common::Indirection> + common::Indirection, + common::Indirection> u; }; diff --git a/flang/lib/Parser/program-parsers.cpp b/flang/lib/Parser/program-parsers.cpp index e24559bf14f7c..ff5e58ebc721c 100644 --- a/flang/lib/Parser/program-parsers.cpp +++ b/flang/lib/Parser/program-parsers.cpp @@ -247,7 +247,8 @@ TYPE_CONTEXT_PARSER("module subprogram part"_en_US, // separate-module-subprogram TYPE_PARSER(construct(indirect(functionSubprogram)) || construct(indirect(subroutineSubprogram)) || - construct(indirect(Parser{}))) + construct(indirect(Parser{})) || + construct(indirect(compilerDirective))) // R1410 module-nature -> INTRINSIC | NON_INTRINSIC constexpr auto moduleNature{ diff --git a/flang/lib/Semantics/program-tree.cpp b/flang/lib/Semantics/program-tree.cpp index bf773f3810c84..fcb6392620b67 100644 --- a/flang/lib/Semantics/program-tree.cpp +++ b/flang/lib/Semantics/program-tree.cpp @@ -111,6 +111,9 @@ static ProgramTree BuildModuleTree(const parser::Name &name, const T &x) { if (subps) { for (const auto &subp : std::get>(subps->t)) { + if (std::holds_alternative< + common::Indirection>(subp.u)) + continue; common::visit( [&](const auto &y) { node.AddChild(ProgramTree::Build(y.value())); }, subp.u); diff --git a/flang/test/Parser/compiler-directives.f90 b/flang/test/Parser/compiler-directives.f90 index 88cfd0944faf0..526f379326909 100644 --- a/flang/test/Parser/compiler-directives.f90 +++ b/flang/test/Parser/compiler-directives.f90 @@ -22,4 +22,12 @@ module m !dir$ optimize : 1 !dir$ loop count (10000) !dir$ loop count (1, 500, 5000, 10000) +contains + !dir$ noinline + subroutine sb1() + end subroutine + + !dir$ noinline + subroutine sb2() + end subroutine end From 8ce3b251d582a96fd2a668a7c29980faf5101ca0 Mon Sep 17 00:00:00 2001 From: Tom Eccles Date: Wed, 13 Dec 2023 11:39:39 +0000 Subject: [PATCH 2/7] WIP: [flang][Lower] support compiler directives inside modules in PFTBuilder Adding a list of all units inside a module in order makes it easier to match directives with functions that immediately follow them. The change to nested units inside of functions was incidental to make this compile, but I expect it should be similarly useful for matching attributes with the unit that follows them inside of a function body. --- flang/include/flang/Lower/PFTBuilder.h | 10 +++- flang/lib/Lower/Bridge.cpp | 33 +++++++++----- flang/lib/Lower/PFTBuilder.cpp | 63 +++++++++++++++++--------- 3 files changed, 72 insertions(+), 34 deletions(-) diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h index 9c6696ff79dae..16076c11483f6 100644 --- a/flang/include/flang/Lower/PFTBuilder.h +++ b/flang/include/flang/Lower/PFTBuilder.h @@ -589,6 +589,12 @@ VariableList getDependentVariableList(const Fortran::semantics::Symbol &); void dump(VariableList &, std::string s = {}); // `s` is an optional dump label +/// Things that can be nested inside of a module or function +/// TODO: add the rest +struct FunctionLikeUnit; +struct CompilerDirectiveUnit; +using NestedUnit = std::variant; + /// Function-like units may contain evaluations (executable statements) and /// nested function-like units (internal procedures and function statements). struct FunctionLikeUnit : public ProgramUnit { @@ -695,7 +701,7 @@ struct FunctionLikeUnit : public ProgramUnit { EvaluationList evaluationList; LabelEvalMap labelEvaluationMap; SymbolLabelMap assignSymbolLabelMap; - std::list nestedFunctions; + std::list nestedUnits; /// pairs for each entry point. The pair at index 0 /// is the primary entry point; remaining pairs are alternate entry points. /// The primary entry point symbol is Null for an anonymous program. @@ -741,7 +747,7 @@ struct ModuleLikeUnit : public ProgramUnit { ModuleStatement beginStmt; ModuleStatement endStmt; - std::list nestedFunctions; + std::list nestedUnits; EvaluationList evaluationList; }; diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 7e64adc3c144c..a08506bf6ebfa 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -303,9 +303,12 @@ class FirConverter : public Fortran::lower::AbstractConverter { }, [&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerModuleDeclScope(m); - for (Fortran::lower::pft::FunctionLikeUnit &f : - m.nestedFunctions) - declareFunction(f); + for (Fortran::lower::pft::NestedUnit &unit : + m.nestedUnits) { + if (auto *f = std::get_if< + Fortran::lower::pft::FunctionLikeUnit>(&unit)) + declareFunction(*f); + } }, [&](Fortran::lower::pft::BlockDataUnit &b) { if (!globalOmpRequiresSymbol) @@ -387,13 +390,17 @@ class FirConverter : public Fortran::lower::AbstractConverter { // Compute the set of host associated entities from the nested functions. llvm::SetVector escapeHost; - for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions) - collectHostAssociatedVariables(f, escapeHost); + for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) { + if (auto *f = std::get_if(&nested)) + collectHostAssociatedVariables(*f, escapeHost); + } funit.setHostAssociatedSymbols(escapeHost); // Declare internal procedures - for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions) - declareFunction(f); + for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) { + if (auto *f = std::get_if(&nested)) + declareFunction(*f); + } } /// Get the scope that is defining or using \p sym. The returned scope is not @@ -4667,8 +4674,10 @@ class FirConverter : public Fortran::lower::AbstractConverter { endNewFunction(funit); } funit.setActiveEntry(0); - for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions) - lowerFunc(f); // internal procedure + for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) { + if (auto *f = std::get_if(&nested)) + lowerFunc(*f); // internal procedure + } } /// Lower module variable definitions to fir::globalOp and OpenMP/OpenACC @@ -4692,8 +4701,10 @@ class FirConverter : public Fortran::lower::AbstractConverter { /// Lower functions contained in a module. void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) { - for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions) - lowerFunc(f); + for (Fortran::lower::pft::NestedUnit &unit : mod.nestedUnits) { + if (auto *f = std::get_if(&unit)) + lowerFunc(*f); + } } void setCurrentPosition(const Fortran::parser::CharBlock &position) { diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp index 32ed539c775b8..0a08e1cf7ff47 100644 --- a/flang/lib/Lower/PFTBuilder.cpp +++ b/flang/lib/Lower/PFTBuilder.cpp @@ -259,6 +259,12 @@ class PFTBuilder { lower::pft::CompilerDirectiveUnit(directive, pftParentStack.back())); return false; } + if (auto *mod = pftParentStack.back().getIf()) { + assert(nestedUnitList && "Modules have a nested units list"); + lower::pft::CompilerDirectiveUnit unit{directive, pftParentStack.back()}; + addNestedUnit(std::move(unit)); + return false; + } return enterConstructOrDirective(directive); } @@ -279,7 +285,7 @@ class PFTBuilder { bool enterModule(const A &mod) { Fortran::lower::pft::ModuleLikeUnit &unit = addUnit(lower::pft::ModuleLikeUnit{mod, pftParentStack.back()}); - functionList = &unit.nestedFunctions; + nestedUnitList = &unit.nestedUnits; pushEvaluationList(&unit.evaluationList); pftParentStack.emplace_back(unit); LLVM_DEBUG(dumpScope(&unit.getScope())); @@ -349,7 +355,7 @@ class PFTBuilder { semanticsContext}); labelEvaluationMap = &unit.labelEvaluationMap; assignSymbolLabelMap = &unit.assignSymbolLabelMap; - functionList = &unit.nestedFunctions; + nestedUnitList = &unit.nestedUnits; pushEvaluationList(&unit.evaluationList); pftParentStack.emplace_back(unit); LLVM_DEBUG(dumpScope(&unit.getScope())); @@ -414,14 +420,14 @@ class PFTBuilder { if (!pftParentStack.empty()) { pftParentStack.back().visit(common::visitors{ [&](lower::pft::FunctionLikeUnit &p) { - functionList = &p.nestedFunctions; + nestedUnitList = &p.nestedUnits; labelEvaluationMap = &p.labelEvaluationMap; assignSymbolLabelMap = &p.assignSymbolLabelMap; }, [&](lower::pft::ModuleLikeUnit &p) { - functionList = &p.nestedFunctions; + nestedUnitList = &p.nestedUnits; }, - [&](auto &) { functionList = nullptr; }, + [&](auto &) { nestedUnitList = nullptr; }, }); } } @@ -432,11 +438,16 @@ class PFTBuilder { return std::get(pgm->getUnits().back()); } + template + void addNestedUnit(A &&source) { + nestedUnitList->emplace_back(lower::pft::NestedUnit{std::move(source)}); + } + template A &addFunction(A &&func) { - if (functionList) { - functionList->emplace_back(std::move(func)); - return functionList->back(); + if (nestedUnitList) { + addNestedUnit(func); + return std::get(nestedUnitList->back()); } return addUnit(std::move(func)); } @@ -459,7 +470,7 @@ class PFTBuilder { /// Append an Evaluation to the end of the current list. lower::pft::Evaluation &addEvaluation(lower::pft::Evaluation &&eval) { - assert(functionList && "not in a function"); + assert(nestedUnitList && "not in a function"); assert(!evaluationListStack.empty() && "empty evaluation list stack"); if (!constructAndDirectiveStack.empty()) eval.parentConstruct = constructAndDirectiveStack.back(); @@ -499,7 +510,7 @@ class PFTBuilder { /// push a new list on the stack of Evaluation lists void pushEvaluationList(lower::pft::EvaluationList *evaluationList) { - assert(functionList && "not in a function"); + assert(nestedUnitList && "not in a function"); assert(evaluationList && evaluationList->empty() && "evaluation list isn't correct"); evaluationListStack.emplace_back(evaluationList); @@ -507,7 +518,7 @@ class PFTBuilder { /// pop the current list and return to the last Evaluation list void popEvaluationList() { - assert(functionList && "not in a function"); + assert(nestedUnitList && "not in a function"); evaluationListStack.pop_back(); } @@ -1088,9 +1099,9 @@ class PFTBuilder { std::vector pftParentStack; const semantics::SemanticsContext &semanticsContext; - /// functionList points to the internal or module procedure function list - /// of a FunctionLikeUnit or a ModuleLikeUnit. It may be null. - std::list *functionList{}; + /// nestedUnitList points to the internal or module procedure unit list + /// of nested units (e.g. functions). It may be null. + std::list *nestedUnitList{}; std::vector constructAndDirectiveStack{}; std::vector doConstructStack{}; /// evaluationListStack is the current nested construct evaluationList state. @@ -1264,11 +1275,17 @@ class PFTDumper { outputStream << ": " << header; outputStream << '\n'; dumpEvaluationList(outputStream, functionLikeUnit.evaluationList); - if (!functionLikeUnit.nestedFunctions.empty()) { + if (!functionLikeUnit.nestedUnits.empty()) { outputStream << "\nContains\n"; - for (const lower::pft::FunctionLikeUnit &func : - functionLikeUnit.nestedFunctions) - dumpFunctionLikeUnit(outputStream, func); + for (const lower::pft::NestedUnit &nested : + functionLikeUnit.nestedUnits) { + if (const auto *func = + std::get_if(&nested)) + dumpFunctionLikeUnit(outputStream, *func); + if (const auto *directive = + std::get_if(&nested)) + dumpCompilerDirectiveUnit(outputStream, *directive); + } outputStream << "End Contains\n"; } outputStream << "End " << unitKind << ' ' << name << "\n\n"; @@ -1298,9 +1315,13 @@ class PFTDumper { outputStream << unitKind << ' ' << name << ": " << header << '\n'; dumpEvaluationList(outputStream, moduleLikeUnit.evaluationList); outputStream << "Contains\n"; - for (const lower::pft::FunctionLikeUnit &func : - moduleLikeUnit.nestedFunctions) - dumpFunctionLikeUnit(outputStream, func); + for (const lower::pft::NestedUnit &nested : moduleLikeUnit.nestedUnits) { + if (const auto *func = std::get_if(&nested)) + dumpFunctionLikeUnit(outputStream, *func); + if (const auto *directive = + std::get_if(&nested)) + dumpCompilerDirectiveUnit(outputStream, *directive); + } outputStream << "End Contains\nEnd " << unitKind << ' ' << name << "\n\n"; } From 0453b9c7b742fded974e7f3cf38c1d4c61d50904 Mon Sep 17 00:00:00 2001 From: Tom Eccles Date: Mon, 11 Dec 2023 16:29:51 +0000 Subject: [PATCH 3/7] WIP: [flang][Lower] add attributes for arm streaming sve directives --- flang/lib/Lower/Bridge.cpp | 76 ++++++++++++++++++++++-- flang/lib/Semantics/resolve-names.cpp | 26 +++++++- flang/test/Lower/arm-ssve-directives.f90 | 46 ++++++++++++++ 3 files changed, 141 insertions(+), 7 deletions(-) create mode 100644 flang/test/Lower/arm-ssve-directives.f90 diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index a08506bf6ebfa..5e0b6a801a98e 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -53,6 +53,7 @@ #include "flang/Semantics/runtime-type-info.h" #include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" +#include "mlir/Dialect/ArmSME/Transforms/Passes.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Parser/Parser.h" @@ -325,13 +326,16 @@ class FirConverter : public Fortran::lower::AbstractConverter { [&]() { createIntrinsicModuleDefinitions(pft); }); // Primary translation pass. - for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) { + std::list &units = pft.getUnits(); + for (auto it = units.begin(); it != units.end(); it = std::next(it)) { std::visit( Fortran::common::visitors{ [&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); }, [&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerMod(m); }, [&](Fortran::lower::pft::BlockDataUnit &b) {}, - [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {}, + [&](Fortran::lower::pft::CompilerDirectiveUnit &d) { + processSubprogramDirective(it, units.end(), d); + }, [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) { builder = new fir::FirOpBuilder(bridge.getModule(), bridge.getKindMap()); @@ -341,7 +345,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { builder = nullptr; }, }, - u); + *it); } // Once all the code has been translated, create global runtime type info @@ -4701,9 +4705,15 @@ class FirConverter : public Fortran::lower::AbstractConverter { /// Lower functions contained in a module. void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) { - for (Fortran::lower::pft::NestedUnit &unit : mod.nestedUnits) { - if (auto *f = std::get_if(&unit)) - lowerFunc(*f); + for (auto it = mod.nestedUnits.begin(); it != mod.nestedUnits.end(); + it = std::next(it)) { + std::visit( + Fortran::common::visitors{ + [&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); }, + [&](Fortran::lower::pft::CompilerDirectiveUnit &d) { + processSubprogramDirective(it, mod.nestedUnits.end(), d); + }}, + *it); } } @@ -5012,6 +5022,60 @@ class FirConverter : public Fortran::lower::AbstractConverter { globalOmpRequiresSymbol); } + /// Process compiler directives that apply to subprograms + template + void + processSubprogramDirective(ITERATOR it, ITERATOR endIt, + Fortran::lower::pft::CompilerDirectiveUnit &d) { + auto *parserDirective = d.getIf(); + if (!parserDirective) + return; + auto *nvList = + std::get_if>( + &parserDirective->u); + if (!nvList) + return; + + // get the function the directive applies to (hopefully the next unit) + mlir::func::FuncOp mlirFunc; + it = std::next(it); + if (it != endIt) { + auto *pftFunction = + std::get_if(&*it); + if (pftFunction) { + Fortran::lower::CalleeInterface callee{*pftFunction, *this}; + mlirFunc = callee.getFuncOp(); + } + } + + for (const Fortran::parser::CompilerDirective::NameValue &nv : *nvList) { + std::string name = std::get(nv.t).ToString(); + + // arm streaming sve directives + auto streamingMode = mlir::arm_sme::ArmStreamingMode::Disabled; + if (name == "arm_streaming") + streamingMode = mlir::arm_sme::ArmStreamingMode::Streaming; + else if (name == "arm_locally_streaming") + streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingLocally; + else if (name == "arm_streaming_compatible") + streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingCompatible; + if (streamingMode != mlir::arm_sme::ArmStreamingMode::Disabled) { + if (!mlirFunc) { + // TODO: share diagnostic code with warnings elsewhere + // TODO: source location is printed as loc<"file.f90":line:col> + mlir::Location loc = genLocation(parserDirective->source); + llvm::errs() << loc << ": warning: ignoring directive '" << name + << "' because it has no associated subprogram\n"; + continue; + } + llvm::StringRef attrName = + mlir::arm_sme::stringifyArmStreamingMode(streamingMode); + mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext()); + mlirFunc->setAttr(attrName, unitAttr); + } + } + } + //===--------------------------------------------------------------------===// Fortran::lower::LoweringBridge &bridge; diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index e1cd34ddf65b6..40fb641e085a7 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -8389,7 +8389,31 @@ void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) { } } } else { - Say(x.source, "Compiler directive was ignored"_warn_en_US); + bool handled = false; + if (const auto *nvList{ + std::get_if>( + &x.u)}) { + for (const parser::CompilerDirective::NameValue &nv : *nvList) { + std::string name = std::get(nv.t).ToString(); + const std::initializer_list handledAttrs{ + "arm_streaming", + "arm_locally_streaming", + "arm_streaming_compatible", + }; + if (std::find(handledAttrs.begin(), handledAttrs.end(), name) == + handledAttrs.end()) { + // exit early so that subsequent recognised attributes can't change + // the result + handled = false; + break; + } + // this attribute was handled + handled = true; + } + } + if (!handled) { + Say(x.source, "Compiler directive was ignored"_warn_en_US); + } } } diff --git a/flang/test/Lower/arm-ssve-directives.f90 b/flang/test/Lower/arm-ssve-directives.f90 new file mode 100644 index 0000000000000..86fbe89920b27 --- /dev/null +++ b/flang/test/Lower/arm-ssve-directives.f90 @@ -0,0 +1,46 @@ +! RUN: bbc -emit-hlfir %s -o - 2>&1 | FileCheck %s + +! check we don't warn about these attributes +! CHECK-NOT: warning: Compiler directive was ignored + +! check we create the right fuction attributes + +!dir$ arm_streaming +subroutine sub +end subroutine sub +! CHECK-LABEL: func.func @_QPsub() +! CHECK-SAME: attributes {arm_streaming} + +!dir$ arm_locally_streaming +subroutine sub2 +end subroutine sub2 +! CHECK-LABEL: func.func @_QPsub2() +! CHECK-SAME: attributes {arm_locally_streaming} + +!dir$ arm_streaming_compatible +subroutine sub3 +end subroutine sub3 +! CHECK-LABEL: func.func @_QPsub3() +! CHECK-SAME: attributes {arm_streaming_compatible} + +module m +contains + +!dir$ arm_streaming +subroutine msub +end subroutine msub +! CHECK-LABEL: func.func @_QMmPmsub() +! CHECK-SAME: attributes {arm_streaming} + +!dir$ arm_locally_streaming +subroutine msub2 +end subroutine msub2 +! CHECK-LABEL: func.func @_QMmPmsub2() +! CHECK-SAME: attributes {arm_locally_streaming} + +!dir$ arm_streaming_compatible +subroutine msub3 +end subroutine msub3 +! CHECK-LABEL: func.func @_QMmPmsub3() +! CHECK-SAME: attributes {arm_streaming_compatible} +end module From 69384d37424e3008bf46ce283bda8561374af44b Mon Sep 17 00:00:00 2001 From: Tom Eccles Date: Wed, 13 Dec 2023 17:23:41 +0000 Subject: [PATCH 4/7] fixup! [Flang] WIP: Allow compiler directives for module procedures Use the call to visit instead of a separate std::holds_alternative --- flang/lib/Semantics/program-tree.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/flang/lib/Semantics/program-tree.cpp b/flang/lib/Semantics/program-tree.cpp index fcb6392620b67..50487ea58dff7 100644 --- a/flang/lib/Semantics/program-tree.cpp +++ b/flang/lib/Semantics/program-tree.cpp @@ -111,11 +111,12 @@ static ProgramTree BuildModuleTree(const parser::Name &name, const T &x) { if (subps) { for (const auto &subp : std::get>(subps->t)) { - if (std::holds_alternative< - common::Indirection>(subp.u)) - continue; common::visit( - [&](const auto &y) { node.AddChild(ProgramTree::Build(y.value())); }, + common::visitors{ + [&](const common::Indirection &) {}, + [&](const auto &y) { + node.AddChild(ProgramTree::Build(y.value())); + }}, subp.u); } } From 7e9b552a3bae17d1c5eb19fc20a57e68c0c1480e Mon Sep 17 00:00:00 2001 From: Tom Eccles Date: Wed, 13 Dec 2023 17:26:10 +0000 Subject: [PATCH 5/7] fixup! WIP: [flang][Lower] add attributes for arm streaming sve directives Use braced initialization --- flang/lib/Semantics/resolve-names.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index 40fb641e085a7..e619e3e0961ec 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -8389,12 +8389,12 @@ void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) { } } } else { - bool handled = false; + bool handled{false}; if (const auto *nvList{ std::get_if>( &x.u)}) { for (const parser::CompilerDirective::NameValue &nv : *nvList) { - std::string name = std::get(nv.t).ToString(); + std::string name{std::get(nv.t).ToString()}; const std::initializer_list handledAttrs{ "arm_streaming", "arm_locally_streaming", From 7eb4a3d0a1cb2a33d65375deac4f3177cb85d7ee Mon Sep 17 00:00:00 2001 From: Mats Petersson Date: Thu, 28 Dec 2023 14:07:49 +0000 Subject: [PATCH 6/7] [flang][Lower] Add the ZA mode directives to support amr streaming sve This adds the arm_new_za, arm_shared_za and arm_preserves_za directives. Also adds two new enum values in the MLIR defintions for ArmZaMode. --- flang/lib/Lower/Bridge.cpp | 20 ++++++++++ flang/lib/Semantics/resolve-names.cpp | 3 ++ flang/test/Lower/arm-ssve-directives.f90 | 37 +++++++++++++++++++ .../mlir/Dialect/ArmSME/Transforms/Passes.td | 6 ++- 4 files changed, 64 insertions(+), 2 deletions(-) diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 5e0b6a801a98e..043c510419181 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -5073,6 +5073,26 @@ class FirConverter : public Fortran::lower::AbstractConverter { mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext()); mlirFunc->setAttr(attrName, unitAttr); } + auto zaMode = mlir::arm_sme::ArmZaMode::Disabled; + if (name == "arm_new_za") + zaMode = mlir::arm_sme::ArmZaMode::NewZA; + else if (name == "arm_shared_za") + zaMode = mlir::arm_sme::ArmZaMode::SharedZA; + else if (name == "arm_preserves_za") + zaMode = mlir::arm_sme::ArmZaMode::PreservesZA; + if (zaMode != mlir::arm_sme::ArmZaMode::Disabled) { + if (!mlirFunc) { + // TODO: share diagnostic code with warnings elsewhere + // TODO: source location is printed as loc<"file.f90":line:col> + mlir::Location loc = genLocation(parserDirective->source); + llvm::errs() << loc << ": warning: ignoring directive '" << name + << "' because it has no associated subprogram\n"; + continue; + } + llvm::StringRef attrName = mlir::arm_sme::stringifyArmZaMode(zaMode); + mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext()); + mlirFunc->setAttr(attrName, unitAttr); + } } } diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp index e619e3e0961ec..ef8accd4636c3 100644 --- a/flang/lib/Semantics/resolve-names.cpp +++ b/flang/lib/Semantics/resolve-names.cpp @@ -8399,6 +8399,9 @@ void ResolveNamesVisitor::Post(const parser::CompilerDirective &x) { "arm_streaming", "arm_locally_streaming", "arm_streaming_compatible", + "arm_shared_za", + "arm_new_za", + "arm_preserves_za", }; if (std::find(handledAttrs.begin(), handledAttrs.end(), name) == handledAttrs.end()) { diff --git a/flang/test/Lower/arm-ssve-directives.f90 b/flang/test/Lower/arm-ssve-directives.f90 index 86fbe89920b27..dd4644f336b6b 100644 --- a/flang/test/Lower/arm-ssve-directives.f90 +++ b/flang/test/Lower/arm-ssve-directives.f90 @@ -23,6 +23,24 @@ end subroutine sub3 ! CHECK-LABEL: func.func @_QPsub3() ! CHECK-SAME: attributes {arm_streaming_compatible} +!dir$ arm_new_za +subroutine sub4 +end subroutine sub4 +! CHECK-LABEL: func.func @_QPsub4() +! CHECK-SAME: attributes {arm_new_za} + +!dir$ arm_shared_za +subroutine sub5 +end subroutine sub5 +! CHECK-LABEL: func.func @_QPsub5() +! CHECK-SAME: attributes {arm_shared_za} + +!dir$ arm_preserves_za +subroutine sub6 +end subroutine sub6 +! CHECK-LABEL: func.func @_QPsub6() +! CHECK-SAME: attributes {arm_preserves_za} + module m contains @@ -43,4 +61,23 @@ subroutine msub3 end subroutine msub3 ! CHECK-LABEL: func.func @_QMmPmsub3() ! CHECK-SAME: attributes {arm_streaming_compatible} + +!dir$ arm_new_za +subroutine msub4 +end subroutine msub4 +! CHECK-LABEL: func.func @_QMmPmsub4() +! CHECK-SAME: attributes {arm_new_za} + +!dir$ arm_shared_za +subroutine msub5 +end subroutine msub5 +! CHECK-LABEL: func.func @_QMmPmsub5() +! CHECK-SAME: attributes {arm_shared_za} + +!dir$ arm_preserves_za +subroutine msub6 +end subroutine msub6 +! CHECK-LABEL: func.func @_QMmPmsub6() +! CHECK-SAME: attributes {arm_preserves_za} + end module diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td index 4266ac5b0c8cf..57f9ac007bae9 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td @@ -28,13 +28,15 @@ def ArmStreamingMode : I32EnumAttr<"ArmStreamingMode", "Armv9 Streaming SVE mode let genSpecializedAttr = 0; } -// TODO: Add other ZA modes. -// https://arm-software.github.io/acle/main/acle.html#sme-attributes-relating-to-za def ArmZaMode : I32EnumAttr<"ArmZaMode", "Armv9 ZA storage mode", [ I32EnumAttrCase<"Disabled", 0, "disabled">, // A function's ZA state is created on entry and destroyed on exit. I32EnumAttrCase<"NewZA", 1, "arm_new_za">, + // A function that preserves ZA state. + I32EnumAttrCase<"PreservesZA", 2, "arm_preserves_za">, + // A function that uses ZA state as input and/or output + I32EnumAttrCase<"SharedZA", 3, "arm_shared_za">, ]>{ let cppNamespace = "mlir::arm_sme"; let genSpecializedAttr = 0; From 6ab4020dbae116f29f95a308251e6cc56a97de2e Mon Sep 17 00:00:00 2001 From: Mats Petersson Date: Fri, 29 Dec 2023 14:16:27 +0000 Subject: [PATCH 7/7] Add docs for directives --- flang/docs/Directives.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/flang/docs/Directives.md b/flang/docs/Directives.md index c8a2c087dfad1..9bd52f11a4fb2 100644 --- a/flang/docs/Directives.md +++ b/flang/docs/Directives.md @@ -29,3 +29,28 @@ A list of non-standard directives supported by Flang end end interface ``` + +## ARM Streaming SVE directives + +These directives are added to support ARM specific instructions. All of +these attributes apply to a specific subroutine or function. These directives +are identical to the attributes provided in C and C++ for the same purpose. +See https://arm-software.github.io/acle/main/acle.html#controlling-the-use-of-streaming-mode for more in depth details. (For the following, function is used +to mean both subroutine and function). + +### Directives relating to ARM Streaming mode + +* `!dir$ arm_streaming` - The function is intended to be used in streaming + mode. +* `!dir$ arm_streaming_compatible` - The function can work both in streaming + mode and non-streaming mode. +* `!dir$ arm_streaming` - The function will enter streaming mode, and return to + non-streaming mode when reaturning. + +### Directives relating to ZA + +* `!dir$ arm_shared_za` - A function that uses ZA for input or output. +* `!dir$ arm_new_za` - A function that has ZA state created and destroyed within + the function. +* `!dir$ arm_preserves_za` - Optimisation hint for the compiler that the + function either doesn't alter, or saves and restores the ZA state.