Skip to content

Commit 5c73a20

Browse files
authored
[SYCL][Graph] Implement dynamic command-groups (#15700)
Implement Dynamic Command-Group feature specified in PR [[SYCL][Graph] Add specification for kernel binary updates](#14896). This feature enables updating `ur_kernel_handle_t` objects in graph nodes between executions as well as parameters and execution range of nodes. Points to note in this change: * The functionality is currently supported on CUDA & HIP which are used for testing in the new E2E tests. Level Zero support will follow shortly, resulting in the removal of the `XFAIL` labels with tracker number from the E2E tests. * The code for adding nodes to a graph has been refactored to split out verification of edges, and marking memory objects used in a node, as separate helper functions. This allows path for adding a command-group node to do this functions over each CG in the list before creating the node itself. * The `dynamic_parameter_impl` code has also been refactored so the code is shared for updating a dynamic parameter used in both a regular kernel node and a dynamic command-group node. * There is now no need for the `handler::setNDRangeUsed()` API now that graph kernel nodes can update between kernels using `sycl::nd_range` and `sycl::range`. The functionality in this method has be turned into a no-op, however removing the method is an ABI breaking change, so it remains guarded by the `__INTEL_PREVIEW_BREAKING_CHANGES` macro. See the addition to the design doc for further details on the implementation.
1 parent 8a5133f commit 5c73a20

33 files changed

+2311
-265
lines changed

sycl/doc/design/CommandGraph.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,29 @@ requirements for these new accessors to correctly trigger allocations before
282282
updating. This is similar to how individual graph commands are enqueued when
283283
accessors are used in a graph node.
284284

285+
### Dynamic Command-Group
286+
287+
To implement the `dynamic_command_group` class for updating the command-groups (CG)
288+
associated with nodes, the CG member of the node implementation class changes
289+
from a `std::unique_ptr` to a `std::shared_ptr` so that multiple nodes and the
290+
`dynamic_command_group_impl` object can share the same CG object. This avoids
291+
the overhead of having to allocate and free copies of the CG when a new active
292+
CG is selected.
293+
294+
The `dynamic_command_group_impl` class contains a list of weak pointers to the
295+
nodes which have been created with it, so that when a new active CG is selected
296+
it can propagate the change to those nodes. The `dynamic_parameter_impl` class
297+
also contains a list of weak pointers, but to the `dynamic_command_group_impl`
298+
instances of any dynamic command-groups where they are used. This allows
299+
updating the dynamic parameter to propagate to dynamic command-group nodes.
300+
301+
The `sycl::detail::CGExecKernel` class has been added to, so that if the
302+
object was created from an element in the dynamic command-group list, the class
303+
stores a vector of weak pointers to the other alternative command-groups created
304+
from the same dynamic command-group object. This allows the SYCL runtime to
305+
access the list of alternative kernels when calling the UR API to append a
306+
kernel command to a command-buffer.
307+
285308
## Optimizations
286309
### Interactions with Profiling
287310

sycl/include/sycl/ext/oneapi/experimental/graph.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class node_impl;
9696
class graph_impl;
9797
class exec_graph_impl;
9898
class dynamic_parameter_impl;
99+
class dynamic_command_group_impl;
99100
} // namespace detail
100101

101102
enum class node_type {
@@ -216,6 +217,23 @@ class depends_on_all_leaves : public ::sycl::detail::DataLessProperty<
216217
} // namespace node
217218
} // namespace property
218219

