From aee625a7a2b0b1d328823dd410522ab04cb8ee65 Mon Sep 17 00:00:00 2001 From: Peter Klausler Date: Thu, 3 Jul 2025 15:17:33 -0700 Subject: [PATCH] [flang] Add general symbol dependence collection utility Replace HarvestSymbolsNeededFromOtherModules() in mod-file.cpp with a general utility function in Semantics. This new code will find other uses in further rework of hermetic module file generation as the means by which the necessary subsets of symbols in dependency modules are collected. --- .../flang/Semantics/symbol-dependence.h | 36 ++ flang/lib/Semantics/CMakeLists.txt | 1 + flang/lib/Semantics/mod-file.cpp | 76 +--- flang/lib/Semantics/symbol-dependence.cpp | 356 ++++++++++++++++++ flang/test/Semantics/modfile44.f90 | 2 + flang/test/Semantics/modfile69.f90 | 2 + 6 files changed, 406 insertions(+), 67 deletions(-) create mode 100644 flang/include/flang/Semantics/symbol-dependence.h create mode 100644 flang/lib/Semantics/symbol-dependence.cpp diff --git a/flang/include/flang/Semantics/symbol-dependence.h b/flang/include/flang/Semantics/symbol-dependence.h new file mode 100644 index 0000000000000..9bcff564a4c04 --- /dev/null +++ b/flang/include/flang/Semantics/symbol-dependence.h @@ -0,0 +1,36 @@ +//===-- include/flang/Semantics/symbol-dependence.h -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_SEMANTICS_SYMBOL_DEPENDENCE_H_ +#define FORTRAN_SEMANTICS_SYMBOL_DEPENDENCE_H_ + +#include "flang/Semantics/symbol.h" + +namespace Fortran::semantics { + +// For a set or scope of symbols, computes the transitive closure of their +// dependences due to their types, bounds, specific procedures, interfaces, +// initialization, storage association, &c. Includes the original symbol +// or members of the original set. Does not include dependences from +// subprogram definitions, only their interfaces. +enum DependenceCollectionFlags { + NoDependenceCollectionFlags = 0, + IncludeOriginalSymbols = 1 << 0, + FollowUseAssociations = 1 << 1, + IncludeSpecificsOfGenerics = 1 << 2, + IncludeUsesOfGenerics = 1 << 3, + NotJustForOneModule = 1 << 4, +}; + +SymbolVector CollectAllDependences(const SymbolVector &, + int = NoDependenceCollectionFlags, const Scope * = nullptr); +SymbolVector CollectAllDependences( + const Scope &, int = NoDependenceCollectionFlags); + +} // namespace Fortran::semantics +#endif // FORTRAN_SEMANTICS_SYMBOL_DEPENDENCE_H_ diff --git a/flang/lib/Semantics/CMakeLists.txt b/flang/lib/Semantics/CMakeLists.txt index 109bc2dbb8569..c711103e793d8 100644 --- a/flang/lib/Semantics/CMakeLists.txt +++ b/flang/lib/Semantics/CMakeLists.txt @@ -48,6 +48,7 @@ add_flang_library(FortranSemantics runtime-type-info.cpp scope.cpp semantics.cpp + symbol-dependence.cpp symbol.cpp tools.cpp type.cpp diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp index 82c8536902eb2..ab851587bb49a 100644 --- a/flang/lib/Semantics/mod-file.cpp +++ b/flang/lib/Semantics/mod-file.cpp @@ -15,6 +15,7 @@ #include "flang/Parser/unparse.h" #include "flang/Semantics/scope.h" #include "flang/Semantics/semantics.h" +#include "flang/Semantics/symbol-dependence.h" #include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" #include "llvm/Support/FileSystem.h" @@ -223,72 +224,13 @@ std::string ModFileWriter::GetAsString(const Symbol &symbol) { // Collect symbols from constant and specification expressions that are being // referenced directly from other modules; they may require new USE // associations. -static void HarvestSymbolsNeededFromOtherModules( - SourceOrderedSymbolSet &, const Scope &); -static void HarvestSymbolsNeededFromOtherModules( - SourceOrderedSymbolSet &set, const Symbol &symbol, const Scope &scope) { - auto HarvestBound{[&](const Bound &bound) { - if (const auto &expr{bound.GetExplicit()}) { - for (SymbolRef ref : evaluate::CollectSymbols(*expr)) { - set.emplace(*ref); - } - } - }}; - auto HarvestShapeSpec{[&](const ShapeSpec &shapeSpec) { - HarvestBound(shapeSpec.lbound()); - HarvestBound(shapeSpec.ubound()); - }}; - auto HarvestArraySpec{[&](const ArraySpec &arraySpec) { - for (const auto &shapeSpec : arraySpec) { - HarvestShapeSpec(shapeSpec); - } - }}; - - if (symbol.has()) { - if (symbol.scope()) { - HarvestSymbolsNeededFromOtherModules(set, *symbol.scope()); - } - } else if (const auto &generic{symbol.detailsIf()}; - generic && generic->derivedType()) { - const Symbol &dtSym{*generic->derivedType()}; - if (dtSym.has()) { - if (dtSym.scope()) { - HarvestSymbolsNeededFromOtherModules(set, *dtSym.scope()); - } - } else { - CHECK(dtSym.has() || dtSym.has()); - } - } else if (const auto *object{symbol.detailsIf()}) { - HarvestArraySpec(object->shape()); - HarvestArraySpec(object->coshape()); - if (IsNamedConstant(symbol) || scope.IsDerivedType()) { - if (object->init()) { - for (SymbolRef ref : evaluate::CollectSymbols(*object->init())) { - set.emplace(*ref); - } - } - } - } else if (const auto *proc{symbol.detailsIf()}) { - if (proc->init() && *proc->init() && scope.IsDerivedType()) { - set.emplace(**proc->init()); - } - } else if (const auto *subp{symbol.detailsIf()}) { - for (const Symbol *dummy : subp->dummyArgs()) { - if (dummy) { - HarvestSymbolsNeededFromOtherModules(set, *dummy, scope); - } - } - if (subp->isFunction()) { - HarvestSymbolsNeededFromOtherModules(set, subp->result(), scope); - } - } -} - -static void HarvestSymbolsNeededFromOtherModules( - SourceOrderedSymbolSet &set, const Scope &scope) { - for (const auto &[_, symbol] : scope) { - HarvestSymbolsNeededFromOtherModules(set, *symbol, scope); +static SourceOrderedSymbolSet HarvestSymbolsNeededFromOtherModules( + const Scope &scope) { + SourceOrderedSymbolSet set; + for (const Symbol &symbol : CollectAllDependences(scope)) { + set.insert(symbol); } + return set; } void ModFileWriter::PrepareRenamings(const Scope &scope) { @@ -304,8 +246,8 @@ void ModFileWriter::PrepareRenamings(const Scope &scope) { } } // Collect symbols needed from other modules - SourceOrderedSymbolSet symbolsNeeded; - HarvestSymbolsNeededFromOtherModules(symbolsNeeded, scope); + SourceOrderedSymbolSet symbolsNeeded{ + HarvestSymbolsNeededFromOtherModules(scope)}; // Establish any necessary renamings of symbols in other modules // to their names in this scope, creating those new names when needed. auto &renamings{context_.moduleFileOutputRenamings()}; diff --git a/flang/lib/Semantics/symbol-dependence.cpp b/flang/lib/Semantics/symbol-dependence.cpp new file mode 100644 index 0000000000000..2591f609f3d00 --- /dev/null +++ b/flang/lib/Semantics/symbol-dependence.cpp @@ -0,0 +1,356 @@ +//===-- lib/Semantics/symbol-dependence.cpp -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Semantics/symbol-dependence.h" +#include "flang/Common/idioms.h" +#include "flang/Common/restorer.h" +#include "flang/Common/visit.h" +#include + +static constexpr bool EnableDebugging{false}; + +namespace Fortran::semantics { + +// Helper class that collects all of the symbol dependences for a +// given symbol. +class Collector { +public: + explicit Collector(int flags) : flags_{flags} {} + + void CollectSymbolDependences(const Symbol &); + UnorderedSymbolSet MustFollowDependences() { return std::move(dependences_); } + SymbolVector AllDependences() { return std::move(mentions_); } + +private: + // This symbol is depended upon and its declaration must precede + // the symbol of interest. + void MustFollow(const Symbol &x) { + if (!possibleImports_ || !DoesScopeContain(possibleImports_, x)) { + dependences_.insert(x); + } + } + // This symbol is depended upon, but is not necessarily a dependence + // that must precede the symbol of interest in the output of the + // topological sort. + void Need(const Symbol &x) { + if (mentioned_.insert(x).second) { + mentions_.emplace_back(x); + } + } + void Need(const Symbol *x) { + if (x) { + Need(*x); + } + } + + // These overloads of Collect() are mutally recursive, so they're + // packaged as member functions of a class. + void Collect(const Symbol &x) { + Need(x); + const auto *subp{x.detailsIf()}; + if ((subp && subp->isInterface()) || IsDummy(x) || + x.has() || x.has()) { + // can be forward-referenced + } else { + MustFollow(x); + } + } + void Collect(SymbolRef x) { Collect(*x); } + template void Collect(const std::optional &x) { + if (x) { + Collect(*x); + } + } + template void Collect(const A *x) { + if (x) { + Collect(*x); + } + } + void Collect(const UnorderedSymbolSet &x) { + for (const Symbol &symbol : x) { + Collect(symbol); + } + } + void Collect(const SourceOrderedSymbolSet &x) { + for (const Symbol &symbol : x) { + Collect(symbol); + } + } + void Collect(const SymbolVector &x) { + for (const Symbol &symbol : x) { + Collect(symbol); + } + } + void Collect(const Scope &x) { Collect(x.GetSymbols()); } + template void Collect(const evaluate::Expr &x) { + UnorderedSymbolSet exprSyms{evaluate::CollectSymbols(x)}; + for (const Symbol &sym : exprSyms) { + if (!sym.owner().IsDerivedType()) { + Collect(sym); + } + } + } + void Collect(const DeclTypeSpec &type) { + if (type.category() == DeclTypeSpec::Category::Character) { + Collect(type.characterTypeSpec().length()); + } else { + Collect(type.AsDerived()); + } + } + void Collect(const DerivedTypeSpec &type) { + const Symbol &typeSym{type.originalTypeSymbol()}; + if (!derivedTypeReferenceCanBeForward_ || !type.parameters().empty()) { + MustFollow(typeSym); + } + Need(typeSym); + for (const auto &[_, value] : type.parameters()) { + Collect(value); + } + } + void Collect(const ParamValue &x) { Collect(x.GetExplicit()); } + void Collect(const Bound &x) { Collect(x.GetExplicit()); } + void Collect(const ShapeSpec &x) { + Collect(x.lbound()); + Collect(x.ubound()); + } + void Collect(const ArraySpec &x) { + for (const ShapeSpec &shapeSpec : x) { + Collect(shapeSpec); + } + } + + UnorderedSymbolSet mentioned_, dependences_; + SymbolVector mentions_; + int flags_{NoDependenceCollectionFlags}; + bool derivedTypeReferenceCanBeForward_{false}; + const Scope *possibleImports_{nullptr}; +}; + +void Collector::CollectSymbolDependences(const Symbol &symbol) { + if (symbol.has() || symbol.has()) { + // type will be picked up later for the function result, if any + } else if (symbol.has() || symbol.has() || + symbol.has()) { + } else if (IsAllocatableOrPointer(symbol) && symbol.owner().IsDerivedType()) { + bool saveCanBeForward{derivedTypeReferenceCanBeForward_}; + derivedTypeReferenceCanBeForward_ = true; + Collect(symbol.GetType()); + derivedTypeReferenceCanBeForward_ = saveCanBeForward; + } else { + Collect(symbol.GetType()); + } + common::visit( + common::visitors{ + [this, &symbol](const ObjectEntityDetails &x) { + Collect(x.shape()); + Collect(x.coshape()); + if (IsNamedConstant(symbol) || symbol.owner().IsDerivedType()) { + Collect(x.init()); + } + Need(x.commonBlock()); + if (const auto *set{FindEquivalenceSet(symbol)}) { + for (const EquivalenceObject &equivObject : *set) { + Need(equivObject.symbol); + } + } + }, + [this, &symbol](const ProcEntityDetails &x) { + Collect(x.rawProcInterface()); + if (symbol.owner().IsDerivedType()) { + Collect(x.init()); + } + }, + [this](const ProcBindingDetails &x) { Need(x.symbol()); }, + [this, &symbol](const SubprogramDetails &x) { + // Note dummy arguments & result symbol without dependence, unless + // the subprogram is an interface block that might need to IMPORT + // a type. + bool needImports{x.isInterface()}; + auto restorer{common::ScopedSet( + possibleImports_, needImports ? symbol.scope() : nullptr)}; + for (const Symbol *dummy : x.dummyArgs()) { + if (dummy) { + Need(*dummy); + if (needImports) { + CollectSymbolDependences(*dummy); + } + } + } + if (x.isFunction()) { + Need(x.result()); + if (needImports) { + CollectSymbolDependences(x.result()); + } + } + }, + [this, &symbol](const DerivedTypeDetails &x) { + Collect(symbol.scope()); + for (const auto &[_, symbolRef] : x.finals()) { + Need(*symbolRef); + } + }, + [this](const GenericDetails &x) { + Collect(x.derivedType()); + Collect(x.specific()); + if (flags_ & IncludeUsesOfGenerics) { + for (const Symbol &use : x.uses()) { + Collect(use); + } + } + if (flags_ & IncludeSpecificsOfGenerics) { + for (const Symbol &specific : x.specificProcs()) { + Collect(specific); + } + } + }, + [this](const NamelistDetails &x) { + for (const Symbol &symbol : x.objects()) { + Collect(symbol); + } + }, + [this](const CommonBlockDetails &x) { + for (auto ref : x.objects()) { + Collect(*ref); + } + }, + [this](const UseDetails &x) { + if (flags_ & FollowUseAssociations) { + Need(x.symbol()); + } + }, + [this](const HostAssocDetails &x) { Need(x.symbol()); }, + [](const auto &) {}, + }, + symbol.details()); +} + +SymbolVector CollectAllDependences(const Scope &scope, int flags) { + SymbolVector basis{scope.GetSymbols()}; + return CollectAllDependences(basis, flags, &scope); +} + +// Returns a vector of symbols, topologically sorted by dependence +SymbolVector CollectAllDependences( + const SymbolVector &original, int flags, const Scope *forScope) { + std::queue work; + UnorderedSymbolSet enqueued; + for (const Symbol &symbol : original) { + if (!symbol.test(Symbol::Flag::CompilerCreated)) { + work.push(&symbol); + enqueued.insert(symbol); + } + } + // For each symbol, collect its dependences into "topology". + // The "visited" vector and "enqueued" set hold all of the + // symbols considered. + std::map topology; + std::vector visited; + visited.reserve(2 * original.size()); + std::optional forModuleName; + if (forScope && !(flags & NotJustForOneModule)) { + if (const Scope *forModule{FindModuleContaining(*forScope)}) { + forModuleName = forModule->GetName(); + } + } + while (!work.empty()) { + const Symbol &symbol{*work.front()}; + work.pop(); + visited.push_back(&symbol); + Collector collector{flags}; + bool doCollection{true}; + if (forModuleName) { + if (const Scope *symModule{FindModuleContaining(symbol.owner())}) { + if (auto symModName{symModule->GetName()}) { + doCollection = *forModuleName == *symModName; + } + } + } + if (doCollection) { + collector.CollectSymbolDependences(symbol); + } + auto dependences{collector.MustFollowDependences()}; + auto mentions{collector.AllDependences()}; + if constexpr (EnableDebugging) { + for (const Symbol &need : dependences) { + llvm::errs() << "symbol " << symbol << " must follow " << need << '\n'; + } + for (const Symbol &need : mentions) { + llvm::errs() << "symbol " << symbol << " needs " << need << '\n'; + } + } + CHECK(topology.find(&symbol) == topology.end()); + topology.emplace(&symbol, std::move(dependences)); + for (const Symbol &symbol : mentions) { + if (!symbol.test(Symbol::Flag::CompilerCreated)) { + if (enqueued.insert(symbol).second) { + work.push(&symbol); + } + } + } + } + CHECK(enqueued.size() == visited.size()); + // Topological sorting + // Subtle: This inverted topology map uses a SymbolVector, not a set + // of symbols, so that the order of symbols in the final output remains + // deterministic. + std::map invertedTopology; + for (const Symbol *symbol : visited) { + invertedTopology[symbol] = SymbolVector{}; + } + std::map numWaitingFor; + for (const Symbol *symbol : visited) { + auto topoIter{topology.find(symbol)}; + CHECK(topoIter != topology.end()); + const auto &needs{topoIter->second}; + if (needs.empty()) { + work.push(symbol); + } else { + numWaitingFor[symbol] = needs.size(); + for (const Symbol &need : needs) { + invertedTopology[&need].push_back(*symbol); + } + } + } + CHECK(visited.size() == work.size() + numWaitingFor.size()); + SymbolVector resultVector; + while (!work.empty()) { + const Symbol &symbol{*work.front()}; + work.pop(); + resultVector.push_back(symbol); + auto enqueuedIter{enqueued.find(symbol)}; + CHECK(enqueuedIter != enqueued.end()); + enqueued.erase(enqueuedIter); + if (auto invertedIter{invertedTopology.find(&symbol)}; + invertedIter != invertedTopology.end()) { + for (const Symbol &neededBy : invertedIter->second) { + std::size_t stillAwaiting{numWaitingFor[&neededBy] - 1}; + if (stillAwaiting == 0) { + work.push(&neededBy); + } else { + numWaitingFor[&neededBy] = stillAwaiting; + } + } + } + } + if constexpr (EnableDebugging) { + llvm::errs() << "Topological sort failed in CollectAllDependences\n"; + for (const Symbol &remnant : enqueued) { + auto topoIter{topology.find(&remnant)}; + CHECK(topoIter != topology.end()); + llvm::errs() << " remnant symbol " << remnant << " needs:\n"; + for (const Symbol &n : topoIter->second) { + llvm::errs() << " " << n << '\n'; + } + } + } + CHECK(enqueued.empty()); + CHECK(resultVector.size() == visited.size()); + return resultVector; +} + +} // namespace Fortran::semantics diff --git a/flang/test/Semantics/modfile44.f90 b/flang/test/Semantics/modfile44.f90 index 23d93b18a5a1f..48f6250cf6ce8 100644 --- a/flang/test/Semantics/modfile44.f90 +++ b/flang/test/Semantics/modfile44.f90 @@ -45,6 +45,8 @@ function foo(j) result(res) !Expect: m2.mod !module m2 +!use m1,only:xyz +!private::xyz !contains !function foo(j) result(res) !use m1,only:xyz diff --git a/flang/test/Semantics/modfile69.f90 b/flang/test/Semantics/modfile69.f90 index 6586e0524f5ea..f969573efdd78 100644 --- a/flang/test/Semantics/modfile69.f90 +++ b/flang/test/Semantics/modfile69.f90 @@ -36,6 +36,8 @@ subroutine sub(x) !Expect: m3.mod !module m3 +!use m2,only:bar +!private::bar !contains !subroutine sub(x) !use m2,only:bar