Skip to content

Commit 28410c6

Browse files
committed
Implement experimental 2way prefetch function
1 parent 0c7e1cb commit 28410c6

File tree

10 files changed

+209
-12
lines changed

10 files changed

+209
-12
lines changed

sycl/cmake/modules/FetchUnifiedRuntime.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,12 @@ if(SYCL_UR_USE_FETCH_CONTENT)
116116
CACHE PATH "Path to external '${name}' adapter source dir" FORCE)
117117
endfunction()
118118

119-
set(UNIFIED_RUNTIME_REPO "https://github.com/oneapi-src/unified-runtime.git")
119+
set(UNIFIED_RUNTIME_REPO "https://github.com/ianayl/unified-runtime.git")
120120
# commit af7e275b509b41f54a66743ebf748dfb51668abf
121121
# Author: Maosu Zhao <maosu.zhao@intel.com>
122122
# Date: Thu Oct 17 16:31:21 2024 +0800
123123
# [DeviceSanitizer] Refactor the code to manage shadow memory (#2127)
124-
set(UNIFIED_RUNTIME_TAG af7e275b509b41f54a66743ebf748dfb51668abf)
124+
set(UNIFIED_RUNTIME_TAG 644eef9f8dfcc3931fd2883f24eace60ca98b3ed)
125125

126126
set(UMF_BUILD_EXAMPLES OFF CACHE INTERNAL "EXAMPLES")
127127
# Due to the use of dependentloadflag and no installer for UMF and hwloc we need
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//==-------- usm_prefetch_exp.hpp --- SYCL USM prefetch extensions ---------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
11+
namespace sycl {
12+
inline namespace _V1 {
13+
14+
namespace ext::oneapi::experimental {
15+
16+
/// @brief Indicates USM memory migration direction: either from host to device, or device to host.
17+
enum class migration_direction {
18+
HOST_TO_DEVICE, /// Move data from host USM to device USM
19+
DEVICE_TO_HOST /// Move data from device USM to host USM
20+
};
21+
22+
} // namespace ext::oneapi::experimental
23+
} // namespace _V1
24+
} // namespace sycl

sycl/include/sycl/handler.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <sycl/ext/oneapi/experimental/cluster_group_prop.hpp>
3636
#include <sycl/ext/oneapi/experimental/graph.hpp>
3737
#include <sycl/ext/oneapi/experimental/raw_kernel_arg.hpp>
38+
#include <sycl/ext/oneapi/experimental/USM/prefetch_exp.hpp>
3839
#include <sycl/ext/oneapi/experimental/use_root_sync_prop.hpp>
3940
#include <sycl/ext/oneapi/experimental/virtual_functions.hpp>
4041
#include <sycl/ext/oneapi/kernel_properties/properties.hpp>
@@ -2824,6 +2825,17 @@ class __SYCL_EXPORT handler {
28242825
/// \param Count is a number of bytes to be prefetched.
28252826
void prefetch(const void *Ptr, size_t Count);
28262827

2828+
/// Experimental implementation of prefetch supporting bidirectional USM data
2829+
/// migration: Provides hints to the runtime library that data should be made
2830+
/// available on a device earlier than Unified Shared Memory would normally
2831+
/// require it to be available.
2832+
///
2833+
/// \param CGH is the handler to be used for prefetching.
2834+
/// \param Ptr is a USM pointer to the memory to be prefetched to the destination.
2835+
/// \param Count is a number of bytes to be prefetched.
2836+
/// \param Direction indicates the direction to prefetch data to/from.
2837+
void ext_oneapi_prefetch_exp(const void* Ptr, size_t Count, ext::oneapi::experimental::migration_direction Direction = ext::oneapi::experimental::migration_direction::HOST_TO_DEVICE);
2838+
28272839
/// Provides additional information to the underlying runtime about how
28282840
/// different allocations are used.
28292841
///
@@ -3253,6 +3265,9 @@ class __SYCL_EXPORT handler {
32533265
detail::code_location MCodeLoc = {};
32543266
bool MIsFinalized = false;
32553267
event MLastEvent;
3268+
/// Enum to indicate USM data migration direction
3269+
ext::oneapi::experimental::migration_direction MDirection = ext::oneapi::experimental::migration_direction::HOST_TO_DEVICE;
3270+
32563271

32573272
// Make queue_impl class friend to be able to call finalize method.
32583273
friend class detail::queue_impl;

sycl/include/sycl/queue.hpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <sycl/ext/oneapi/device_global/device_global.hpp> // for device_global
3232
#include <sycl/ext/oneapi/device_global/properties.hpp> // for device_image_s...
3333
#include <sycl/ext/oneapi/experimental/graph.hpp> // for command_graph...
34+
#include <sycl/ext/oneapi/experimental/USM/prefetch_exp.hpp> // for migration...
3435
#include <sycl/ext/oneapi/properties/properties.hpp> // for empty_properti...
3536
#include <sycl/handler.hpp> // for handler, isDev...
3637
#include <sycl/id.hpp> // for id
@@ -745,6 +746,24 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
745746
TlsCodeLocCapture.query());
746747
}
747748

