Skip to content

Commit 588d728

Browse files
authored
[SYCL][Fusion] Make target architecture part of JIT cache key (intel#12017)
JIT cache avoids running fusion JIT for previously found fused kernel sequences with a given set of parameters. Add target architecture to the set of parameters to avoid returning a previously fused kernel for a different target. Fused kernels for SPIR-V targets can be reused regardless of different architectures. --------- Signed-off-by: Victor Perez <victor.perez@codeplay.com>
1 parent 550b1b9 commit 588d728

File tree

9 files changed

+153
-10
lines changed

9 files changed

+153
-10
lines changed

sycl-fusion/common/include/Kernel.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,38 @@ enum class ParameterKind : uint32_t {
6161
/// Different binary formats supported as input to the JIT compiler.
6262
enum class BinaryFormat : uint32_t { INVALID, LLVM, SPIRV, PTX, AMDGCN };
6363

64+
/// Unique ID for each supported architecture in the SYCL implementation.
65+
///
66+
/// Values of this type will only be used in the kernel fusion non-persistent
67+
/// JIT. There is no guarantee for backwards compatibility, so this should not
68+
/// be used in persistent caches.
69+
using DeviceArchitecture = unsigned;
70+
71+
class TargetInfo {
72+
public:
73+
static constexpr TargetInfo get(BinaryFormat Format,
74+
DeviceArchitecture Arch) {
75+
if (Format == BinaryFormat::SPIRV) {
76+
/// As an exception, SPIR-V targets have a single common ID (-1), as fused
77+
/// kernels will be reused across SPIR-V devices.
78+
return {Format, DeviceArchitecture(-1)};
79+
}
80+
return {Format, Arch};
81+
}
82+
83+
TargetInfo() = default;
84+
85+
constexpr BinaryFormat getFormat() const { return Format; }
86+
constexpr DeviceArchitecture getArch() const { return Arch; }
87+
88+
private:
89+
constexpr TargetInfo(BinaryFormat Format, DeviceArchitecture Arch)
90+
: Format(Format), Arch(Arch) {}
91+
92+
BinaryFormat Format;
93+
DeviceArchitecture Arch;
94+
};
95+
6496
/// Information about a device intermediate representation module (e.g., SPIR-V,
6597
/// LLVM IR) from DPC++.
6698
struct SYCLKernelBinaryInfo {

sycl-fusion/jit-compiler/include/JITContext.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@ class LLVMContext;
2929
namespace jit_compiler {
3030

3131
using CacheKeyT =
32-
std::tuple<std::vector<std::string>, ParamIdentList, BarrierFlags,
33-
std::vector<ParameterInternalization>, std::vector<JITConstant>,
32+
std::tuple<DeviceArchitecture, std::vector<std::string>, ParamIdentList,
33+
BarrierFlags, std::vector<ParameterInternalization>,
34+
std::vector<JITConstant>,
3435
// This field of the cache is optional because, if all of the
3536
// ranges are equal, we will perform no remapping, so that fused
3637
// kernels can be reused with different lists of equal nd-ranges.

sycl-fusion/jit-compiler/include/Options.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
namespace jit_compiler {
1818

19-
enum OptionID { VerboseOutput, EnableCaching, TargetFormat };
19+
enum OptionID { VerboseOutput, EnableCaching, TargetDeviceInfo };
2020

2121
class OptionPtrBase {};
2222

@@ -80,8 +80,8 @@ struct JITEnableVerbose : public OptionBase<OptionID::VerboseOutput, bool> {};
8080

8181
struct JITEnableCaching : public OptionBase<OptionID::EnableCaching, bool> {};
8282

83-
struct JITTargetFormat
84-
: public OptionBase<OptionID::TargetFormat, BinaryFormat> {};
83+
struct JITTargetInfo
84+
: public OptionBase<OptionID::TargetDeviceInfo, TargetInfo> {};
8585

8686
} // namespace option
8787
} // namespace jit_compiler

sycl-fusion/jit-compiler/lib/KernelFusion.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,18 @@ FusionResult KernelFusion::fuseKernels(
9494

9595
bool IsHeterogeneousList = jit_compiler::isHeterogeneousList(NDRanges);
9696

97-
BinaryFormat TargetFormat = ConfigHelper::get<option::JITTargetFormat>();
97+
TargetInfo TargetInfo = ConfigHelper::get<option::JITTargetInfo>();
98+
BinaryFormat TargetFormat = TargetInfo.getFormat();
99+
DeviceArchitecture TargetArch = TargetInfo.getArch();
98100

99101
if (!isTargetFormatSupported(TargetFormat)) {
100102
return FusionResult(
101103
"Fusion output target format not supported by this build");
102104
}
103105

104106
bool CachingEnabled = ConfigHelper::get<option::JITEnableCaching>();
105-
CacheKeyT CacheKey{KernelsToFuse,
107+
CacheKeyT CacheKey{TargetArch,
108+
KernelsToFuse,
106109
Identities,
107110
BarriersFlags,
108111
Internalization,

sycl/source/detail/device_impl.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,13 @@ class device_impl {
257257
std::string
258258
get_device_info_string(sycl::detail::pi::PiDeviceInfo InfoCode) const;
259259

260+
/// Get device architecture
261+
ext::oneapi::experimental::architecture getDeviceArch() const;
262+
260263
private:
261264
explicit device_impl(pi_native_handle InteropDevice,
262265
sycl::detail::pi::PiDevice Device,
263266
PlatformImplPtr Platform, const PluginPtr &Plugin);
264-
ext::oneapi::experimental::architecture getDeviceArch() const;
265267
sycl::detail::pi::PiDevice MDevice = 0;
266268
sycl::detail::pi::PiDeviceType MType;
267269
sycl::detail::pi::PiDevice MRootDevice = nullptr;

sycl/source/detail/jit_compiler.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ ::jit_compiler::BinaryFormat getTargetFormat(QueueImplPtr &Queue) {
5656
}
5757
}
5858

59+
::jit_compiler::TargetInfo getTargetInfo(QueueImplPtr &Queue) {
60+
::jit_compiler::BinaryFormat Format = getTargetFormat(Queue);
61+
return ::jit_compiler::TargetInfo::get(
62+
Format, static_cast<::jit_compiler::DeviceArchitecture>(
63+
Queue->getDeviceImplPtr()->getDeviceArch()));
64+
}
65+
5966
std::pair<const RTDeviceBinaryImage *, sycl::detail::pi::PiProgram>
6067
retrieveKernelBinary(QueueImplPtr &Queue, CGExecKernel *KernelCG) {
6168
auto KernelName = KernelCG->getKernelName();
@@ -824,8 +831,9 @@ jit_compiler::fuseKernels(QueueImplPtr Queue,
824831
JITConfig.set<::jit_compiler::option::JITEnableCaching>(
825832
detail::SYCLConfig<detail::SYCL_ENABLE_FUSION_CACHING>::get());
826833

827-
::jit_compiler::BinaryFormat TargetFormat = getTargetFormat(Queue);
828-
JITConfig.set<::jit_compiler::option::JITTargetFormat>(TargetFormat);
834+
::jit_compiler::TargetInfo TargetInfo = getTargetInfo(Queue);
835+
::jit_compiler::BinaryFormat TargetFormat = TargetInfo.getFormat();
836+
JITConfig.set<::jit_compiler::option::JITTargetInfo>(TargetInfo);
829837

830838
auto FusionResult = ::jit_compiler::KernelFusion::fuseKernels(
831839
*MJITContext, std::move(JITConfig), InputKernelInfo, InputKernelNames,
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// REQUIRES: fusion, gpu, (opencl || level_zero)
2+
// RUN: %{build} -O2 -o %t.out
3+
// RUN: env SYCL_RT_WARNING_LEVEL=1 %{run-unfiltered-devices} %t.out 2>&1 | FileCheck %s --implicit-check-not "WRONG a VALUE" --implicit-check-not "WRONG b VALUE"
4+
5+
// Test caching for JIT fused kernels when different SPIR-V devices are
6+
// involved.
7+
8+
#include "./jit_caching_multitarget_common.h"
9+
10+
// Initial invocation
11+
// CHECK: JIT DEBUG: Compiling new kernel, no suitable cached kernel found
12+
13+
// Identical invocation, should lead to JIT cache hit.
14+
// CHECK-NEXT: JIT DEBUG: Re-using cached JIT kernel
15+
// CHECK-NEXT: INFO: Re-using existing device binary for fused kernel
16+
17+
// Invocation with another SPIR-V device involved. Should lead to JIT cache hit.
18+
// CHECK-NEXT: JIT DEBUG: Re-using cached JIT kernel
19+
// CHECK-NEXT: INFO: Re-using existing device binary for fused kernel
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// REQUIRES: fusion, gpu, (hip || cuda)
2+
// RUN: %{build} -fsycl-embed-ir -O2 -o %t.out
3+
// RUN: env SYCL_RT_WARNING_LEVEL=1 %{run-unfiltered-devices} %t.out 2>&1 | FileCheck %s --implicit-check-not "WRONG a VALUE" --implicit-check-not "WRONG b VALUE"
4+
// XFAIL: *
5+
6+
// COM: This test is expected to fail on CI, as CUDA and HIP runners do not
7+
// provide a CPU backend.
8+
9+
// Test caching for JIT fused kernels when devices with different architectures
10+
// are involved.
11+
12+
#include "./jit_caching_multitarget_common.h"
13+
14+
// Initial invocation
15+
// CHECK: JIT DEBUG: Compiling new kernel, no suitable cached kernel found
16+
17+
// Identical invocation, should lead to JIT cache hit.
18+
// CHECK-NEXT: JIT DEBUG: Re-using cached JIT kernel
19+
// CHECK-NEXT: INFO: Re-using existing device binary for fused kernel
20+
21+
// Invocation with a device with a different architecture. Should trigger JIT
22+
// compilation.
23+
// CHECK-NEXT: JIT DEBUG: Compiling new kernel, no suitable cached kernel found
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Test caching for JIT fused kernels when different devices are involved.
2+
3+
#include <sycl/sycl.hpp>
4+
5+
using namespace sycl;
6+
7+
constexpr inline std::size_t size(1024);
8+
9+
class Kernel0;
10+
class Kernel1;
11+
12+
void performFusion(queue q, std::size_t *a, std::size_t *b) {
13+
{
14+
buffer<std::size_t> a_buf(a, size);
15+
buffer<std::size_t> b_buf(b, size);
16+
17+
ext::codeplay::experimental::fusion_wrapper fw{q};
18+
fw.start_fusion();
19+
q.submit([&](handler &cgh) {
20+
accessor a(a_buf, cgh, write_only, no_init);
21+
cgh.parallel_for<Kernel0>(size, [=](id<1> i) { a[i] = i; });
22+
});
23+
q.submit([&](handler &cgh) {
24+
accessor a(a_buf, cgh, read_only);
25+
accessor b(b_buf, cgh, write_only, no_init);
26+
cgh.parallel_for<Kernel1>(size, [=](id<1> i) { b[i] = a[i] * 2; });
27+
});
28+
fw.complete_fusion();
29+
}
30+
for (std::size_t i = 0; i < size; ++i) {
31+
assert(a[i] == i && "WRONG a VALUE");
32+
assert(b[i] == i * 2 && "WRONG b VALUE");
33+
}
34+
}
35+
36+
int main() {
37+
queue q{gpu_selector_v,
38+
ext::codeplay::experimental::property::queue::enable_fusion{}};
39+
queue q_cpu{cpu_selector_v,
40+
ext::codeplay::experimental::property::queue::enable_fusion{}};
41+
42+
std::vector<std::size_t> a(size);
43+
std::vector<std::size_t> b(size);
44+
45+
// Initial invocation
46+
performFusion(q, a.data(), b.data());
47+
48+
// Identical invocation, should lead to JIT cache hit.
49+
performFusion(q, a.data(), b.data());
50+
51+
// Invocation on CPU device.
52+
performFusion(q_cpu, a.data(), b.data());
53+
54+
return 0;
55+
}

0 commit comments

Comments
 (0)