220+
class __SYCL_EXPORT dynamic_command_group {
221+
public:
222+
dynamic_command_group(
223+
const command_graph<graph_state::modifiable> &Graph,
224+
const std::vector<std::function<void(handler &)>> &CGFList);
225+
226+
size_t get_active_cgf() const;
227+
void set_active_cgf(size_t Index);
228+
229+
private:
230+
template <class Obj>
231+
friend const decltype(Obj::impl) &
232+
sycl::detail::getSyclObjImpl(const Obj &SyclObject);
233+
234+
std::shared_ptr<detail::dynamic_command_group_impl> impl;
235+
};
236+
219237
namespace detail {
220238
// Templateless modifiable command-graph base class.
221239
class __SYCL_EXPORT modifiable_command_graph {
@@ -337,6 +355,12 @@ class __SYCL_EXPORT modifiable_command_graph {
337355
modifiable_command_graph(const std::shared_ptr<detail::graph_impl> &Impl)
338356
: impl(Impl) {}
339357

358+
/// Template-less implementation of add() for dynamic command-group nodes.
359+
/// @param DynCGF Dynamic Command-group function object to add.
360+
/// @param Dep List of predecessor nodes.
361+
/// @return Node added to the graph.
362+
node addImpl(dynamic_command_group &DynCGF, const std::vector<node> &Dep);
363+
340364
/// Template-less implementation of add() for CGF nodes.
341365
/// @param CGF Command-group function to add.
342366
/// @param Dep List of predecessor nodes.

sycl/include/sycl/handler.hpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,7 +1175,6 @@ class __SYCL_EXPORT handler {
11751175
StoreLambda<KName, decltype(Wrapper), Dims, TransformedArgType>(
11761176
std::move(Wrapper));
11771177
setType(detail::CGType::Kernel);
1178-
setNDRangeUsed(false);
11791178
#endif
11801179
} else
11811180
#endif // !__SYCL_DISABLE_PARALLEL_FOR_RANGE_ROUNDING__ &&
@@ -1198,7 +1197,6 @@ class __SYCL_EXPORT handler {
11981197
StoreLambda<NameT, KernelType, Dims, TransformedArgType>(
11991198
std::move(KernelFunc));
12001199
setType(detail::CGType::Kernel);
1201-
setNDRangeUsed(false);
12021200
#endif
12031201
#else
12041202
(void)KernelFunc;
@@ -1249,7 +1247,6 @@ class __SYCL_EXPORT handler {
12491247
StoreLambda<NameT, KernelType, Dims, TransformedArgType>(
12501248
std::move(KernelFunc));
12511249
setType(detail::CGType::Kernel);
1252-
setNDRangeUsed(true);
12531250
#endif
12541251
}
12551252

@@ -1272,7 +1269,6 @@ class __SYCL_EXPORT handler {
12721269
setNDRangeDescriptor(std::move(NumWorkItems));
12731270
processLaunchProperties<PropertiesT>(Props);
12741271
setType(detail::CGType::Kernel);
1275-
setNDRangeUsed(false);
12761272
extractArgsAndReqs();
12771273
MKernelName = getKernelName();
12781274
#endif
@@ -1298,7 +1294,6 @@ class __SYCL_EXPORT handler {
12981294
setNDRangeDescriptor(std::move(NDRange));
12991295
processLaunchProperties(Props);
13001296
setType(detail::CGType::Kernel);
1301-
setNDRangeUsed(true);
13021297
extractArgsAndReqs();
13031298
MKernelName = getKernelName();
13041299
#endif
@@ -1339,7 +1334,6 @@ class __SYCL_EXPORT handler {
13391334
setNDRangeDescriptor(NumWorkGroups, /*SetNumWorkGroups=*/true);
13401335
StoreLambda<NameT, KernelType, Dims, LambdaArgType>(std::move(KernelFunc));
13411336
setType(detail::CGType::Kernel);
1342-
setNDRangeUsed(false);
13431337
#endif // __SYCL_DEVICE_ONLY__
13441338
}
13451339

@@ -1971,7 +1965,6 @@ class __SYCL_EXPORT handler {
19711965
StoreLambda<NameT, KernelType, Dims, TransformedArgType>(
19721966
std::move(KernelFunc));
19731967
setType(detail::CGType::Kernel);
1974-
setNDRangeUsed(false);
19751968
#endif
19761969
}
19771970

@@ -2069,7 +2062,6 @@ class __SYCL_EXPORT handler {
20692062
detail::checkValueRange<Dims>(NumWorkItems, WorkItemOffset);
20702063
setNDRangeDescriptor(std::move(NumWorkItems), std::move(WorkItemOffset));
20712064
setType(detail::CGType::Kernel);
2072-
setNDRangeUsed(false);
20732065
extractArgsAndReqs();
20742066
MKernelName = getKernelName();
20752067
#endif
@@ -2148,7 +2140,6 @@ class __SYCL_EXPORT handler {
21482140
setNDRangeDescriptor(std::move(NumWorkItems));
21492141
MKernel = detail::getSyclObjImpl(std::move(Kernel));
21502142
setType(detail::CGType::Kernel);
2151-
setNDRangeUsed(false);
21522143
if (!lambdaAndKernelHaveEqualName<NameT>()) {
21532144
extractArgsAndReqs();
21542145
MKernelName = getKernelName();
@@ -2189,7 +2180,6 @@ class __SYCL_EXPORT handler {
21892180
setNDRangeDescriptor(std::move(NumWorkItems), std::move(WorkItemOffset));
21902181
MKernel = detail::getSyclObjImpl(std::move(Kernel));
21912182
setType(detail::CGType::Kernel);
2192-
setNDRangeUsed(false);
21932183
if (!lambdaAndKernelHaveEqualName<NameT>()) {
21942184
extractArgsAndReqs();
21952185
MKernelName = getKernelName();
@@ -2229,7 +2219,6 @@ class __SYCL_EXPORT handler {
22292219
setNDRangeDescriptor(std::move(NDRange));
22302220
MKernel = detail::getSyclObjImpl(std::move(Kernel));
22312221
setType(detail::CGType::Kernel);
2232-
setNDRangeUsed(true);
22332222
if (!lambdaAndKernelHaveEqualName<NameT>()) {
22342223
extractArgsAndReqs();
22352224
MKernelName = getKernelName();
@@ -3357,6 +3346,7 @@ class __SYCL_EXPORT handler {
33573346
size_t Size, bool Block = false);
33583347
friend class ext::oneapi::experimental::detail::graph_impl;
33593348
friend class ext::oneapi::experimental::detail::dynamic_parameter_impl;
3349+
friend class ext::oneapi::experimental::detail::dynamic_command_group_impl;
33603350

33613351
bool DisableRangeRounding();
33623352

@@ -3626,8 +3616,10 @@ class __SYCL_EXPORT handler {
36263616
}
36273617
#endif
36283618

3619+
#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
36293620
// Set that an ND Range was used during a call to parallel_for
36303621
void setNDRangeUsed(bool Value);
3622+
#endif
36313623

36323624
inline void internalProfilingTagImpl() {
36333625
throwIfActionIsCreated();

sycl/source/detail/cg.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,9 @@ class CGExecKernel : public CG {
257257
std::string MKernelName;
258258
std::vector<std::shared_ptr<detail::stream_impl>> MStreams;
259259
std::vector<std::shared_ptr<const void>> MAuxiliaryResources;
260+
/// Used to implement ext_oneapi_graph dynamic_command_group. Stores the list
261+
/// of command-groups that a kernel command can be updated to.
262+
std::vector<std::weak_ptr<CGExecKernel>> MAlternativeKernels;
260263
ur_kernel_cache_config_t MKernelCacheConfig;
261264
bool MKernelIsCooperative = false;
262265
bool MKernelUsesClusterLaunch = false;
@@ -277,7 +280,7 @@ class CGExecKernel : public CG {
277280
MKernelBundle(std::move(KernelBundle)), MArgs(std::move(Args)),
278281
MKernelName(std::move(KernelName)), MStreams(std::move(Streams)),
279282
MAuxiliaryResources(std::move(AuxiliaryResources)),
280-
MKernelCacheConfig(std::move(KernelCacheConfig)),
283+
MAlternativeKernels{}, MKernelCacheConfig(std::move(KernelCacheConfig)),
281284
MKernelIsCooperative(KernelIsCooperative),
282285
MKernelUsesClusterLaunch(MKernelUsesClusterLaunch) {
283286
assert(getType() == CGType::Kernel && "Wrong type of exec kernel CG.");

0 commit comments

Comments
 (0)