Skip to content

Commit 822cf9b

Browse files
fabiomestreEwanC
andauthored
[SYCL][Graph] Add common reference semantics (#16788)
Adds missing common reference semantic functionality such as operator==, operator!= and hash functions to all sycl graph related classes. --------- Co-authored-by: Ewan Crawford <ewan@codeplay.com>
1 parent a63f8b4 commit 822cf9b

File tree

6 files changed

+333
-5
lines changed

6 files changed

+333
-5
lines changed

sycl/include/sycl/ext/oneapi/experimental/graph.hpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,14 @@ class __SYCL_EXPORT node {
142142
/// Update the Range of this node if it is a kernel execution node
143143
template <int Dimensions> void update_range(range<Dimensions> executionRange);
144144

145+
/// Common Reference Semantics
146+
friend bool operator==(const node &LHS, const node &RHS) {
147+
return LHS.impl == RHS.impl;
148+
}
149+
friend bool operator!=(const node &LHS, const node &RHS) {
150+
return !operator==(LHS, RHS);
151+
}
152+
145153
private:
146154
node(const std::shared_ptr<detail::node_impl> &Impl) : impl(Impl) {}
147155

@@ -181,6 +189,16 @@ class __SYCL_EXPORT dynamic_command_group {
181189
size_t get_active_index() const;
182190
void set_active_index(size_t Index);
183191

192+
/// Common Reference Semantics
193+
friend bool operator==(const dynamic_command_group &LHS,
194+
const dynamic_command_group &RHS) {
195+
return LHS.impl == RHS.impl;
196+
}
197+
friend bool operator!=(const dynamic_command_group &LHS,
198+
const dynamic_command_group &RHS) {
199+
return !operator==(LHS, RHS);
200+
}
201+
184202
private:
185203
template <class Obj>
186204
friend const decltype(Obj::impl) &
@@ -307,6 +325,16 @@ class __SYCL_EXPORT modifiable_command_graph
307325
/// Get a list of all root nodes (nodes without dependencies) in this graph.
308326
std::vector<node> get_root_nodes() const;
309327

328+
/// Common Reference Semantics
329+
friend bool operator==(const modifiable_command_graph &LHS,
330+
const modifiable_command_graph &RHS) {
331+
return LHS.impl == RHS.impl;
332+
}
333+
friend bool operator!=(const modifiable_command_graph &LHS,
334+
const modifiable_command_graph &RHS) {
335+
return !operator==(LHS, RHS);
336+
}
337+
310338
protected:
311339
/// Constructor used internally by the runtime.
312340
/// @param Impl Detail implementation class to construct object with.
@@ -386,6 +414,16 @@ class __SYCL_EXPORT executable_command_graph
386414
/// @param Nodes The nodes to use for updating the graph.
387415
void update(const std::vector<node> &Nodes);
388416

417+
/// Common Reference Semantics
418+
friend bool operator==(const executable_command_graph &LHS,
419+
const executable_command_graph &RHS) {
420+
return LHS.impl == RHS.impl;
421+
}
422+
friend bool operator!=(const executable_command_graph &LHS,
423+
const executable_command_graph &RHS) {
424+
return !operator==(LHS, RHS);
425+
}
426+
389427
protected:
390428
/// Constructor used by internal runtime.
391429
/// @param Graph Detail implementation class to construct with.
@@ -452,6 +490,16 @@ class __SYCL_EXPORT dynamic_parameter_base {
452490
Graph,
453491
size_t ParamSize, const void *Data);
454492

493+
/// Common Reference Semantics
494+
friend bool operator==(const dynamic_parameter_base &LHS,
495+
const dynamic_parameter_base &RHS) {
496+
return LHS.impl == RHS.impl;
497+
}
498+
friend bool operator!=(const dynamic_parameter_base &LHS,
499+
const dynamic_parameter_base &RHS) {
500+
return !operator==(LHS, RHS);
501+
}
502+
455503
protected:
456504
void updateValue(const void *NewValue, size_t Size);
457505

@@ -512,3 +560,37 @@ command_graph(const context &SyclContext, const device &SyclDevice,
512560

513561
} // namespace _V1
514562
} // namespace sycl
563+
564+
namespace std {
565+
template <> struct __SYCL_EXPORT hash<sycl::ext::oneapi::experimental::node> {
566+
size_t operator()(const sycl::ext::oneapi::experimental::node &Node) const;
567+
};
568+
569+
template <>
570+
struct __SYCL_EXPORT
571+
hash<sycl::ext::oneapi::experimental::dynamic_command_group> {
572+
size_t operator()(const sycl::ext::oneapi::experimental::dynamic_command_group
573+
&DynamicCGH) const;
574+
};
575+
576+
template <sycl::ext::oneapi::experimental::graph_state State>
577+
struct __SYCL_EXPORT
578+
hash<sycl::ext::oneapi::experimental::command_graph<State>> {
579+
size_t operator()(const sycl::ext::oneapi::experimental::command_graph<State>
580+
&Graph) const {
581+
auto ID = sycl::detail::getSyclObjImpl(Graph)->getID();
582+
return std::hash<decltype(ID)>()(ID);
583+
}
584+
};
585+
586+
template <typename ValueT>
587+
struct __SYCL_EXPORT
588+
hash<sycl::ext::oneapi::experimental::dynamic_parameter<ValueT>> {
589+
size_t
590+
operator()(const sycl::ext::oneapi::experimental::dynamic_parameter<ValueT>
591+
&DynamicParam) const {
592+
auto ID = sycl::detail::getSyclObjImpl(DynamicParam)->getID();
593+
return std::hash<decltype(ID)>()(ID);
594+
}
595+
};
596+
} // namespace std

sycl/source/detail/graph_impl.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,8 @@ graph_impl::graph_impl(const sycl::context &SyclContext,
324324
const sycl::device &SyclDevice,
325325
const sycl::property_list &PropList)
326326
: MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
327-
MEventsMap(), MInorderQueueMap() {
327+
MEventsMap(), MInorderQueueMap(),
328+
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
328329
checkGraphPropertiesAndThrow(PropList);
329330
if (PropList.has_property<property::graph::no_cycle_check>()) {
330331
MSkipCycleChecks = true;
@@ -913,7 +914,8 @@ exec_graph_impl::exec_graph_impl(sycl::context Context,
913914
MExecutionEvents(),
914915
MIsUpdatable(PropList.has_property<property::graph::updatable>()),
915916
MEnableProfiling(
916-
PropList.has_property<property::graph::enable_profiling>()) {
917+
PropList.has_property<property::graph::enable_profiling>()),
918+
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
917919
checkGraphPropertiesAndThrow(PropList);
918920
// If the graph has been marked as updatable then check if the backend
919921
// actually supports that. Devices supporting aspect::ext_oneapi_graph must
@@ -2026,7 +2028,8 @@ void dynamic_parameter_impl::updateCGAccessor(
20262028

20272029
dynamic_command_group_impl::dynamic_command_group_impl(
20282030
const command_graph<graph_state::modifiable> &Graph)
2029-
: MGraph{sycl::detail::getSyclObjImpl(Graph)}, MActiveCGF(0) {}
2031+
: MGraph{sycl::detail::getSyclObjImpl(Graph)}, MActiveCGF(0),
2032+
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {}
20302033

20312034
void dynamic_command_group_impl::finalizeCGFList(
20322035
const std::vector<std::function<void(handler &)>> &CGFList) {
@@ -2150,3 +2153,17 @@ void dynamic_command_group::set_active_index(size_t Index) {
21502153
} // namespace ext
21512154
} // namespace _V1
21522155
} // namespace sycl
2156+
2157+
size_t std::hash<sycl::ext::oneapi::experimental::node>::operator()(
2158+
const sycl::ext::oneapi::experimental::node &Node) const {
2159+
auto ID = sycl::detail::getSyclObjImpl(Node)->getID();
2160+
return std::hash<decltype(ID)>()(ID);
2161+
}
2162+
2163+
size_t
2164+
std::hash<sycl::ext::oneapi::experimental::dynamic_command_group>::operator()(
2165+
const sycl::ext::oneapi::experimental::dynamic_command_group &DynamicCG)
2166+
const {
2167+
auto ID = sycl::detail::getSyclObjImpl(DynamicCG)->getID();
2168+
return std::hash<decltype(ID)>()(ID);
2169+
}

sycl/source/detail/graph_impl.hpp

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,8 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
11201120
return MBarrierDependencyMap[Queue];
11211121
}
11221122

1123+
unsigned long long getID() const { return MID; }
1124+
11231125
private:
11241126
/// Iterate over the graph depth-first and run \p NodeFunc on each node.
11251127
/// @param NodeFunc A function which receives as input a node in the graph to
@@ -1198,6 +1200,10 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
11981200
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
11991201
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
12001202
MBarrierDependencyMap;
1203+
1204+
unsigned long long MID;
1205+
// Used for std::hash in order to create a unique hash for the instance.
1206+
inline static std::atomic<unsigned long long> NextAvailableID = 0;
12011207
};
12021208

12031209
/// Class representing the implementation of command_graph<executable>.
@@ -1297,6 +1303,8 @@ class exec_graph_impl {
12971303

12981304
void updateImpl(std::shared_ptr<node_impl> NodeImpl);
12991305

1306+
unsigned long long getID() const { return MID; }
1307+
13001308
private:
13011309
/// Create a command-group for the node and add it to command-buffer by going
13021310
/// through the scheduler.
@@ -1408,21 +1416,27 @@ class exec_graph_impl {
14081416
// Stores a cache of node ids from modifiable graph nodes to the companion
14091417
// node(s) in this graph. Used for quick access when updating this graph.
14101418
std::multimap<node_impl::id_type, std::shared_ptr<node_impl>> MIDCache;
1419+
1420+
unsigned long long MID;
1421+
// Used for std::hash in order to create a unique hash for the instance.
1422+
inline static std::atomic<unsigned long long> NextAvailableID = 0;
14111423
};
14121424

14131425
class dynamic_parameter_impl {
14141426
public:
14151427
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl,
14161428
size_t ParamSize, const void *Data)
1417-
: MGraph(GraphImpl), MValueStorage(ParamSize) {
1429+
: MGraph(GraphImpl), MValueStorage(ParamSize),
1430+
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
14181431
std::memcpy(MValueStorage.data(), Data, ParamSize);
14191432
}
14201433

14211434
/// sycl_ext_oneapi_raw_kernel_arg constructor
14221435
/// Parameter size is taken from member of raw_kernel_arg object.
14231436
dynamic_parameter_impl(std::shared_ptr<graph_impl> GraphImpl, size_t,
14241437
raw_kernel_arg *Data)
1425-
: MGraph(GraphImpl) {
1438+
: MGraph(GraphImpl),
1439+
MID(NextAvailableID.fetch_add(1, std::memory_order_relaxed)) {
14261440
size_t RawArgSize = Data->MArgSize;
14271441
const void *RawArgData = Data->MArgData;
14281442
MValueStorage.reserve(RawArgSize);
@@ -1493,13 +1507,20 @@ class dynamic_parameter_impl {
14931507
int ArgIndex,
14941508
const sycl::detail::AccessorBaseHost *Acc);
14951509

1510+
unsigned long long getID() const { return MID; }
1511+
14961512
// Weak ptrs to node_impls which will be updated
14971513
std::vector<std::pair<std::weak_ptr<node_impl>, int>> MNodes;
14981514
// Dynamic command-groups which will be updated
14991515
std::vector<DynamicCGInfo> MDynCGs;
15001516

15011517
std::shared_ptr<graph_impl> MGraph;
15021518
std::vector<std::byte> MValueStorage;
1519+
1520+
private:
1521+
unsigned long long MID;
1522+
// Used for std::hash in order to create a unique hash for the instance.
1523+
inline static std::atomic<unsigned long long> NextAvailableID = 0;
15031524
};
15041525

15051526
class dynamic_command_group_impl
@@ -1540,6 +1561,13 @@ class dynamic_command_group_impl
15401561

15411562
/// List of nodes using this dynamic command-group.
15421563
std::vector<std::weak_ptr<node_impl>> MNodes;
1564+
1565+
unsigned long long getID() const { return MID; }
1566+
1567+
private:
1568+
unsigned long long MID;
1569+
// Used for std::hash in order to create a unique hash for the instance.
1570+
inline static std::atomic<unsigned long long> NextAvailableID = 0;
15431571
};
15441572
} // namespace detail
15451573
} // namespace experimental

sycl/test/abi/sycl_symbols_windows.dump

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,12 @@
336336
??0dynamic_parameter_base@detail@experimental@oneapi@ext@_V1@sycl@@QEAA@$$QEAV0123456@@Z
337337
??0dynamic_parameter_base@detail@experimental@oneapi@ext@_V1@sycl@@QEAA@AEBV0123456@@Z
338338
??0dynamic_parameter_base@detail@experimental@oneapi@ext@_V1@sycl@@QEAA@V?$command_graph@$0A@@23456@_KPEBX@Z
339+
??4?$hash@Vdynamic_command_group@experimental@oneapi@ext@_V1@sycl@@@std@@QEAAAEAU01@AEBU01@@Z
340+
??4?$hash@Vdynamic_command_group@experimental@oneapi@ext@_V1@sycl@@@std@@QEAAAEAU01@$$QEAU01@@Z
341+
??R?$hash@Vdynamic_command_group@experimental@oneapi@ext@_V1@sycl@@@std@@QEBA_KAEBVdynamic_command_group@experimental@oneapi@ext@_V1@sycl@@@Z
342+
??R?$hash@Vnode@experimental@oneapi@ext@_V1@sycl@@@std@@QEBA_KAEBVnode@experimental@oneapi@ext@_V1@sycl@@@Z
343+
??4?$hash@Vnode@experimental@oneapi@ext@_V1@sycl@@@std@@QEAAAEAU01@$$QEAU01@@Z
344+
??4?$hash@Vnode@experimental@oneapi@ext@_V1@sycl@@@std@@QEAAAEAU01@AEBU01@@Z
339345
??0event@_V1@sycl@@AEAA@V?$shared_ptr@Vevent_impl@detail@_V1@sycl@@@std@@@Z
340346
??0event@_V1@sycl@@QEAA@$$QEAV012@@Z
341347
??0event@_V1@sycl@@QEAA@AEBV012@@Z

sycl/unittests/Extensions/CommandGraph/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ set(CMAKE_CXX_EXTENSIONS OFF)
33
add_sycl_unittest(CommandGraphExtensionTests OBJECT
44
Barrier.cpp
55
CommandGraph.cpp
6+
CommonReferenceSemantics.cpp
67
Exceptions.cpp
78
InOrderQueue.cpp
89
MultiThreaded.cpp

0 commit comments

Comments
 (0)