749+
/// Experimental implementation of prefetch supporting bidirectional USM data
750+
/// migration: Provides hints to the runtime library that data should be made
751+
/// available on a device earlier than Unified Shared Memory would normally
752+
/// require it to be available.
753+
///
754+
/// \param Ptr is a USM pointer to the memory to be prefetched to the device.
755+
/// \param Count is a number of bytes to be prefetched.
756+
/// \param Direction indicates the direction to prefetch data to/from.
757+
/// \return an event representing prefetch operation.
758+
event ext_oneapi_prefetch_exp(
759+
const void *Ptr, size_t Count,
760+
ext::oneapi::experimental::migration_direction Direction = ext::oneapi::experimental::migration_direction::HOST_TO_DEVICE,
761+
const detail::code_location &CodeLoc = detail::code_location::current()) {
762+
detail::tls_code_loc_t TlsCodeLocCapture(CodeLoc);
763+
return submit([=](handler &CGH) { CGH.ext_oneapi_prefetch_exp(Ptr, Count, Direction); },
764+
TlsCodeLocCapture.query());
765+
}
766+
748767
/// Provides hints to the runtime library that data should be made available
749768
/// on a device earlier than Unified Shared Memory would normally require it
750769
/// to be available.
@@ -765,6 +784,29 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
765784
TlsCodeLocCapture.query());
766785
}
767786

787+
/// Experimental implementation of prefetch supporting bidirectional USM data
788+
/// migration: Provides hints to the runtime library that data should be made
789+
/// available on a device earlier than Unified Shared Memory would normally
790+
/// require it to be available.
791+
///
792+
/// \param Ptr is a USM pointer to the memory to be prefetched to the device.
793+
/// \param Count is a number of bytes to be prefetched.
794+
/// \param DepEvent is an event that specifies the kernel dependencies.
795+
/// \param Direction indicates the direction to prefetch data to/from.
796+
/// \return an event representing prefetch operation.
797+
event ext_oneapi_prefetch_exp(
798+
const void *Ptr, size_t Count, event DepEvent,
799+
ext::oneapi::experimental::migration_direction Direction = ext::oneapi::experimental::migration_direction::HOST_TO_DEVICE,
800+
const detail::code_location &CodeLoc = detail::code_location::current()) {
801+
detail::tls_code_loc_t TlsCodeLocCapture(CodeLoc);
802+
return submit(
803+
[=](handler &CGH) {
804+
CGH.depends_on(DepEvent);
805+
CGH.ext_oneapi_prefetch_exp(Ptr, Count, Direction);
806+
},
807+
TlsCodeLocCapture.query());
808+
}
809+
768810
/// Provides hints to the runtime library that data should be made available
769811
/// on a device earlier than Unified Shared Memory would normally require it
770812
/// to be available.
@@ -786,6 +828,30 @@ class __SYCL_EXPORT queue : public detail::OwnerLessBase<queue> {
786828
TlsCodeLocCapture.query());
787829
}
788830

