Skip to content

[flang] Add general symbol dependence collection utility #146968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

klausler
Copy link
Contributor

@klausler klausler commented Jul 3, 2025

Replace HarvestSymbolsNeededFromOtherModules() in mod-file.cpp with a general utility function in Semantics. This new code will find other uses in further rework of hermetic module file generation as the means by which the necessary subsets of symbols in dependency modules are collected.

Replace HarvestSymbolsNeededFromOtherModules() in mod-file.cpp with a
general utility function in Semantics.  This new code will find
other uses in further rework of hermetic module file generation
as the means by which the necessary subsets of symbols in dependency
modules are collected.
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:semantics labels Jul 3, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 3, 2025

@llvm/pr-subscribers-flang-semantics

Author: Peter Klausler (klausler)

Changes

Replace HarvestSymbolsNeededFromOtherModules() in mod-file.cpp with a general utility function in Semantics. This new code will find other uses in further rework of hermetic module file generation as the means by which the necessary subsets of symbols in dependency modules are collected.


Full diff: https://github.com/llvm/llvm-project/pull/146968.diff

6 Files Affected:

  • (added) flang/include/flang/Semantics/symbol-dependence.h (+36)
  • (modified) flang/lib/Semantics/CMakeLists.txt (+1)
  • (modified) flang/lib/Semantics/mod-file.cpp (+9-67)
  • (added) flang/lib/Semantics/symbol-dependence.cpp (+356)
  • (modified) flang/test/Semantics/modfile44.f90 (+2)
  • (modified) flang/test/Semantics/modfile69.f90 (+2)
