Skip to content

Commit e448888

Browse files
joppermNaghasan
andauthored
[SYCL][Fusion] Support heterogeneous ND ranges on HIP (#11805)
This PR enables support for fusing kernels with heterogeneous ND ranges for the HIP target. This leverages the `TargetFusionInterface` established in #11421, however, due to the lack of explicit intrinsics for querying the global/workgroup sizes in the AMDGCN backend, we had to extend the interface further to be able to remap *loads from the dispatch pointer with special offsets* as builtins as well. --------- Signed-off-by: Julian Oppermann <julian.oppermann@codeplay.com> Co-authored-by: Victor Lomuller <victor@codeplay.com>
1 parent a325e62 commit e448888

File tree

13 files changed

+1670
-173
lines changed

13 files changed

+1670
-173
lines changed

libclc/amdgcn-amdhsa/libspirv/workitem/get_local_size.cl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,19 @@
2323
CONST_AS char * __clc_amdgcn_dispatch_ptr(void) __asm("llvm.amdgcn.dispatch.ptr");
2424
#endif
2525

26+
// Mimic `EmitAMDGPUWorkGroupSize` in `clang/lib/CodeGen/CGBuiltin.cpp`.
27+
2628
_CLC_DEF _CLC_OVERLOAD size_t __spirv_WorkgroupSize_x() {
27-
CONST_AS uint * ptr = (CONST_AS uint *) __dispatch_ptr();
28-
return ptr[1] & 0xffffu;
29+
CONST_AS ushort * ptr = (CONST_AS ushort *) __dispatch_ptr();
30+
return ptr[2];
2931
}
3032

3133
_CLC_DEF _CLC_OVERLOAD size_t __spirv_WorkgroupSize_y() {
32-
CONST_AS uint * ptr = (CONST_AS uint *) __dispatch_ptr();
33-
return ptr[1] >> 16;
34+
CONST_AS ushort * ptr = (CONST_AS ushort *) __dispatch_ptr();
35+
return ptr[3];
3436
}
3537

3638
_CLC_DEF _CLC_OVERLOAD size_t __spirv_WorkgroupSize_z() {
37-
CONST_AS uint * ptr = (CONST_AS uint *) __dispatch_ptr();
38-
return ptr[2] & 0xffffu;
39+
CONST_AS ushort * ptr = (CONST_AS ushort *) __dispatch_ptr();
40+
return ptr[4];
3941
}

sycl-fusion/jit-compiler/lib/KernelFusion.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,6 @@ FusionResult KernelFusion::fuseKernels(
100100
"Fusion output target format not supported by this build");
101101
}
102102

103-
if (TargetFormat != BinaryFormat::SPIRV &&
104-
TargetFormat != BinaryFormat::PTX && IsHeterogeneousList) {
105-
return FusionResult{
106-
"Heterogeneous ND ranges not supported for this target"};
107-
}
108-
109103
bool CachingEnabled = ConfigHelper::get<option::JITEnableCaching>();
110104
CacheKeyT CacheKey{KernelsToFuse,
111105
Identities,

sycl-fusion/passes/kernel-fusion/Builtins.cpp

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,9 @@ static raw_ostream &operator<<(raw_ostream &Os, const NDRange &ND) {
3636
<< ND.getLocalSize();
3737
}
3838

39-
/// Will generate a unique function name so that it can be reused in further
40-
/// stages.
41-
static std::string getFunctionName(BuiltinKind K, const NDRange &SrcNDRange,
42-
const NDRange &FusedNDRange) {
39+
std::string Remapper::getFunctionName(BuiltinKind K, const NDRange &SrcNDRange,
40+
const NDRange &FusedNDRange,
41+
uint32_t Idx) {
4342
std::string Res;
4443
raw_string_ostream S{Res};
4544
S << "__" <<
@@ -63,6 +62,8 @@ static std::string getFunctionName(BuiltinKind K, const NDRange &SrcNDRange,
6362
llvm_unreachable("Unhandled kind");
6463
}()
6564
<< "_remapper_" << SrcNDRange << "_" << FusedNDRange;
65+
if (Idx != (uint32_t)-1)
66+
S << "_" << static_cast<char>('x' + Idx);
6667
return S.str();
6768
}
6869

@@ -339,13 +340,8 @@ jit_compiler::Remapper::remapBuiltins(Function *F, const NDRange &SrcNDRange,
339340
// If the builtin should not be remapped, return the original function.
340341
return F;
341342

342-
// Remap given builtin.
343-
const auto Name = getFunctionName(K, SrcNDRange, FusedNDRange);
344-
auto *M = F->getParent();
345-
assert(!M->getFunction(Name) && "Function name should be unique");
346-
347343
return Cached = TargetInfo.createRemapperFunction(
348-
*this, K, F->getName(), Name, M, SrcNDRange, FusedNDRange);
344+
*this, K, F, F->getParent(), SrcNDRange, FusedNDRange);
349345
}
350346
if (TargetInfo.isSafeToNotRemapBuiltin(F)) {
351347
// No need to remap.
@@ -375,20 +371,9 @@ jit_compiler::Remapper::remapBuiltins(Function *F, const NDRange &SrcNDRange,
375371

376372
// Set Cached to support recursive functions.
377373
Cached = Clone;
378-
for (auto &I : instructions(Clone)) {
379-
if (auto *Call = dyn_cast<CallBase>(&I)) {
380-
// Recursive call
381-
auto *OldF = Call->getCalledFunction();
382-
auto ErrOrNewF = remapBuiltins(OldF, SrcNDRange, FusedNDRange);
383-
if (auto Err = ErrOrNewF.takeError()) {
384-
return std::move(Err);
385-
}
386-
// Override called function.
387-
auto *NewF = *ErrOrNewF;
388-
Call->setCalledFunction(NewF);
389-
Call->setCallingConv(NewF->getCallingConv());
390-
Call->setAttributes(NewF->getAttributes());
391-
}
374+
if (auto Err = TargetInfo.scanForBuiltinsToRemap(Clone, *this, SrcNDRange,
375+
FusedNDRange)) {
376+
return Err;
392377
}
393378
return Clone;
394379
}

sycl-fusion/passes/kernel-fusion/Builtins.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ class Remapper {
4747
explicit Remapper(const llvm::TargetFusionInfo &TargetInfo)
4848
: TargetInfo(TargetInfo) {}
4949

50+
///
51+
/// Generate a unique function name for a remapper function.
52+
static std::string getFunctionName(BuiltinKind K, const NDRange &SrcNDRange,
53+
const NDRange &FusedNDRange,
54+
uint32_t Idx = -1);
55+
5056
///
5157
/// Recursively remap index space getters builtins.
5258
llvm::Expected<llvm::Function *> remapBuiltins(llvm::Function *F,

0 commit comments

Comments
 (0)