8
8
9
9
#pragma once
10
10
11
+ #include " dpct/kernel.hpp"
11
12
#include " dpct/util.hpp"
12
13
#include " sycl/handler.hpp"
14
+ #include " sycl/queue.hpp"
13
15
#include < cstddef>
14
16
#include < sycl/ext/oneapi/experimental/graph.hpp>
15
17
#include < sycl/sycl.hpp>
16
18
#include < unordered_map>
17
- #include < unordered_map>
18
19
19
20
namespace dpct {
20
21
namespace experimental {
@@ -33,16 +34,18 @@ struct kernel_node_params {
33
34
dpct::dim3 block_dim;
34
35
dpct::dim3 grid_dim;
35
36
void **kernel_params;
36
- void * func;
37
+ void * func;
37
38
unsigned int shared_mem_bytes;
38
39
39
40
public:
40
- void set_block_dim (dpct::dim3 block_dim) { block_dim = block_dim; }
41
- void set_grid_dim (dpct::dim3 grid_dim) { grid_dim = grid_dim; }
42
- void set_kernel_params (void **kernel_params) { kernel_params = kernel_params; }
43
- void set_func (void *func) { func = func; }
41
+ void set_block_dim (dpct::dim3 block_dim) { this ->block_dim = block_dim; }
42
+ void set_grid_dim (dpct::dim3 grid_dim) { this ->grid_dim = grid_dim; }
43
+ void set_kernel_params (void **kernel_params) {
44
+ this ->kernel_params = kernel_params;
45
+ }
46
+ void set_func (void *func) { this ->func = func; }
44
47
void set_shared_mem_bytes (unsigned int shared_mem_bytes) {
45
- shared_mem_bytes = shared_mem_bytes;
48
+ this -> shared_mem_bytes = shared_mem_bytes;
46
49
}
47
50
dpct::dim3 get_block_dim () { return block_dim; }
48
51
dpct::dim3 get_grid_dim () { return grid_dim; }
@@ -65,10 +68,6 @@ class graph_mgr {
65
68
return instance;
66
69
}
67
70
68
- std::unordered_map<dpct::experimental::node_ptr,
69
- dpct::experimental::kernel_node_params>
70
- kernel_node_params_map;
71
-
72
71
void begin_recording (sycl::queue *queue_ptr) {
73
72
// Calling begin_recording on an already recording queue is a no-op in SYCL
74
73
if (queue_graph_map.find (queue_ptr) != queue_graph_map.end ()) {
@@ -124,6 +123,36 @@ class graph_mgr {
124
123
}
125
124
}
126
125
126
+ void add_kernel_node (dpct::experimental::node_ptr *node,
127
+ dpct::experimental::command_graph_ptr graph,
128
+ dpct::experimental::node_ptr *dependencies,
129
+ std::size_t numberOfDependencies,
130
+ dpct::experimental::kernel_node_params *params) {
131
+ kernel_node_params_map[graph].push_back (params);
132
+ }
133
+ void launch (dpct::experimental::command_graph_exec_ptr execGraph,
134
+ sycl::queue *queue) {
135
+ auto graph = exec_graph_map[execGraph];
136
+ for (auto kernel_params : kernel_node_params_map[graph]) {
137
+ graph->add ([&](sycl::handler &cgh) {
138
+ cgh.host_task ([=]() {
139
+ dpct::kernel_launcher::launch (
140
+ kernel_params->get_func (), kernel_params->get_grid_dim (),
141
+ kernel_params->get_block_dim (),
142
+ kernel_params->get_kernel_params (),
143
+ kernel_params->get_shared_mem_bytes (), queue);
144
+ });
145
+ });
146
+ }
147
+ auto final_graph = graph->finalize ();
148
+ queue->submit ([&](sycl::handler &cgh) { cgh.ext_oneapi_graph (final_graph); });
149
+ }
150
+
151
+ void instantiate (dpct::experimental::command_graph_exec_ptr *execGraph,
152
+ dpct::experimental::command_graph_ptr graph) {
153
+ exec_graph_map[*execGraph] = graph;
154
+ }
155
+
127
156
private:
128
157
std::unordered_map<sycl::queue *, command_graph_ptr> queue_graph_map;
129
158
std::unordered_map<dpct::experimental::command_graph_ptr,
@@ -132,6 +161,12 @@ class graph_mgr {
132
161
std::unordered_map<dpct::experimental::command_graph_ptr,
133
162
std::vector<sycl::ext::oneapi::experimental::node>>
134
163
root_nodes_map;
164
+ std::unordered_map<dpct::experimental::command_graph_exec_ptr,
165
+ dpct::experimental::command_graph_ptr>
166
+ exec_graph_map;
167
+ std::unordered_map<dpct::experimental::command_graph_ptr,
168
+ std::vector<dpct::experimental::kernel_node_params *>>
169
+ kernel_node_params_map;
135
170
};
136
171
} // namespace detail
137
172
@@ -204,9 +239,9 @@ static void add_dependencies(dpct::experimental::command_graph_ptr graph,
204
239
// / nodes will be assigned.
205
240
// / \param [out] numberOfNodes The number of nodes in the graph.
206
241
static void get_nodes (dpct::experimental::command_graph_ptr graph,
207
- dpct::experimental::node_ptr *nodesArray,
208
- std::size_t *numberOfNodes) {
209
- detail::graph_mgr::instance ().get_nodes (graph, nodesArray, numberOfNodes);
242
+ dpct::experimental::node_ptr *nodesArray,
243
+ std::size_t *numberOfNodes) {
244
+ detail::graph_mgr::instance ().get_nodes (graph, nodesArray, numberOfNodes);
210
245
}
211
246
212
247
// / Gets the root nodes in the command graph.
@@ -215,14 +250,29 @@ detail::graph_mgr::instance().get_nodes(graph, nodesArray, numberOfNodes);
215
250
// / root nodes will be assigned.
216
251
// / \param [out] numberOfNodes The number of root nodes in the graph.
217
252
static void get_root_nodes (dpct::experimental::command_graph_ptr graph,
218
- dpct::experimental::node_ptr *nodesArray,
219
- std::size_t *numberOfNodes) {
220
- detail::graph_mgr::instance ().get_root_nodes (graph, nodesArray,
221
- numberOfNodes);
253
+ dpct::experimental::node_ptr *nodesArray,
254
+ std::size_t *numberOfNodes) {
255
+ detail::graph_mgr::instance ().get_root_nodes (graph, nodesArray,
256
+ numberOfNodes);
257
+ }
258
+
259
+ static void add_kernel_node (dpct::experimental::node_ptr *node,
260
+ dpct::experimental::command_graph_ptr graph,
261
+ dpct::experimental::node_ptr *dependencies,
262
+ std::size_t numberOfDependencies,
263
+ dpct::experimental::kernel_node_params *params) {
264
+ detail::graph_mgr::instance ().add_kernel_node (node, graph, dependencies,
265
+ numberOfDependencies, params);
222
266
}
223
267
224
- static void add_kernel_node (dpct::experimental::node_ptr* node, dpct::experimental::node_ptr* dependencies, std::size_t &numberOfDependencies, dpct::experimental::kernel_node_params* params){
268
+ static void instantiate (dpct::experimental::command_graph_exec_ptr *execGraph,
269
+ dpct::experimental::command_graph_ptr graph) {
270
+ detail::graph_mgr::instance ().instantiate (execGraph, graph);
271
+ }
225
272
273
+ static void launch (dpct::experimental::command_graph_exec_ptr execGraph,
274
+ sycl::queue *queue) {
275
+ detail::graph_mgr::instance ().launch (execGraph, queue);
226
276
}
227
277
228
278
} // namespace experimental
0 commit comments