diff --git a/flang/include/flang/Semantics/symbol-dependence.h b/flang/include/flang/Semantics/symbol-dependence.h new file mode 100644 index 0000000000000..9bcff564a4c04 --- /dev/null +++ b/flang/include/flang/Semantics/symbol-dependence.h @@ -0,0 +1,36 @@ +//===-- include/flang/Semantics/symbol-dependence.h -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef FORTRAN_SEMANTICS_SYMBOL_DEPENDENCE_H_ +#define FORTRAN_SEMANTICS_SYMBOL_DEPENDENCE_H_ + +#include "flang/Semantics/symbol.h" + +namespace Fortran::semantics { + +// For a set or scope of symbols, computes the transitive closure of their +// dependences due to their types, bounds, specific procedures, interfaces, +// initialization, storage association, &c. Includes the original symbol +// or members of the original set. Does not include dependences from +// subprogram definitions, only their interfaces. +enum DependenceCollectionFlags { + NoDependenceCollectionFlags = 0, + IncludeOriginalSymbols = 1 << 0, + FollowUseAssociations = 1 << 1, + IncludeSpecificsOfGenerics = 1 << 2, + IncludeUsesOfGenerics = 1 << 3, + NotJustForOneModule = 1 << 4, +}; + +SymbolVector CollectAllDependences(const SymbolVector &, + int = NoDependenceCollectionFlags, const Scope * = nullptr); +SymbolVector CollectAllDependences( + const Scope &, int = NoDependenceCollectionFlags); + +} // namespace Fortran::semantics +#endif // FORTRAN_SEMANTICS_SYMBOL_DEPENDENCE_H_ diff --git a/flang/lib/Semantics/CMakeLists.txt b/flang/lib/Semantics/CMakeLists.txt index 109bc2dbb8569..c711103e793d8 100644 --- a/flang/lib/Semantics/CMakeLists.txt +++ b/flang/lib/Semantics/CMakeLists.txt @@ -48,6 +48,7 @@ add_flang_library(FortranSemantics runtime-type-info.cpp scope.cpp semantics.cpp + symbol-dependence.cpp symbol.cpp tools.cpp type.cpp diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp index 82c8536902eb2..ab851587bb49a 100644 --- a/flang/lib/Semantics/mod-file.cpp +++ b/flang/lib/Semantics/mod-file.cpp @@ -15,6 +15,7 @@ #include "flang/Parser/unparse.h" #include "flang/Semantics/scope.h" #include "flang/Semantics/semantics.h" +#include "flang/Semantics/symbol-dependence.h" #include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" #include "llvm/Support/FileSystem.h" @@ -223,72 +224,13 @@ std::string ModFileWriter::GetAsString(const Symbol &symbol) { // Collect symbols from constant and specification expressions that are being // referenced directly from other modules; they may require new USE // associations. -static void HarvestSymbolsNeededFromOtherModules( - SourceOrderedSymbolSet &, const Scope &); -static void HarvestSymbolsNeededFromOtherModules( - SourceOrderedSymbolSet &set, const Symbol &symbol, const Scope &scope) { - auto HarvestBound{[&](const Bound &bound) { - if (const auto &expr{bound.GetExplicit()}) { - for (SymbolRef ref : evaluate::CollectSymbols(*expr)) { - set.emplace(*ref); - } - } - }}; - auto HarvestShapeSpec{[&](const ShapeSpec &shapeSpec) { - HarvestBound(shapeSpec.lbound()); - HarvestBound(shapeSpec.ubound()); - }}; - auto HarvestArraySpec{[&](const ArraySpec &arraySpec) { - for (const auto &shapeSpec : arraySpec) { - HarvestShapeSpec(shapeSpec); - } - }}; - - if (symbol.has()) { - if (symbol.scope()) { - HarvestSymbolsNeededFromOtherModules(set, *symbol.scope()); - } - } else if (const auto &generic{symbol.detailsIf()}; - generic && generic->derivedType()) { - const Symbol &dtSym{*generic->derivedType()}; - if (dtSym.has()) { - if (dtSym.scope()) { - HarvestSymbolsNeededFromOtherModules(set, *dtSym.scope()); - } - } else { - CHECK(dtSym.has() || dtSym.has()); - } - } else if (const auto *object{symbol.detailsIf()}) { - HarvestArraySpec(object->shape()); - HarvestArraySpec(object->coshape()); - if (IsNamedConstant(symbol) || scope.IsDerivedType()) { - if (object->init()) { - for (SymbolRef ref : evaluate::CollectSymbols(*object->init())) { - set.emplace(*ref); - } - } - } - } else if (const auto *proc{symbol.detailsIf()}) { - if (proc->init() && *proc->init() && scope.IsDerivedType()) { - set.emplace(**proc->init()); - } - } else if (const auto *subp{symbol.detailsIf()}) { - for (const Symbol *dummy : subp->dummyArgs()) { - if (dummy) { - HarvestSymbolsNeededFromOtherModules(set, *dummy, scope); - } - } - if (subp->isFunction()) { - HarvestSymbolsNeededFromOtherModules(set, subp->result(), scope); - } - } -} - -static void HarvestSymbolsNeededFromOtherModules( - SourceOrderedSymbolSet &set, const Scope &scope) { - for (const auto &[_, symbol] : scope) { - HarvestSymbolsNeededFromOtherModules(set, *symbol, scope); +static SourceOrderedSymbolSet HarvestSymbolsNeededFromOtherModules( + const Scope &scope) { + SourceOrderedSymbolSet set; + for (const Symbol &symbol : CollectAllDependences(scope)) { + set.insert(symbol); } + return set; } void ModFileWriter::PrepareRenamings(const Scope &scope) { @@ -304,8 +246,8 @@ void ModFileWriter::PrepareRenamings(const Scope &scope) { } } // Collect symbols needed from other modules - SourceOrderedSymbolSet symbolsNeeded; - HarvestSymbolsNeededFromOtherModules(symbolsNeeded, scope); + SourceOrderedSymbolSet symbolsNeeded{ + HarvestSymbolsNeededFromOtherModules(scope)}; // Establish any necessary renamings of symbols in other modules // to their names in this scope, creating those new names when needed. auto &renamings{context_.moduleFileOutputRenamings()}; diff --git a/flang/lib/Semantics/symbol-dependence.cpp b/flang/lib/Semantics/symbol-dependence.cpp new file mode 100644 index 0000000000000..2591f609f3d00 --- /dev/null +++ b/flang/lib/Semantics/symbol-dependence.cpp @@ -0,0 +1,356 @@ +//===-- lib/Semantics/symbol-dependence.cpp -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Semantics/symbol-dependence.h" +#include "flang/Common/idioms.h" +#include "flang/Common/restorer.h" +#include "flang/Common/visit.h" +#include + +static constexpr bool EnableDebugging{false}; + +namespace Fortran::semantics { + +// Helper class that collects all of the symbol dependences for a +// given symbol. +class Collector { +public: + explicit Collector(int flags) : flags_{flags} {} + + void CollectSymbolDependences(const Symbol &); + UnorderedSymbolSet MustFollowDependences() { return std::move(dependences_); } + SymbolVector AllDependences() { return std::move(mentions_); } + +private: + // This symbol is depended upon and its declaration must precede + // the symbol of interest. + void MustFollow(const Symbol &x) { + if (!possibleImports_ || !DoesScopeContain(possibleImports_, x)) { + dependences_.insert(x); + } + } + // This symbol is depended upon, but is not necessarily a dependence + // that must precede the symbol of interest in the output of the + // topological sort. + void Need(const Symbol &x) { + if (mentioned_.insert(x).second) { + mentions_.emplace_back(x); + } + } + void Need(const Symbol *x) { + if (x) { + Need(*x); + } + } + + // These overloads of Collect() are mutally recursive, so they're + // packaged as member functions of a class. + void Collect(const Symbol &x) { + Need(x); + const auto *subp{x.detailsIf()}; + if ((subp && subp->isInterface()) || IsDummy(x) || + x.has() || x.has()) { + // can be forward-referenced + } else { + MustFollow(x); + } + } + void Collect(SymbolRef x) { Collect(*x); } + template void Collect(const std::optional &x) { + if (x) { + Collect(*x); + } + } + template void Collect(const A *x) { + if (x) { + Collect(*x); + } + } + void Collect(const UnorderedSymbolSet &x) { + for (const Symbol &symbol : x) { + Collect(symbol); + } + } + void Collect(const SourceOrderedSymbolSet &x) { + for (const Symbol &symbol : x) { + Collect(symbol); + } + } + void Collect(const SymbolVector &x) { + for (const Symbol &symbol : x) { + Collect(symbol); + } + } + void Collect(const Scope &x) { Collect(x.GetSymbols()); } + template void Collect(const evaluate::Expr &x) { + UnorderedSymbolSet exprSyms{evaluate::CollectSymbols(x)}; + for (const Symbol &sym : exprSyms) { + if (!sym.owner().IsDerivedType()) { + Collect(sym); + } + } + } + void Collect(const DeclTypeSpec &type) { + if (type.category() == DeclTypeSpec::Category::Character) { + Collect(type.characterTypeSpec().length()); + } else { + Collect(type.AsDerived()); + } + } + void Collect(const DerivedTypeSpec &type) { + const Symbol &typeSym{type.originalTypeSymbol()}; + if (!derivedTypeReferenceCanBeForward_ || !type.parameters().empty()) { + MustFollow(typeSym); + } + Need(typeSym); + for (const auto &[_, value] : type.parameters()) { + Collect(value); + } + } + void Collect(const ParamValue &x) { Collect(x.GetExplicit()); } + void Collect(const Bound &x) { Collect(x.GetExplicit()); } + void Collect(const ShapeSpec &x) { + Collect(x.lbound()); + Collect(x.ubound()); + } + void Collect(const ArraySpec &x) { + for (const ShapeSpec &shapeSpec : x) { + Collect(shapeSpec); + } + } + + UnorderedSymbolSet mentioned_, dependences_; + SymbolVector mentions_; + int flags_{NoDependenceCollectionFlags}; + bool derivedTypeReferenceCanBeForward_{false}; + const Scope *possibleImports_{nullptr}; +}; + +void Collector::CollectSymbolDependences(const Symbol &symbol) { + if (symbol.has() || symbol.has()) { + // type will be picked up later for the function result, if any + } else if (symbol.has() || symbol.has() || + symbol.has()) { + } else if (IsAllocatableOrPointer(symbol) && symbol.owner().IsDerivedType()) { + bool saveCanBeForward{derivedTypeReferenceCanBeForward_}; + derivedTypeReferenceCanBeForward_ = true; + Collect(symbol.GetType()); + derivedTypeReferenceCanBeForward_ = saveCanBeForward; + } else { + Collect(symbol.GetType()); + } + common::visit( + common::visitors{ + [this, &symbol](const ObjectEntityDetails &x) { + Collect(x.shape()); + Collect(x.coshape()); + if (IsNamedConstant(symbol) || symbol.owner().IsDerivedType()) { + Collect(x.init()); + } + Need(x.commonBlock()); + if (const auto *set{FindEquivalenceSet(symbol)}) { + for (const EquivalenceObject &equivObject : *set) { + Need(equivObject.symbol); + } + } + }, + [this, &symbol](const ProcEntityDetails &x) { + Collect(x.rawProcInterface()); + if (symbol.owner().IsDerivedType()) { + Collect(x.init()); + } + }, + [this](const ProcBindingDetails &x) { Need(x.symbol()); }, + [this, &symbol](const SubprogramDetails &x) { + // Note dummy arguments & result symbol without dependence, unless + // the subprogram is an interface block that might need to IMPORT + // a type. + bool needImports{x.isInterface()}; + auto restorer{common::ScopedSet( + possibleImports_, needImports ? symbol.scope() : nullptr)}; + for (const Symbol *dummy : x.dummyArgs()) { + if (dummy) { + Need(*dummy); + if (needImports) { + CollectSymbolDependences(*dummy); + } + } + } + if (x.isFunction()) { + Need(x.result()); + if (needImports) { + CollectSymbolDependences(x.result()); + } + } + }, + [this, &symbol](const DerivedTypeDetails &x) { + Collect(symbol.scope()); + for (const auto &[_, symbolRef] : x.finals()) { + Need(*symbolRef); + } + }, + [this](const GenericDetails &x) { + Collect(x.derivedType()); + Collect(x.specific()); + if (flags_ & IncludeUsesOfGenerics) { + for (const Symbol &use : x.uses()) { + Collect(use); + } + } + if (flags_ & IncludeSpecificsOfGenerics) { + for (const Symbol &specific : x.specificProcs()) { + Collect(specific); + } + } + }, + [this](const NamelistDetails &x) { + for (const Symbol &symbol : x.objects()) { + Collect(symbol); + } + }, + [this](const CommonBlockDetails &x) { + for (auto ref : x.objects()) { + Collect(*ref); + } + }, + [this](const UseDetails &x) { + if (flags_ & FollowUseAssociations) { + Need(x.symbol()); + } + }, + [this](const HostAssocDetails &x) { Need(x.symbol()); }, + [](const auto &) {}, + }, + symbol.details()); +} + +SymbolVector CollectAllDependences(const Scope &scope, int flags) { + SymbolVector basis{scope.GetSymbols()}; + return CollectAllDependences(basis, flags, &scope); +} + +// Returns a vector of symbols, topologically sorted by dependence +SymbolVector CollectAllDependences( + const SymbolVector &original, int flags, const Scope *forScope) { + std::queue work; + UnorderedSymbolSet enqueued; + for (const Symbol &symbol : original) { + if (!symbol.test(Symbol::Flag::CompilerCreated)) { + work.push(&symbol); + enqueued.insert(symbol); + } + } + // For each symbol, collect its dependences into "topology". + // The "visited" vector and "enqueued" set hold all of the + // symbols considered. + std::map topology; + std::vector visited; + visited.reserve(2 * original.size()); + std::optional forModuleName; + if (forScope && !(flags & NotJustForOneModule)) { + if (const Scope *forModule{FindModuleContaining(*forScope)}) { + forModuleName = forModule->GetName(); + } + } + while (!work.empty()) { + const Symbol &symbol{*work.front()}; + work.pop(); + visited.push_back(&symbol); + Collector collector{flags}; + bool doCollection{true}; + if (forModuleName) { + if (const Scope *symModule{FindModuleContaining(symbol.owner())}) { + if (auto symModName{symModule->GetName()}) { + doCollection = *forModuleName == *symModName; + } + } + } + if (doCollection) { + collector.CollectSymbolDependences(symbol); + } + auto dependences{collector.MustFollowDependences()}; + auto mentions{collector.AllDependences()}; + if constexpr (EnableDebugging) { + for (const Symbol &need : dependences) { + llvm::errs() << "symbol " << symbol << " must follow " << need << '\n'; + } + for (const Symbol &need : mentions) { + llvm::errs() << "symbol " << symbol << " needs " << need << '\n'; + } + } + CHECK(topology.find(&symbol) == topology.end()); + topology.emplace(&symbol, std::move(dependences)); + for (const Symbol &symbol : mentions) { + if (!symbol.test(Symbol::Flag::CompilerCreated)) { + if (enqueued.insert(symbol).second) { + work.push(&symbol); + } + } + } + } + CHECK(enqueued.size() == visited.size()); + // Topological sorting + // Subtle: This inverted topology map uses a SymbolVector, not a set + // of symbols, so that the order of symbols in the final output remains + // deterministic. + std::map invertedTopology; + for (const Symbol *symbol : visited) { + invertedTopology[symbol] = SymbolVector{}; + } + std::map numWaitingFor; + for (const Symbol *symbol : visited) { + auto topoIter{topology.find(symbol)}; + CHECK(topoIter != topology.end()); + const auto &needs{topoIter->second}; + if (needs.empty()) { + work.push(symbol); + } else { + numWaitingFor[symbol] = needs.size(); + for (const Symbol &need : needs) { + invertedTopology[&need].push_back(*symbol); + } + } + } + CHECK(visited.size() == work.size() + numWaitingFor.size()); + SymbolVector resultVector; + while (!work.empty()) { + const Symbol &symbol{*work.front()}; + work.pop(); + resultVector.push_back(symbol); + auto enqueuedIter{enqueued.find(symbol)}; + CHECK(enqueuedIter != enqueued.end()); + enqueued.erase(enqueuedIter); + if (auto invertedIter{invertedTopology.find(&symbol)}; + invertedIter != invertedTopology.end()) { + for (const Symbol &neededBy : invertedIter->second) { + std::size_t stillAwaiting{numWaitingFor[&neededBy] - 1}; + if (stillAwaiting == 0) { + work.push(&neededBy); + } else { + numWaitingFor[&neededBy] = stillAwaiting; + } + } + } + } + if constexpr (EnableDebugging) { + llvm::errs() << "Topological sort failed in CollectAllDependences\n"; + for (const Symbol &remnant : enqueued) { + auto topoIter{topology.find(&remnant)}; + CHECK(topoIter != topology.end()); + llvm::errs() << " remnant symbol " << remnant << " needs:\n"; + for (const Symbol &n : topoIter->second) { + llvm::errs() << " " << n << '\n'; + } + } + } + CHECK(enqueued.empty()); + CHECK(resultVector.size() == visited.size()); + return resultVector; +} + +} // namespace Fortran::semantics diff --git a/flang/test/Semantics/modfile44.f90 b/flang/test/Semantics/modfile44.f90 index 23d93b18a5a1f..48f6250cf6ce8 100644 --- a/flang/test/Semantics/modfile44.f90 +++ b/flang/test/Semantics/modfile44.f90 @@ -45,6 +45,8 @@ function foo(j) result(res) !Expect: m2.mod !module m2 +!use m1,only:xyz +!private::xyz !contains !function foo(j) result(res) !use m1,only:xyz diff --git a/flang/test/Semantics/modfile69.f90 b/flang/test/Semantics/modfile69.f90 index 6586e0524f5ea..f969573efdd78 100644 --- a/flang/test/Semantics/modfile69.f90 +++ b/flang/test/Semantics/modfile69.f90 @@ -36,6 +36,8 @@ subroutine sub(x) !Expect: m3.mod !module m3 +!use m2,only:bar +!private::bar !contains !subroutine sub(x) !use m2,only:bar