diff --git a/flang/include/flang/Semantics/symbol-dependence.h b/flang/include/flang/Semantics/symbol-dependence.h
new file mode 100644
index 0000000000000..9bcff564a4c04
--- /dev/null
+++ b/flang/include/flang/Semantics/symbol-dependence.h
@@ -0,0 +1,36 @@
+//===-- include/flang/Semantics/symbol-dependence.h -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_SEMANTICS_SYMBOL_DEPENDENCE_H_
+#define FORTRAN_SEMANTICS_SYMBOL_DEPENDENCE_H_
+
+#include "flang/Semantics/symbol.h"
+
+namespace Fortran::semantics {
+
+// For a set or scope of symbols, computes the transitive closure of their
+// dependences due to their types, bounds, specific procedures, interfaces,
+// initialization, storage association, &c. Includes the original symbol
+// or members of the original set.  Does not include dependences from
+// subprogram definitions, only their interfaces.
+enum DependenceCollectionFlags {
+  NoDependenceCollectionFlags = 0,
+  IncludeOriginalSymbols = 1 << 0,
+  FollowUseAssociations = 1 << 1,
+  IncludeSpecificsOfGenerics = 1 << 2,
+  IncludeUsesOfGenerics = 1 << 3,
+  NotJustForOneModule = 1 << 4,
+};
+
+SymbolVector CollectAllDependences(const SymbolVector &,
+    int = NoDependenceCollectionFlags, const Scope * = nullptr);
+SymbolVector CollectAllDependences(
+    const Scope &, int = NoDependenceCollectionFlags);
+
+} // namespace Fortran::semantics
+#endif // FORTRAN_SEMANTICS_SYMBOL_DEPENDENCE_H_
diff --git a/flang/lib/Semantics/CMakeLists.txt b/flang/lib/Semantics/CMakeLists.txt
index 109bc2dbb8569..c711103e793d8 100644
--- a/flang/lib/Semantics/CMakeLists.txt
+++ b/flang/lib/Semantics/CMakeLists.txt
@@ -48,6 +48,7 @@ add_flang_library(FortranSemantics
   runtime-type-info.cpp
   scope.cpp
   semantics.cpp
+  symbol-dependence.cpp
   symbol.cpp
   tools.cpp
   type.cpp
diff --git a/flang/lib/Semantics/mod-file.cpp b/flang/lib/Semantics/mod-file.cpp
index 82c8536902eb2..ab851587bb49a 100644
--- a/flang/lib/Semantics/mod-file.cpp
+++ b/flang/lib/Semantics/mod-file.cpp
@@ -15,6 +15,7 @@
 #include "flang/Parser/unparse.h"
 #include "flang/Semantics/scope.h"
 #include "flang/Semantics/semantics.h"
+#include "flang/Semantics/symbol-dependence.h"
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
 #include "llvm/Support/FileSystem.h"
@@ -223,72 +224,13 @@ std::string ModFileWriter::GetAsString(const Symbol &symbol) {
 // Collect symbols from constant and specification expressions that are being
 // referenced directly from other modules; they may require new USE
 // associations.
-static void HarvestSymbolsNeededFromOtherModules(
-    SourceOrderedSymbolSet &, const Scope &);
-static void HarvestSymbolsNeededFromOtherModules(
-    SourceOrderedSymbolSet &set, const Symbol &symbol, const Scope &scope) {
-  auto HarvestBound{[&](const Bound &bound) {
-    if (const auto &expr{bound.GetExplicit()}) {
-      for (SymbolRef ref : evaluate::CollectSymbols(*expr)) {
-        set.emplace(*ref);
-      }
-    }
-  }};
-  auto HarvestShapeSpec{[&](const ShapeSpec &shapeSpec) {
-    HarvestBound(shapeSpec.lbound());
-    HarvestBound(shapeSpec.ubound());
-  }};
-  auto HarvestArraySpec{[&](const ArraySpec &arraySpec) {
-    for (const auto &shapeSpec : arraySpec) {
-      HarvestShapeSpec(shapeSpec);
-    }
-  }};
-
-  if (symbol.has<DerivedTypeDetails>()) {
-    if (symbol.scope()) {
-      HarvestSymbolsNeededFromOtherModules(set, *symbol.scope());
-    }
-  } else if (const auto &generic{symbol.detailsIf<GenericDetails>()};
-             generic && generic->derivedType()) {
-    const Symbol &dtSym{*generic->derivedType()};
-    if (dtSym.has<DerivedTypeDetails>()) {
-      if (dtSym.scope()) {
-        HarvestSymbolsNeededFromOtherModules(set, *dtSym.scope());
-      }
-    } else {
-      CHECK(dtSym.has<UseDetails>() || dtSym.has<UseErrorDetails>());
-    }
-  } else if (const auto *object{symbol.detailsIf<ObjectEntityDetails>()}) {
-    HarvestArraySpec(object->shape());
-    HarvestArraySpec(object->coshape());
-    if (IsNamedConstant(symbol) || scope.IsDerivedType()) {
-      if (object->init()) {
-        for (SymbolRef ref : evaluate::CollectSymbols(*object->init())) {
-          set.emplace(*ref);
-        }
-      }
-    }
-  } else if (const auto *proc{symbol.detailsIf<ProcEntityDetails>()}) {
-    if (proc->init() && *proc->init() && scope.IsDerivedType()) {
-      set.emplace(**proc->init());
-    }
-  } else if (const auto *subp{symbol.detailsIf<SubprogramDetails>()}) {
-    for (const Symbol *dummy : subp->dummyArgs()) {
-      if (dummy) {
-        HarvestSymbolsNeededFromOtherModules(set, *dummy, scope);
-      }
-    }
-    if (subp->isFunction()) {
-      HarvestSymbolsNeededFromOtherModules(set, subp->result(), scope);
-    }
-  }
-}
-
-static void HarvestSymbolsNeededFromOtherModules(
-    SourceOrderedSymbolSet &set, const Scope &scope) {
-  for (const auto &[_, symbol] : scope) {
-    HarvestSymbolsNeededFromOtherModules(set, *symbol, scope);
+static SourceOrderedSymbolSet HarvestSymbolsNeededFromOtherModules(
+    const Scope &scope) {
+  SourceOrderedSymbolSet set;
+  for (const Symbol &symbol : CollectAllDependences(scope)) {
+    set.insert(symbol);
   }
+  return set;
 }
 
 void ModFileWriter::PrepareRenamings(const Scope &scope) {
@@ -304,8 +246,8 @@ void ModFileWriter::PrepareRenamings(const Scope &scope) {
     }
   }
   // Collect symbols needed from other modules
-  SourceOrderedSymbolSet symbolsNeeded;
-  HarvestSymbolsNeededFromOtherModules(symbolsNeeded, scope);
+  SourceOrderedSymbolSet symbolsNeeded{
+      HarvestSymbolsNeededFromOtherModules(scope)};
   // Establish any necessary renamings of symbols in other modules
   // to their names in this scope, creating those new names when needed.
   auto &renamings{context_.moduleFileOutputRenamings()};
diff --git a/flang/lib/Semantics/symbol-dependence.cpp b/flang/lib/Semantics/symbol-dependence.cpp
new file mode 100644
index 0000000000000..2591f609f3d00
--- /dev/null
+++ b/flang/lib/Semantics/symbol-dependence.cpp
@@ -0,0 +1,356 @@
+//===-- lib/Semantics/symbol-dependence.cpp -------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Semantics/symbol-dependence.h"
+#include "flang/Common/idioms.h"
+#include "flang/Common/restorer.h"
+#include "flang/Common/visit.h"
+#include <queue>
+
+static constexpr bool EnableDebugging{false};
+
+namespace Fortran::semantics {
+
+// Helper class that collects all of the symbol dependences for a
+// given symbol.
+class Collector {
+public:
+  explicit Collector(int flags) : flags_{flags} {}
+
+  void CollectSymbolDependences(const Symbol &);
+  UnorderedSymbolSet MustFollowDependences() { return std::move(dependences_); }
+  SymbolVector AllDependences() { return std::move(mentions_); }
+
+private:
+  // This symbol is depended upon and its declaration must precede
+  // the symbol of interest.
+  void MustFollow(const Symbol &x) {
+    if (!possibleImports_ || !DoesScopeContain(possibleImports_, x)) {
+      dependences_.insert(x);
+    }
+  }
+  // This symbol is depended upon, but is not necessarily a dependence
+  // that must precede the symbol of interest in the output of the
+  // topological sort.
+  void Need(const Symbol &x) {
+    if (mentioned_.insert(x).second) {
+      mentions_.emplace_back(x);
+    }
+  }
+  void Need(const Symbol *x) {
+    if (x) {
+      Need(*x);
+    }
+  }
+
+  // These overloads of Collect() are mutally recursive, so they're
+  // packaged as member functions of a class.
+  void Collect(const Symbol &x) {
+    Need(x);
+    const auto *subp{x.detailsIf<SubprogramDetails>()};
+    if ((subp && subp->isInterface()) || IsDummy(x) ||
+        x.has<CommonBlockDetails>() || x.has<NamelistDetails>()) {
+      // can be forward-referenced
+    } else {
+      MustFollow(x);
+    }
+  }
+  void Collect(SymbolRef x) { Collect(*x); }
+  template <typename A> void Collect(const std::optional<A> &x) {
+    if (x) {
+      Collect(*x);
+    }
+  }
+  template <typename A> void Collect(const A *x) {
+    if (x) {
+      Collect(*x);
+    }
+  }
+  void Collect(const UnorderedSymbolSet &x) {
+    for (const Symbol &symbol : x) {
+      Collect(symbol);
+    }
+  }
+  void Collect(const SourceOrderedSymbolSet &x) {
+    for (const Symbol &symbol : x) {
+      Collect(symbol);
+    }
+  }
+  void Collect(const SymbolVector &x) {
+    for (const Symbol &symbol : x) {
+      Collect(symbol);
+    }
+  }
+  void Collect(const Scope &x) { Collect(x.GetSymbols()); }
+  template <typename T> void Collect(const evaluate::Expr<T> &x) {
+    UnorderedSymbolSet exprSyms{evaluate::CollectSymbols(x)};
+    for (const Symbol &sym : exprSyms) {
+      if (!sym.owner().IsDerivedType()) {
+        Collect(sym);
+      }
+    }
+  }
+  void Collect(const DeclTypeSpec &type) {
+    if (type.category() == DeclTypeSpec::Category::Character) {
+      Collect(type.characterTypeSpec().length());
+    } else {
+      Collect(type.AsDerived());
+    }
+  }
+  void Collect(const DerivedTypeSpec &type) {
+    const Symbol &typeSym{type.originalTypeSymbol()};
+    if (!derivedTypeReferenceCanBeForward_ || !type.parameters().empty()) {
+      MustFollow(typeSym);
+    }
+    Need(typeSym);
+    for (const auto &[_, value] : type.parameters()) {
+      Collect(value);
+    }
+  }
+  void Collect(const ParamValue &x) { Collect(x.GetExplicit()); }
+  void Collect(const Bound &x) { Collect(x.GetExplicit()); }
+  void Collect(const ShapeSpec &x) {
+    Collect(x.lbound());
+    Collect(x.ubound());
+  }
+  void Collect(const ArraySpec &x) {
+    for (const ShapeSpec &shapeSpec : x) {
+      Collect(shapeSpec);
+    }
+  }
+
+  UnorderedSymbolSet mentioned_, dependences_;
+  SymbolVector mentions_;
+  int flags_{NoDependenceCollectionFlags};
+  bool derivedTypeReferenceCanBeForward_{false};
+  const Scope *possibleImports_{nullptr};
+};
+
+void Collector::CollectSymbolDependences(const Symbol &symbol) {
+  if (symbol.has<ProcBindingDetails>() || symbol.has<SubprogramDetails>()) {
+    // type will be picked up later for the function result, if any
+  } else if (symbol.has<UseDetails>() || symbol.has<UseErrorDetails>() ||
+      symbol.has<HostAssocDetails>()) {
+  } else if (IsAllocatableOrPointer(symbol) && symbol.owner().IsDerivedType()) {
+    bool saveCanBeForward{derivedTypeReferenceCanBeForward_};
+    derivedTypeReferenceCanBeForward_ = true;
+    Collect(symbol.GetType());
+    derivedTypeReferenceCanBeForward_ = saveCanBeForward;
+  } else {
+    Collect(symbol.GetType());
+  }
+  common::visit(
+      common::visitors{
+          [this, &symbol](const ObjectEntityDetails &x) {
+            Collect(x.shape());
+            Collect(x.coshape());
+            if (IsNamedConstant(symbol) || symbol.owner().IsDerivedType()) {
+              Collect(x.init());
+            }
+            Need(x.commonBlock());
+            if (const auto *set{FindEquivalenceSet(symbol)}) {
+              for (const EquivalenceObject &equivObject : *set) {
+                Need(equivObject.symbol);
+              }
+            }
+          },
+          [this, &symbol](const ProcEntityDetails &x) {
+            Collect(x.rawProcInterface());
+            if (symbol.owner().IsDerivedType()) {
+              Collect(x.init());
+            }
+          },
+          [this](const ProcBindingDetails &x) { Need(x.symbol()); },
+          [this, &symbol](const SubprogramDetails &x) {
+            // Note dummy arguments & result symbol without dependence, unless
+            // the subprogram is an interface block that might need to IMPORT
+            // a type.
+            bool needImports{x.isInterface()};
+            auto restorer{common::ScopedSet(
+                possibleImports_, needImports ? symbol.scope() : nullptr)};
+            for (const Symbol *dummy : x.dummyArgs()) {
+              if (dummy) {
+                Need(*dummy);
+                if (needImports) {
+                  CollectSymbolDependences(*dummy);
+                }
+              }
+            }
+            if (x.isFunction()) {
+              Need(x.result());
+              if (needImports) {
+                CollectSymbolDependences(x.result());
+              }
+            }
+          },
+          [this, &symbol](const DerivedTypeDetails &x) {
+            Collect(symbol.scope());
+            for (const auto &[_, symbolRef] : x.finals()) {
+              Need(*symbolRef);
+            }
+          },
+          [this](const GenericDetails &x) {
+            Collect(x.derivedType());
+            Collect(x.specific());
+            if (flags_ & IncludeUsesOfGenerics) {
+              for (const Symbol &use : x.uses()) {
+                Collect(use);
+              }
+            }
+            if (flags_ & IncludeSpecificsOfGenerics) {
+              for (const Symbol &specific : x.specificProcs()) {
+                Collect(specific);
+              }
+            }
+          },
+          [this](const NamelistDetails &x) {
+            for (const Symbol &symbol : x.objects()) {
+              Collect(symbol);
+            }
+          },
+          [this](const CommonBlockDetails &x) {
+            for (auto ref : x.objects()) {
+              Collect(*ref);
+            }
+          },
+          [this](const UseDetails &x) {
+            if (flags_ & FollowUseAssociations) {
+              Need(x.symbol());
+            }
+          },
+          [this](const HostAssocDetails &x) { Need(x.symbol()); },
+          [](const auto &) {},
+      },
+      symbol.details());
+}
+
+SymbolVector CollectAllDependences(const Scope &scope, int flags) {
+  SymbolVector basis{scope.GetSymbols()};
+  return CollectAllDependences(basis, flags, &scope);
+}
+
+// Returns a vector of symbols, topologically sorted by dependence
+SymbolVector CollectAllDependences(
+    const SymbolVector &original, int flags, const Scope *forScope) {
+  std::queue<const Symbol *> work;
+  UnorderedSymbolSet enqueued;
+  for (const Symbol &symbol : original) {
+    if (!symbol.test(Symbol::Flag::CompilerCreated)) {
+      work.push(&symbol);
+      enqueued.insert(symbol);
+    }
+  }
+  // For each symbol, collect its dependences into "topology".
+  // The "visited" vector and "enqueued" set hold all of the
+  // symbols considered.
+  std::map<const Symbol *, UnorderedSymbolSet> topology;
+  std::vector<const Symbol *> visited;
+  visited.reserve(2 * original.size());
+  std::optional<SourceName> forModuleName;
+  if (forScope && !(flags & NotJustForOneModule)) {
+    if (const Scope *forModule{FindModuleContaining(*forScope)}) {
+      forModuleName = forModule->GetName();
+    }
+  }
+  while (!work.empty()) {
+    const Symbol &symbol{*work.front()};
+    work.pop();
+    visited.push_back(&symbol);
+    Collector collector{flags};
+    bool doCollection{true};
+    if (forModuleName) {
+      if (const Scope *symModule{FindModuleContaining(symbol.owner())}) {
+        if (auto symModName{symModule->GetName()}) {
+          doCollection = *forModuleName == *symModName;
+        }
+      }
+    }
+    if (doCollection) {
+      collector.CollectSymbolDependences(symbol);
+    }
+    auto dependences{collector.MustFollowDependences()};
+    auto mentions{collector.AllDependences()};
+    if constexpr (EnableDebugging) {
+      for (const Symbol &need : dependences) {
+        llvm::errs() << "symbol " << symbol << " must follow " << need << '\n';
+      }
+      for (const Symbol &need : mentions) {
+        llvm::errs() << "symbol " << symbol << " needs " << need << '\n';
+      }
+    }
+    CHECK(topology.find(&symbol) == topology.end());
+    topology.emplace(&symbol, std::move(dependences));
+    for (const Symbol &symbol : mentions) {
+      if (!symbol.test(Symbol::Flag::CompilerCreated)) {
+        if (enqueued.insert(symbol).second) {
+          work.push(&symbol);
+        }
+      }
+    }
+  }
+  CHECK(enqueued.size() == visited.size());
+  // Topological sorting
+  // Subtle: This inverted topology map uses a SymbolVector, not a set
+  // of symbols, so that the order of symbols in the final output remains
+  // deterministic.
+  std::map<const Symbol *, SymbolVector> invertedTopology;
+  for (const Symbol *symbol : visited) {
+    invertedTopology[symbol] = SymbolVector{};
+  }
+  std::map<const Symbol *, std::size_t> numWaitingFor;
+  for (const Symbol *symbol : visited) {
+    auto topoIter{topology.find(symbol)};
+    CHECK(topoIter != topology.end());
+    const auto &needs{topoIter->second};
+    if (needs.empty()) {
+      work.push(symbol);
+    } else {
+      numWaitingFor[symbol] = needs.size();
+      for (const Symbol &need : needs) {
+        invertedTopology[&need].push_back(*symbol);
+      }
+    }
+  }
+  CHECK(visited.size() == work.size() + numWaitingFor.size());
+  SymbolVector resultVector;
+  while (!work.empty()) {
+    const Symbol &symbol{*work.front()};
+    work.pop();
+    resultVector.push_back(symbol);
+    auto enqueuedIter{enqueued.find(symbol)};
+    CHECK(enqueuedIter != enqueued.end());
+    enqueued.erase(enqueuedIter);
+    if (auto invertedIter{invertedTopology.find(&symbol)};
+        invertedIter != invertedTopology.end()) {
+      for (const Symbol &neededBy : invertedIter->second) {
+        std::size_t stillAwaiting{numWaitingFor[&neededBy] - 1};
+        if (stillAwaiting == 0) {
+          work.push(&neededBy);
+        } else {
+          numWaitingFor[&neededBy] = stillAwaiting;
+        }
+      }
+    }
+  }
+  if constexpr (EnableDebugging) {
+    llvm::errs() << "Topological sort failed in CollectAllDependences\n";
+    for (const Symbol &remnant : enqueued) {
+      auto topoIter{topology.find(&remnant)};
+      CHECK(topoIter != topology.end());
+      llvm::errs() << "  remnant symbol " << remnant << " needs:\n";
+      for (const Symbol &n : topoIter->second) {
+        llvm::errs() << "   " << n << '\n';
+      }
+    }
+  }
+  CHECK(enqueued.empty());
+  CHECK(resultVector.size() == visited.size());
+  return resultVector;
+}
+
+} // namespace Fortran::semantics
diff --git a/flang/test/Semantics/modfile44.f90 b/flang/test/Semantics/modfile44.f90
index 23d93b18a5a1f..48f6250cf6ce8 100644
--- a/flang/test/Semantics/modfile44.f90
+++ b/flang/test/Semantics/modfile44.f90
@@ -45,6 +45,8 @@ function foo(j) result(res)
 
 !Expect: m2.mod
 !module m2
+!use m1,only:xyz
+!private::xyz
 !contains
 !function foo(j) result(res)
 !use m1,only:xyz
diff --git a/flang/test/Semantics/modfile69.f90 b/flang/test/Semantics/modfile69.f90
index 6586e0524f5ea..f969573efdd78 100644
--- a/flang/test/Semantics/modfile69.f90
+++ b/flang/test/Semantics/modfile69.f90
@@ -36,6 +36,8 @@ subroutine sub(x)
 
 !Expect: m3.mod
 !module m3
+!use m2,only:bar
+!private::bar
 !contains
 !subroutine sub(x)
 !use m2,only:bar

Copy link
Contributor

@clementval clementval left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@akuhlens akuhlens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:semantics flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants