Skip to content

Commit 98fcde3

Browse files
test2
Signed-off-by: Daiyaan Ahmed <daiyaan.ahmed@intel.com>
1 parent d7ad0fb commit 98fcde3

File tree

2 files changed

+83
-40
lines changed

2 files changed

+83
-40
lines changed

clang/lib/DPCT/RulesLang/RulesLangGraph.cpp

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -55,35 +55,28 @@ void GraphRule::runRule(const MatchFinder::MatchResult &Result) {
5555
*Result.Context);
5656
auto MemberName = ME->getMemberNameInfo().getAsString();
5757
if (BaseTy == "cudaKernelNodeParams") {
58+
std::cout <<"NODE PARAMS FOUND\n";
59+
DpctGlobalInfo::setCVersionCUDALaunchUsed();
5860
auto FieldName = KernelNodeParamNames[MemberName];
5961
if (FieldName.empty()) {
6062
report(ME->getBeginLoc(), Diagnostics::API_NOT_MIGRATED, false,
6163
DpctGlobalInfo::getOriginalTypeName(ME->getBase()->getType()) +
6264
"::" + ME->getMemberDecl()->getName().str());
6365
return;
66+
6467
}
6568
// if(FieldName == "func"){
66-
// if(auto BO = dyn_cast<BinaryOperator>(getParentAsAssignedBO(ME, *Result.Context))){
67-
// const Expr *RHS = BO->getRHS();
68-
// const Expr *StrippedRHS = RHS->IgnoreParenCasts();
69-
// std::string RHSStr;
70-
// llvm::raw_string_ostream OS(RHSStr);
71-
// std::cout <<"RHSSTR: " <<RHSStr << "\n";
72-
// StrippedRHS->printPretty(OS, nullptr, Result.Context->getPrintingPolicy());
73-
74-
75-
// // Create the replacement string using dpct::wrapper_register
76-
// auto ReplacementStr = "set_func.dpct::wrapper_register(&" + RHSStr + "_wrapper).get()";
77-
// std::cout<< "ReplacementSTR:" << ReplacementStr << "\n";
78-
79-
// // Replace the assignment with the set_func method call
80-
// // emplaceTransformation(ReplaceMemberAssignAsSetMethod(
81-
// // BO, ME, FieldName, ReplacementStr));
82-
// emplaceTransformation(new ReplaceText(getStmtExpansionSourceRange(RHS).getBegin(),
83-
// ReplacementStr.length(),
84-
// std::move(ReplacementStr)));
85-
// return;
86-
// }
69+
// Check for the binary operator and fetch the RHS
70+
// Strip the explicit typecast if it exists
71+
// Check for VarDecl on the StrippedRHS
72+
// If not a VarDecl, then insert user warning
73+
// Check for VarDecl Type to be a FunctionDecl
74+
// If FunctionDecl, then
75+
// VarDecl, get var name, Get kernel_node_params variable name
76+
// Create the expression, hardcoded strting
77+
// Create new replace object and emplace transformation (nodeParams.set_func((void*)dpct::wrapper_register(&incrementKernel_wrapper).get());)
78+
// If VarDecl and not a FunctionDecl and if type of VarDecl is function pointer
79+
// Create a hardcoded string (nodeParams.set_func(a.get()));
8780
// }
8881
if (auto BO = getParentAsAssignedBO(ME, *Result.Context)) {
8982
StringRef ReplacedArg = "";

clang/runtime/dpct-rt/include/dpct/graph.hpp

Lines changed: 69 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88

99
#pragma once
1010

11+
#include "dpct/kernel.hpp"
1112
#include "dpct/util.hpp"
1213
#include "sycl/handler.hpp"
14+
#include "sycl/queue.hpp"
1315
#include <cstddef>
1416
#include <sycl/ext/oneapi/experimental/graph.hpp>
1517
#include <sycl/sycl.hpp>
1618
#include <unordered_map>
17-
#include <unordered_map>
1819

1920
namespace dpct {
2021
namespace experimental {
@@ -33,16 +34,18 @@ struct kernel_node_params {
3334
dpct::dim3 block_dim;
3435
dpct::dim3 grid_dim;
3536
void **kernel_params;
36-
void* func;
37+
void *func;
3738
unsigned int shared_mem_bytes;
3839

3940
public:
40-
void set_block_dim(dpct::dim3 block_dim) { block_dim = block_dim; }
41-
void set_grid_dim(dpct::dim3 grid_dim) { grid_dim = grid_dim; }
42-
void set_kernel_params(void **kernel_params) { kernel_params = kernel_params; }
43-
void set_func(void *func) { func = func; }
41+
void set_block_dim(dpct::dim3 block_dim) { this->block_dim = block_dim; }
42+
void set_grid_dim(dpct::dim3 grid_dim) { this->grid_dim = grid_dim; }
43+
void set_kernel_params(void **kernel_params) {
44+
this->kernel_params = kernel_params;
45+
}
46+
void set_func(void *func) { this->func = func; }
4447
void set_shared_mem_bytes(unsigned int shared_mem_bytes) {
45-
shared_mem_bytes = shared_mem_bytes;
48+
this->shared_mem_bytes = shared_mem_bytes;
4649
}
4750
dpct::dim3 get_block_dim() { return block_dim; }
4851
dpct::dim3 get_grid_dim() { return grid_dim; }
@@ -65,10 +68,6 @@ class graph_mgr {
6568
return instance;
6669
}
6770

68-
std::unordered_map<dpct::experimental::node_ptr,
69-
dpct::experimental::kernel_node_params>
70-
kernel_node_params_map;
71-
7271
void begin_recording(sycl::queue *queue_ptr) {
7372
// Calling begin_recording on an already recording queue is a no-op in SYCL
7473
if (queue_graph_map.find(queue_ptr) != queue_graph_map.end()) {
@@ -124,6 +123,36 @@ class graph_mgr {
124123
}
125124
}
126125

126+
void add_kernel_node(dpct::experimental::node_ptr *node,
127+
dpct::experimental::command_graph_ptr graph,
128+
dpct::experimental::node_ptr *dependencies,
129+
std::size_t numberOfDependencies,
130+
dpct::experimental::kernel_node_params *params) {
131+
kernel_node_params_map[graph].push_back(params);
132+
}
133+
void launch(dpct::experimental::command_graph_exec_ptr execGraph,
134+
sycl::queue *queue) {
135+
auto graph = exec_graph_map[execGraph];
136+
for (auto kernel_params : kernel_node_params_map[graph]) {
137+
graph->add([&](sycl::handler &cgh) {
138+
cgh.host_task([=]() {
139+
dpct::kernel_launcher::launch(
140+
kernel_params->get_func(), kernel_params->get_grid_dim(),
141+
kernel_params->get_block_dim(),
142+
kernel_params->get_kernel_params(),
143+
kernel_params->get_shared_mem_bytes(), queue);
144+
});
145+
});
146+
}
147+
auto final_graph = graph->finalize();
148+
queue->submit([&](sycl::handler &cgh) { cgh.ext_oneapi_graph(final_graph); });
149+
}
150+
151+
void instantiate(dpct::experimental::command_graph_exec_ptr *execGraph,
152+
dpct::experimental::command_graph_ptr graph) {
153+
exec_graph_map[*execGraph] = graph;
154+
}
155+
127156
private:
128157
std::unordered_map<sycl::queue *, command_graph_ptr> queue_graph_map;
129158
std::unordered_map<dpct::experimental::command_graph_ptr,
@@ -132,6 +161,12 @@ class graph_mgr {
132161
std::unordered_map<dpct::experimental::command_graph_ptr,
133162
std::vector<sycl::ext::oneapi::experimental::node>>
134163
root_nodes_map;
164+
std::unordered_map<dpct::experimental::command_graph_exec_ptr,
165+
dpct::experimental::command_graph_ptr>
166+
exec_graph_map;
167+
std::unordered_map<dpct::experimental::command_graph_ptr,
168+
std::vector<dpct::experimental::kernel_node_params *>>
169+
kernel_node_params_map;
135170
};
136171
} // namespace detail
137172

@@ -204,9 +239,9 @@ static void add_dependencies(dpct::experimental::command_graph_ptr graph,
204239
/// nodes will be assigned.
205240
/// \param [out] numberOfNodes The number of nodes in the graph.
206241
static void get_nodes(dpct::experimental::command_graph_ptr graph,
207-
dpct::experimental::node_ptr *nodesArray,
208-
std::size_t *numberOfNodes) {
209-
detail::graph_mgr::instance().get_nodes(graph, nodesArray, numberOfNodes);
242+
dpct::experimental::node_ptr *nodesArray,
243+
std::size_t *numberOfNodes) {
244+
detail::graph_mgr::instance().get_nodes(graph, nodesArray, numberOfNodes);
210245
}
211246

212247
/// Gets the root nodes in the command graph.
@@ -215,14 +250,29 @@ detail::graph_mgr::instance().get_nodes(graph, nodesArray, numberOfNodes);
215250
/// root nodes will be assigned.
216251
/// \param [out] numberOfNodes The number of root nodes in the graph.
217252
static void get_root_nodes(dpct::experimental::command_graph_ptr graph,
218-
dpct::experimental::node_ptr *nodesArray,
219-
std::size_t *numberOfNodes) {
220-
detail::graph_mgr::instance().get_root_nodes(graph, nodesArray,
221-
numberOfNodes);
253+
dpct::experimental::node_ptr *nodesArray,
254+
std::size_t *numberOfNodes) {
255+
detail::graph_mgr::instance().get_root_nodes(graph, nodesArray,
256+
numberOfNodes);
257+
}
258+
259+
static void add_kernel_node(dpct::experimental::node_ptr *node,
260+
dpct::experimental::command_graph_ptr graph,
261+
dpct::experimental::node_ptr *dependencies,
262+
std::size_t numberOfDependencies,
263+
dpct::experimental::kernel_node_params *params) {
264+
detail::graph_mgr::instance().add_kernel_node(node, graph, dependencies,
265+
numberOfDependencies, params);
222266
}
223267

224-
static void add_kernel_node(dpct::experimental::node_ptr* node, dpct::experimental::node_ptr* dependencies, std::size_t &numberOfDependencies, dpct::experimental::kernel_node_params* params){
268+
static void instantiate(dpct::experimental::command_graph_exec_ptr *execGraph,
269+
dpct::experimental::command_graph_ptr graph) {
270+
detail::graph_mgr::instance().instantiate(execGraph, graph);
271+
}
225272

273+
static void launch(dpct::experimental::command_graph_exec_ptr execGraph,
274+
sycl::queue *queue) {
275+
detail::graph_mgr::instance().launch(execGraph, queue);
226276
}
227277

228278
} // namespace experimental

0 commit comments

Comments
 (0)