Skip to content

[NFCI][SYCL][Graph] Refactor graph_impl::add #19351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 77 additions & 103 deletions sycl/source/detail/graph/graph_impl.cpp

Large diffs are not rendered by default.

87 changes: 44 additions & 43 deletions sycl/source/detail/graph/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ class partition {
partition() : MSchedule(), MCommandBuffers() {}

/// List of root nodes.
std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
MRoots;
std::set<node_impl *> MRoots;
/// Execution schedule of nodes in the graph.
std::list<std::shared_ptr<node_impl>> MSchedule;
std::list<node_impl *> MSchedule;
/// Map of devices to command buffers.
std::unordered_map<sycl::device, ur_exp_command_buffer_handle_t>
MCommandBuffers;
Expand Down Expand Up @@ -84,17 +83,20 @@ class partition {
// replaced every time the partition is executed.
EventImplPtr MEvent;

nodes_range roots() const { return MRoots; }
nodes_range schedule() const { return MSchedule; }

/// Checks if the graph is single path, i.e. each node has a single successor.
/// @return True if the graph is a single path
bool checkIfGraphIsSinglePath() {
if (MRoots.size() > 1) {
return false;
}
for (const auto &Node : MSchedule) {
for (node_impl &Node : schedule()) {
// In version 1.3.28454 of the L0 driver, 2D Copy ops cannot not
// be enqueued in an in-order cmd-list (causing execution to stall).
// The 2D Copy test should be removed from here when the bug is fixed.
if ((Node->MSuccessors.size() > 1) || (Node->isNDCopyNode())) {
if ((Node.MSuccessors.size() > 1) || (Node.isNDCopyNode())) {
return false;
}
}
Expand All @@ -103,7 +105,7 @@ class partition {
}

/// Add nodes to MSchedule.
void schedule();
void updateSchedule();
};

/// Implementation details of command_graph<modifiable>.
Expand All @@ -126,7 +128,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {

/// Remove node from list of root nodes.
/// @param Root Node to remove from list of root nodes.
void removeRoot(const std::shared_ptr<node_impl> &Root);
void removeRoot(node_impl &Root);

/// Verifies the CG is valid to add to the graph and returns set of
/// dependent nodes if so.
Expand All @@ -145,30 +147,30 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// @param CommandGroup The CG which stores all information for this node.
/// @param Deps Dependencies of the created node.
/// @return Created node in the graph.
std::shared_ptr<node_impl> add(node_type NodeType,
std::shared_ptr<sycl::detail::CG> CommandGroup,
nodes_range Deps);
node_impl &add(node_type NodeType,
std::shared_ptr<sycl::detail::CG> CommandGroup,
nodes_range Deps);

/// Create a CGF node in the graph.
/// @param CGF Command-group function to create node with.
/// @param Args Node arguments.
/// @param Deps Dependencies of the created node.
/// @return Created node in the graph.
std::shared_ptr<node_impl> add(std::function<void(handler &)> CGF,
const std::vector<sycl::detail::ArgDesc> &Args,
std::vector<std::shared_ptr<node_impl>> &Deps);
node_impl &add(std::function<void(handler &)> CGF,
const std::vector<sycl::detail::ArgDesc> &Args,
nodes_range Deps);

/// Create an empty node in the graph.
/// @param Deps List of predecessor nodes.
/// @return Created node in the graph.
std::shared_ptr<node_impl> add(nodes_range Deps);
node_impl &add(nodes_range Deps);

/// Create a dynamic command-group node in the graph.
/// @param DynCGImpl Dynamic command-group used to create node.
/// @param Deps List of predecessor nodes.
/// @return Created node in the graph.
std::shared_ptr<node_impl>
add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl, nodes_range Deps);
node_impl &add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
nodes_range Deps);

/// Add a queue to the set of queues which are currently recording to this
/// graph.
Expand All @@ -190,10 +192,10 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// @param EventImpl Event to associate with a node in map.
/// @param NodeImpl Node to associate with event in map.
void addEventForNode(std::shared_ptr<sycl::detail::event_impl> EventImpl,
const std::shared_ptr<node_impl> &NodeImpl) {
node_impl &NodeImpl) {
if (!(EventImpl->hasCommandGraph()))
EventImpl->setCommandGraph(shared_from_this());
MEventsMap[EventImpl] = NodeImpl;
MEventsMap[EventImpl] = NodeImpl.shared_from_this();
}

/// Find the sycl event associated with a node.
Expand Down Expand Up @@ -281,15 +283,16 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
sycl::device getDevice() const { return MDevice; }

/// List of root nodes.
std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
MRoots;
std::set<node_impl *> MRoots;

/// Storage for all nodes contained within a graph. Nodes are connected to
/// each other via weak_ptrs and so do not extend each other's lifetimes.
/// This storage allows easy iteration over all nodes in the graph, rather
/// than needing an expensive depth first search.
std::vector<std::shared_ptr<node_impl>> MNodeStorage;

nodes_range roots() const { return MRoots; }

/// Find the last node added to this graph from an in-order queue.
/// @param Queue In-order queue to find the last node added to the graph from.
/// @return Last node in this graph added from \p Queue recording, or empty
Expand All @@ -312,8 +315,8 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
std::fstream Stream(FilePath, std::ios::out);
Stream << "digraph dot {" << std::endl;

for (std::weak_ptr<node_impl> Node : MRoots)
Node.lock()->printDotRecursive(Stream, VisitedNodes, Verbose);
for (node_impl &Node : roots())
Node.printDotRecursive(Stream, VisitedNodes, Verbose);

Stream << "}" << std::endl;

Expand Down Expand Up @@ -418,13 +421,10 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
}

size_t RootsFound = 0;
for (std::weak_ptr<node_impl> NodeA : MRoots) {
for (std::weak_ptr<node_impl> NodeB : Graph.MRoots) {
auto NodeALocked = NodeA.lock();
auto NodeBLocked = NodeB.lock();

if (NodeALocked->isSimilar(*NodeBLocked)) {
if (checkNodeRecursive(*NodeALocked, *NodeBLocked)) {
for (node_impl &NodeA : roots()) {
for (node_impl &NodeB : Graph.roots()) {
if (NodeA.isSimilar(NodeB)) {
if (checkNodeRecursive(NodeA, NodeB)) {
RootsFound++;
break;
}
Expand Down Expand Up @@ -510,6 +510,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
}

private:
template <typename... Ts> node_impl &createNode(Ts &&...Args) {
MNodeStorage.push_back(
std::make_shared<node_impl>(std::forward<Ts>(Args)...));
return *MNodeStorage.back();
Comment on lines +514 to +516
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if node_impl creation is under a mutex or not. If races are possible, then might need to change to

   auto Ptr = make_shared();
   node_impl &Res = *Ptr;
   MNodeStorage.push_back(std::move(Ptr));
   return Res;

If that's the case, then nodes_range over std::vector<node> optimization needs to be examined for the race conditions as well.

}

/// Check the graph for cycles by performing a depth-first search of the
/// graph. If a node is visited more than once in a given path through the
/// graph, a cycle is present and the search ends immediately.
Expand All @@ -518,18 +524,18 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {

/// Insert node into list of root nodes.
/// @param Root Node to add to list of root nodes.
void addRoot(const std::shared_ptr<node_impl> &Root);
void addRoot(node_impl &Root);

/// Adds dependencies for a new node, if it has no deps it will be
/// added as a root node.
/// @param Node The node to add deps for
/// @param Deps List of dependent nodes
void addDepsToNode(const std::shared_ptr<node_impl> &Node, nodes_range Deps) {
void addDepsToNode(node_impl &Node, nodes_range Deps) {
for (node_impl &N : Deps) {
N.registerSuccessor(Node);
this->removeRoot(Node);
}
if (Node->MPredecessors.empty()) {
if (Node.MPredecessors.empty()) {
this->addRoot(Node);
}
}
Expand Down Expand Up @@ -647,9 +653,7 @@ class exec_graph_impl {

/// Query the scheduling of node execution.
/// @return List of nodes in execution order.
const std::list<std::shared_ptr<node_impl>> &getSchedule() const {
return MSchedule;
}
const std::list<node_impl *> &getSchedule() const { return MSchedule; }

/// Query the graph_impl.
/// @return pointer to the graph_impl MGraphImpl
Expand Down Expand Up @@ -730,8 +734,7 @@ class exec_graph_impl {
/// @param Node The node being enqueued.
/// @return UR sync point created for this node in the command-buffer.
ur_exp_command_buffer_sync_point_t
enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
std::shared_ptr<node_impl> Node);
enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer, node_impl &Node);

/// Enqueue a node directly to the command-buffer without going through the
/// scheduler.
Expand All @@ -740,11 +743,9 @@ class exec_graph_impl {
/// @param CommandBuffer Command-buffer to add node to as a command.
/// @param Node The node being enqueued.
/// @return UR sync point created for this node in the command-buffer.
ur_exp_command_buffer_sync_point_t
enqueueNodeDirect(const sycl::context &Ctx,
sycl::detail::device_impl &DeviceImpl,
ur_exp_command_buffer_handle_t CommandBuffer,
std::shared_ptr<node_impl> Node);
ur_exp_command_buffer_sync_point_t enqueueNodeDirect(
const sycl::context &Ctx, sycl::detail::device_impl &DeviceImpl,
ur_exp_command_buffer_handle_t CommandBuffer, node_impl &Node);

/// Enqueues a host-task partition (i.e. a partition that contains only a
/// single node and that node is a host-task).
Expand Down Expand Up @@ -873,7 +874,7 @@ class exec_graph_impl {
ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) const;

/// Execution schedule of nodes in the graph.
std::list<std::shared_ptr<node_impl>> MSchedule;
std::list<node_impl *> MSchedule;
/// Pointer to the modifiable graph impl associated with this executable
/// graph.
/// Thread-safe implementation note: in the current implementation
Expand Down
7 changes: 2 additions & 5 deletions sycl/source/detail/graph/node_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,11 @@ std::vector<node> createNodesFromImpls(
return Nodes;
}

/// Takes a vector of shared_ptrs to node_impls and returns a vector of node
/// objects created from those impls, in the same order.
std::vector<node> createNodesFromImpls(
const std::vector<std::shared_ptr<detail::node_impl>> &Impls) {
std::vector<node> createNodesFromImpls(nodes_range Impls) {
std::vector<node> Nodes{};
Nodes.reserve(Impls.size());

for (std::shared_ptr<detail::node_impl> Impl : Impls) {
for (detail::node_impl &Impl : Impls) {
Nodes.push_back(sycl::detail::createSyclObjFromImpl<node>(Impl));
}

Expand Down
42 changes: 21 additions & 21 deletions sycl/source/detail/graph/node_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
#include <sycl/detail/cg_types.hpp> // for CGType
#include <sycl/detail/kernel_desc.hpp> // for kernel_param_kind_t

#include <sycl/ext/oneapi/experimental/graph/node.hpp> // for node

#include <cstring>
#include <fstream>
#include <iomanip>
#include <list>
#include <set>
#include <vector>

Expand All @@ -25,8 +28,6 @@ inline namespace _V1 {
namespace ext {
namespace oneapi {
namespace experimental {
// Forward declarations
class node;

namespace detail {
// Forward declarations
Expand All @@ -39,10 +40,7 @@ class exec_graph_impl;
std::vector<node>
createNodesFromImpls(const std::vector<std::weak_ptr<node_impl>> &Impls);

/// Takes a vector of shared_ptrs to node_impls and returns a vector of node
/// objects created from those impls, in the same order.
std::vector<node>
createNodesFromImpls(const std::vector<std::shared_ptr<node_impl>> &Impls);
std::vector<node> createNodesFromImpls(nodes_range Impls);

inline node_type getNodeTypeFromCG(sycl::detail::CGType CGType) {
using sycl::detail::CG;
Expand Down Expand Up @@ -123,27 +121,27 @@ class node_impl : public std::enable_shared_from_this<node_impl> {

/// Add successor to the node.
/// @param Node Node to add as a successor.
void registerSuccessor(const std::shared_ptr<node_impl> &Node) {
void registerSuccessor(node_impl &Node) {
if (std::find_if(MSuccessors.begin(), MSuccessors.end(),
[Node](const std::weak_ptr<node_impl> &Ptr) {
return Ptr.lock() == Node;
[&Node](const std::weak_ptr<node_impl> &Ptr) {
return Ptr.lock().get() == &Node;
}) != MSuccessors.end()) {
return;
}
MSuccessors.push_back(Node);
Node->registerPredecessor(shared_from_this());
MSuccessors.push_back(Node.weak_from_this());
Node.registerPredecessor(*this);
}

/// Add predecessor to the node.
/// @param Node Node to add as a predecessor.
void registerPredecessor(const std::shared_ptr<node_impl> &Node) {
void registerPredecessor(node_impl &Node) {
if (std::find_if(MPredecessors.begin(), MPredecessors.end(),
[&Node](const std::weak_ptr<node_impl> &Ptr) {
return Ptr.lock() == Node;
return Ptr.lock().get() == &Node;
}) != MPredecessors.end()) {
return;
}
MPredecessors.push_back(Node);
MPredecessors.push_back(Node.weak_from_this());
}

/// Construct an empty node.
Expand Down Expand Up @@ -774,7 +772,9 @@ class nodes_range {
// from `weak_ptr`s this alternative should be removed too.
std::vector<std::weak_ptr<node_impl>>,
//
std::set<std::shared_ptr<node_impl>>>;
std::set<std::shared_ptr<node_impl>>, std::set<node_impl *>,
//
std::list<node_impl *>, std::vector<node>>;

storage_iter Begin;
storage_iter End;
Expand All @@ -783,10 +783,8 @@ class nodes_range {
public:
nodes_range(const nodes_range &Other) = default;

template <
typename ContainerTy,
typename = std::enable_if_t<!std::is_same_v<nodes_range, ContainerTy>>>
nodes_range(ContainerTy &Container)
template <typename ContainerTy>
nodes_range(const ContainerTy &Container)
: Begin{Container.begin()}, End{Container.end()}, Size{Container.size()} {
}

Expand All @@ -812,12 +810,14 @@ class nodes_range {
return std::visit(
[](auto &&It) -> node_impl & {
auto &Elem = *It;
if constexpr (std::is_same_v<std::decay_t<decltype(Elem)>,
std::weak_ptr<node_impl>>) {
using Ty = std::decay_t<decltype(Elem)>;
if constexpr (std::is_same_v<Ty, std::weak_ptr<node_impl>>) {
// This assumes that weak_ptr doesn't actually manage lifetime and
// the object is guaranteed to be alive (which seems to be the
// assumption across all graph code).
return *Elem.lock();
} else if constexpr (std::is_same_v<Ty, node>) {
return *getSyclObjImpl(Elem);
} else {
return *Elem;
}
Expand Down
Loading