831+
/// Experimental implementation of prefetch supporting bidirectional USM data
832+
/// migration: Provides hints to the runtime library that data should be made
833+
/// available on a device earlier than Unified Shared Memory would normally
834+
/// require it to be available.
835+
///
836+
/// \param Ptr is a USM pointer to the memory to be prefetched to the device.
837+
/// \param Count is a number of bytes to be prefetched.
838+
/// \param DepEvents is a vector of events that specifies the kernel
839+
/// dependencies.
840+
/// \param Direction indicates the direction to prefetch data to/from.
841+
/// \return an event representing prefetch operation.
842+
event ext_oneapi_prefetch_exp(
843+
const void *Ptr, size_t Count, const std::vector<event> &DepEvents,
844+
ext::oneapi::experimental::migration_direction Direction = ext::oneapi::experimental::migration_direction::HOST_TO_DEVICE,
845+
const detail::code_location &CodeLoc = detail::code_location::current()) {
846+
detail::tls_code_loc_t TlsCodeLocCapture(CodeLoc);
847+
return submit(
848+
[=](handler &CGH) {
849+
CGH.depends_on(DepEvents);
850+
CGH.ext_oneapi_prefetch_exp(Ptr, Count, Direction);
851+
},
852+
TlsCodeLocCapture.query());
853+
}
854+
789855
/// Copies data from one 2D memory region to another, both pointed by
790856
/// USM pointers.
791857
/// No operations is done if \param Width or \param Height is zero. An

sycl/source/detail/cg.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <sycl/exception_list.hpp> // for queue_impl
1818
#include <sycl/kernel.hpp> // for kernel_impl
1919
#include <sycl/kernel_bundle.hpp> // for kernel_bundle_impl
20+
#include <sycl/ext/oneapi/experimental/USM/prefetch_exp.hpp> // for migration_direction
2021

2122
#include <assert.h> // for assert
2223
#include <memory> // for shared_ptr, unique_ptr
@@ -390,14 +391,16 @@ class CGFillUSM : public CG {
390391
class CGPrefetchUSM : public CG {
391392
void *MDst;
392393
size_t MLength;
394+
ext::oneapi::experimental::migration_direction MDirection;
393395

394396
public:
395-
CGPrefetchUSM(void *DstPtr, size_t Length, CG::StorageInitHelper CGData,
397+
CGPrefetchUSM(void *DstPtr, size_t Length, CG::StorageInitHelper CGData, ext::oneapi::experimental::migration_direction Direction,
396398
detail::code_location loc = {})
397399
: CG(CGType::PrefetchUSM, std::move(CGData), std::move(loc)),
398-
MDst(DstPtr), MLength(Length) {}
400+
MDst(DstPtr), MLength(Length), MDirection(Direction) {}
399401
void *getDst() { return MDst; }
400402
size_t getLength() { return MLength; }
403+
ext::oneapi::experimental::migration_direction getDirection() { return MDirection; }
401404
};
402405

403406
/// "Advise USM" command group class.

sycl/source/detail/memory_manager.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -964,16 +964,19 @@ void MemoryManager::fill_usm(void *Mem, QueueImplPtr Queue, size_t Length,
964964
DepEvents.size(), DepEvents.data(), OutEvent);
965965
}
966966

967-
void MemoryManager::prefetch_usm(void *Mem, QueueImplPtr Queue, size_t Length,
967+
void MemoryManager::prefetch_usm(void *Mem, QueueImplPtr Queue, size_t Length, sycl::ext::oneapi::experimental::migration_direction Direction,
968968
std::vector<ur_event_handle_t> DepEvents,
969969
ur_event_handle_t *OutEvent,
970970
const detail::EventImplPtr &OutEventImpl) {
971971
assert(Queue && "USM prefetch must be called with a valid device queue");
972972
const AdapterPtr &Adapter = Queue->getAdapter();
973+
ur_usm_migration_flags_t migration_flag = UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE;
974+
if (Direction == sycl::ext::oneapi::experimental::migration_direction::DEVICE_TO_HOST)
975+
migration_flag = UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST;
973976
if (OutEventImpl != nullptr)
974977
OutEventImpl->setHostEnqueueTime();
975978
Adapter->call<UrApiKind::urEnqueueUSMPrefetch>(Queue->getHandleRef(), Mem,
976-
Length, 0, DepEvents.size(),
979+
Length, migration_flag, DepEvents.size(),
977980
DepEvents.data(), OutEvent);
978981
}
979982

@@ -1610,12 +1613,15 @@ void MemoryManager::ext_oneapi_fill_cmd_buffer(
16101613

16111614
void MemoryManager::ext_oneapi_prefetch_usm_cmd_buffer(
16121615
sycl::detail::ContextImplPtr Context,
1613-
ur_exp_command_buffer_handle_t CommandBuffer, void *Mem, size_t Length,
1616+
ur_exp_command_buffer_handle_t CommandBuffer, void *Mem, size_t Length, sycl::ext::oneapi::experimental::migration_direction Direction,
16141617
std::vector<ur_exp_command_buffer_sync_point_t> Deps,
16151618
ur_exp_command_buffer_sync_point_t *OutSyncPoint) {
16161619
const AdapterPtr &Adapter = Context->getAdapter();
1620+
ur_usm_migration_flags_t migration_flag = UR_USM_MIGRATION_FLAG_HOST_TO_DEVICE;
1621+
if (Direction == sycl::ext::oneapi::experimental::migration_direction::DEVICE_TO_HOST)
1622+
migration_flag = UR_USM_MIGRATION_FLAG_DEVICE_TO_HOST;
16171623
Adapter->call<UrApiKind::urCommandBufferAppendUSMPrefetchExp>(
1618-
CommandBuffer, Mem, Length, ur_usm_migration_flags_t(0), Deps.size(),
1624+
CommandBuffer, Mem, Length, migration_flag, Deps.size(),
16191625
Deps.data(), 0, nullptr, OutSyncPoint, nullptr, nullptr);
16201626
}
16211627

sycl/source/detail/memory_manager.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <sycl/id.hpp>
1515
#include <sycl/property_list.hpp>
1616
#include <sycl/range.hpp>
17+
#include <sycl/ext/oneapi/experimental/USM/prefetch_exp.hpp>
1718

1819
#include <ur_api.h>
1920

@@ -149,7 +150,7 @@ class MemoryManager {
149150
ur_event_handle_t *OutEvent,
150151
const detail::EventImplPtr &OutEventImpl);
151152

152-
static void prefetch_usm(void *Ptr, QueueImplPtr Queue, size_t Len,
153+
static void prefetch_usm(void *Ptr, QueueImplPtr Queue, size_t Len, sycl::ext::oneapi::experimental::migration_direction Direction,
153154
std::vector<ur_event_handle_t> DepEvents,
154155
ur_event_handle_t *OutEvent,
155156
const detail::EventImplPtr &OutEventImpl);
@@ -250,6 +251,7 @@ class MemoryManager {
250251
static void ext_oneapi_prefetch_usm_cmd_buffer(
251252
sycl::detail::ContextImplPtr Context,
252253
ur_exp_command_buffer_handle_t CommandBuffer, void *Mem, size_t Length,
254+
sycl::ext::oneapi::experimental::migration_direction Direction,
253255
std::vector<ur_exp_command_buffer_sync_point_t> Deps,
254256
ur_exp_command_buffer_sync_point_t *OutSyncPoint);
255257

sycl/source/detail/scheduler/commands.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2867,7 +2867,7 @@ ur_result_t ExecCGCommand::enqueueImpCommandBuffer() {
28672867
CGPrefetchUSM *Prefetch = (CGPrefetchUSM *)MCommandGroup.get();
28682868
MemoryManager::ext_oneapi_prefetch_usm_cmd_buffer(
28692869
MQueue->getContextImplPtr(), MCommandBuffer, Prefetch->getDst(),
2870-
Prefetch->getLength(), std::move(MSyncPointDeps), &OutSyncPoint);
2870+
Prefetch->getLength(), Prefetch->getDirection(), std::move(MSyncPointDeps), &OutSyncPoint);
28712871
MEvent->setSyncPoint(OutSyncPoint);
28722872
return UR_RESULT_SUCCESS;
28732873
}
@@ -3044,7 +3044,7 @@ ur_result_t ExecCGCommand::enqueueImpQueue() {
30443044
case CGType::PrefetchUSM: {
30453045
CGPrefetchUSM *Prefetch = (CGPrefetchUSM *)MCommandGroup.get();
30463046
MemoryManager::prefetch_usm(Prefetch->getDst(), MQueue,
3047-
Prefetch->getLength(), std::move(RawEvents),
3047+
Prefetch->getLength(), Prefetch->getDirection(), std::move(RawEvents),
30483048
Event, MEvent);
30493049
if (Event)
30503050
MEvent->setHandle(*Event);

sycl/source/handler.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ event handler::finalize() {
377377
break;
378378
case detail::CGType::PrefetchUSM:
379379
CommandGroup.reset(new detail::CGPrefetchUSM(
380-
MDstPtr, MLength, std::move(impl->CGData), MCodeLoc));
380+
MDstPtr, MLength, std::move(impl->CGData), MDirection, MCodeLoc));
381381
break;
382382
case detail::CGType::AdviseUSM:
383383
CommandGroup.reset(new detail::CGAdviseUSM(MDstPtr, MLength, impl->MAdvice,
@@ -970,6 +970,14 @@ void handler::prefetch(const void *Ptr, size_t Count) {
970970
setType(detail::CGType::PrefetchUSM);
971971
}
972972

973+
void handler::ext_oneapi_prefetch_exp(const void* ptr, size_t Count, ext::oneapi::experimental::migration_direction Direction) {
974+
throwIfActionIsCreated();
975+
MDstPtr = const_cast<void *>(ptr);
976+
MLength = Count;
977+
MDirection = Direction;
978+
setType(sycl::detail::CGType::PrefetchUSM);
979+
}
980+
973981
void handler::mem_advise(const void *Ptr, size_t Count, int Advice) {
974982
throwIfActionIsCreated();
975983
MDstPtr = const_cast<void *>(Ptr);

sycl/test-e2e/USM/prefetch_exp.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
//==---- prefetch_exp.cpp - Experimental 2-way USM prefetch test ------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// RUN: %{build} -o %t1.out
10+
// RUN: %{run} %t1.out
11+
12+
#include <sycl/detail/core.hpp>
13+
#include <sycl/usm.hpp>
14+
#include <sycl/ext/oneapi/experimental/USM/prefetch_exp.hpp>
15+
16+
using namespace sycl;
17+
18+
static constexpr int count = 100;
19+
20+
int main() {
21+
queue q([](exception_list el) {
22+
for (auto &e : el)
23+
throw e;
24+
});
25+
if (q.get_device().get_info<info::device::usm_shared_allocations>()) {
26+
float *src = (float *)malloc_shared(sizeof(float) * count, q.get_device(),
27+
q.get_context());
28+
float *dest = (float *)malloc_shared(sizeof(float) * count, q.get_device(),
29+
q.get_context());
30+
for (int i = 0; i < count; i++)
31+
src[i] = i;
32+
33+
// Test host to device prefetch_exp(handler &CGH, ..)
34+
{
35+
event init_prefetch = q.submit(
36+
[&](handler &cgh) { cgh.ext_oneapi_prefetch_exp(src, sizeof(float) * count); });
37+
38+
q.submit([&](handler &cgh) {
39+
cgh.depends_on(init_prefetch);
40+
cgh.single_task<class double_dest>([=]() {
41+
for (int i = 0; i < count; i++)
42+
dest[i] = 2 * src[i];
43+
});
44+
}
45+
q.wait_and_throw();
46+
47+
for (int i = 0; i < count; i++) {
48+
assert(dest[i] == i * 2);
49+
}
50+
}
51+
52+
// Test queue::prefetch
53+
{
54+
event init_prefetch = q.ext_oneapi_prefetch_exp(src, sizeof(float) * count);
55+
56+
q.submit([&](handler &cgh) {
57+
cgh.depends_on(init_prefetch);
58+
cgh.single_task<class double_dest3>([=]() {
59+
for (int i = 0; i < count; i++)
60+
dest[i] = 3 * src[i];
61+
});
62+
});
63+
q.wait_and_throw();
64+
65+
for (int i = 0; i < count; i++) {
66+
assert(dest[i] == i * 3);
67+
}
68+
}
69+
free(src, q);
70+
free(dest, q);
71+
}
72+
return 0;
73+
}

0 commit comments

Comments
 (0)