Skip to content

[offload][SYCL] Add Module splitting by categories. #131347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 0 additions & 101 deletions llvm/include/llvm/Transforms/Utils/SYCLUtils.h

This file was deleted.

34 changes: 18 additions & 16 deletions llvm/include/llvm/Transforms/Utils/SplitModuleByCategory.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
// Functionality to split a module by categories.
//===----------------------------------------------------------------------===//

#ifndef LLVM_FRONTEND_SYCL_SPLIT_MODULE_H
#define LLVM_FRONTEND_SYCL_SPLIT_MODULE_H
#ifndef LLVM_TRANSFORM_UTILS_SPLIT_MODULE_BY_CATEGORY_H
#define LLVM_TRANSFORM_UTILS_SPLIT_MODULE_BY_CATEGORY_H

#include "llvm/ADT/STLFunctionalExtras.h"

Expand All @@ -22,21 +22,23 @@ namespace llvm {
class Module;
class Function;

namespace sycl {

/// FunctionCategorizer returns integer category for the given Function.
/// Otherwise, it returns std::nullopt if function doesn't have a category.
using FunctionCategorizer = function_ref<std::optional<int>(const Function &F)>;

using PostSplitCallbackType = function_ref<void(std::unique_ptr<Module> Part)>;

/// Splits the given module \p M.
/// Every split image is being passed to \p Callback for further possible
/// Splits the given module \p M using the given \p FunctionCategorizer.
/// \p FunctionCategorizer returns integer category for an input Function.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A side note, I think it'd be more helpful (at least putting on my AMD hat) to be able to determine where a global variable goes as well, if we'd like to make this pass generic to support all potential targets. The reason is, for AMDGPU, we probably need to categorize all functions that could potentially reference a global variable in the sam module, due to the lowering of LDS (shared) variables.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave this for later.

/// It may return std::nullopt if a function doesn't have a category.
/// Module's functions are being grouped by categories. Every such group
/// populates a call graph containing group's functions themselves and all
/// reachable functions and globals. Split outputs are populated from each call
/// graph associated with some category.
///
/// Every split output is being passed to \p Callback for further possible
/// processing.
void splitModuleByCategory(std::unique_ptr<Module> M, FunctionCategorizer FC,
PostSplitCallbackType Callback);
///
/// Currently, the supported targets are SPIRV, AMDGPU and NVPTX.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update this comment. I am not sure if the targets are restrictive. I think the restriction is whether the input module has recursive calls or not.

Thanks

Copy link
Contributor

@shiltian shiltian Jun 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Now we have call backs so it should just work for all.

Update:

This is probably because isKernel function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably because isKernel function.

Yes and the algorithm was implemented with assumption that the input is a heterogenous program, which usually don't have recursion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which usually don't have recursion.'

FWIW, "usually" is doing the heavy lifting here. Please do not assume anything about GPU codes that is not required. So, recursion should be assumed to happen.

void splitModuleByCategory(
std::unique_ptr<Module> M,
function_ref<std::optional<int>(const Function &F)> FunctionCategorizer,
function_ref<void(std::unique_ptr<Module> Part)> Callback);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be helpful to pass the category to the callback in addition to the module. But if this is not needed right now, we can do this later.


} // namespace sycl
} // namespace llvm

#endif // LLVM_FRONTEND_SYCL_SPLIT_MODULE_H
#endif // LLVM_TRANSFORM_UTILS_SPLIT_MODULE_BY_CATEGORY_H
1 change: 0 additions & 1 deletion llvm/lib/Transforms/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ add_llvm_component_library(LLVMTransformUtils
SplitModule.cpp
SplitModuleByCategory.cpp
StripNonLineTableDebugInfo.cpp
SYCLUtils.cpp
SymbolRewriter.cpp
UnifyFunctionExitNodes.cpp
UnifyLoopExits.cpp
Expand Down
117 changes: 0 additions & 117 deletions llvm/lib/Transforms/Utils/SYCLUtils.cpp

This file was deleted.

35 changes: 18 additions & 17 deletions llvm/lib/Transforms/Utils/SplitModuleByCategory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,11 @@
#include <utility>

using namespace llvm;
using namespace llvm::sycl;

#define DEBUG_TYPE "sycl-split-module"
#define DEBUG_TYPE "split-module-by-category"

namespace {

bool isKernel(const Function &F) {
return F.getCallingConv() == CallingConv::SPIR_KERNEL ||
F.getCallingConv() == CallingConv::AMDGPU_KERNEL ||
F.getCallingConv() == CallingConv::PTX_Kernel;
}

// A vector that contains a group of function with the same category.
using EntryPointSet = SetVector<const Function *>;

Expand Down Expand Up @@ -106,6 +99,12 @@ class ModuleDesc {
#endif
};

bool isKernel(const Function &F) {
return F.getCallingConv() == CallingConv::SPIR_KERNEL ||
F.getCallingConv() == CallingConv::AMDGPU_KERNEL ||
F.getCallingConv() == CallingConv::PTX_Kernel;
}

// Represents "dependency" or "use" graph of global objects (functions and
// global variables) in a module. It is used during device code split to
// understand which global variables and functions (other than entry points)
Expand Down Expand Up @@ -299,17 +298,18 @@ class ModuleSplitter {
bool hasMoreSplits() const { return Groups.size() > 0; }
};

EntryPointGroupVec selectEntryPointGroups(const Module &M,
FunctionCategorizer FC) {
EntryPointGroupVec
selectEntryPointGroups(const Module &M,
function_ref<std::optional<int>(const Function &F)> FC) {
// std::map is used here to ensure stable ordering of entry point groups,
// which is based on their contents, this greatly helps LIT tests
std::map<int, EntryPointSet> EntryPointsMap;

for (const auto &F : M.functions()) {
if (auto Key = FC(F); Key) {
auto It = EntryPointsMap.find(*Key);
if (auto Category = FC(F); Category) {
auto It = EntryPointsMap.find(*Category);
if (It == EntryPointsMap.end())
It = EntryPointsMap.emplace(*Key, EntryPointSet()).first;
It = EntryPointsMap.emplace(*Category, EntryPointSet()).first;

It->second.insert(&F);
}
Expand All @@ -326,10 +326,11 @@ EntryPointGroupVec selectEntryPointGroups(const Module &M,

} // namespace

void llvm::sycl::splitModuleByCategory(std::unique_ptr<Module> M,
FunctionCategorizer FC,
PostSplitCallbackType Callback) {
EntryPointGroupVec Groups = selectEntryPointGroups(*M, FC);
void llvm::splitModuleByCategory(
std::unique_ptr<Module> M,
function_ref<std::optional<int>(const Function &F)> FunctionCategorizer,
function_ref<void(std::unique_ptr<Module> Part)> Callback) {
EntryPointGroupVec Groups = selectEntryPointGroups(*M, FunctionCategorizer);
ModuleDesc MD = std::move(M);
ModuleSplitter Splitter(std::move(MD), std::move(Groups));
while (Splitter.hasMoreSplits()) {
Expand Down
Loading
Loading