@@ -37,13 +37,18 @@ int main() {
37
37
if (!IH.ext_codeplay_has_graph ()) {
38
38
assert (false && " Native Handle should have a graph" );
39
39
}
40
- // Newly created stream for this node
41
- auto NativeStream = IH.get_native_queue <backend::ext_oneapi_cuda>();
42
40
// Graph already created with cuGraphCreate
43
41
CUgraph NativeGraph =
44
42
IH.ext_codeplay_get_native_graph <backend::ext_oneapi_cuda>();
45
43
46
44
// Start stream capture
45
+ // After CUDA 12.3 we can use cuStreamBeginCaptureToGraph to capture
46
+ // the stream directly in the native graph, rather than needing to
47
+ // instantiate the stream capture as a new graph.
48
+ #if CUDA_VERSION >= 12030
49
+ // Newly created stream for this node
50
+ auto NativeStream = IH.get_native_queue <backend::ext_oneapi_cuda>();
51
+
47
52
auto Res = cuStreamBeginCaptureToGraph (NativeStream, NativeGraph, nullptr ,
48
53
nullptr , 0 ,
49
54
CU_STREAM_CAPTURE_MODE_GLOBAL);
@@ -68,6 +73,53 @@ int main() {
68
73
69
74
Res = cuStreamEndCapture (NativeStream, &NativeGraph);
70
75
assert (Res == CUDA_SUCCESS);
76
+ #else
77
+ // Use explicit graph building API to add alloc/free nodes when
78
+ // cuGraphAddMemFreeNode isn't available
79
+ auto Device = IH.get_native_device <backend::ext_oneapi_cuda>();
80
+ CUDA_MEM_ALLOC_NODE_PARAMS AllocParams{};
81
+ AllocParams.bytesize = Size * sizeof (int32_t );
82
+ AllocParams.poolProps .allocType = CU_MEM_ALLOCATION_TYPE_PINNED;
83
+ AllocParams.poolProps .location .id = Device;
84
+ AllocParams.poolProps .location .type = CU_MEM_LOCATION_TYPE_DEVICE;
85
+ CUgraphNode AllocNode;
86
+ auto Res = cuGraphAddMemAllocNode (&AllocNode, NativeGraph, nullptr , 0 ,
87
+ &AllocParams);
88
+ assert (Res == CUDA_SUCCESS);
89
+
90
+ CUdeviceptr PtrAsync = AllocParams.dptr ;
91
+ CUDA_MEMSET_NODE_PARAMS MemsetParams{};
92
+ MemsetParams.dst = PtrAsync;
93
+ MemsetParams.elementSize = sizeof (int32_t );
94
+ MemsetParams.height = Size;
95
+ MemsetParams.pitch = sizeof (int32_t );
96
+ MemsetParams.value = Pattern;
97
+ MemsetParams.width = 1 ;
98
+ CUgraphNode MemsetNode;
99
+ CUcontext Context = IH.get_native_context <backend::ext_oneapi_cuda>();
100
+ Res = cuGraphAddMemsetNode (&MemsetNode, NativeGraph, &AllocNode, 1 ,
101
+ &MemsetParams, Context);
102
+ assert (Res == CUDA_SUCCESS);
103
+
104
+ CUDA_MEMCPY3D MemcpyParams{};
105
+ std::memset (&MemcpyParams, 0 , sizeof (CUDA_MEMCPY3D));
106
+ MemcpyParams.srcMemoryType = CU_MEMORYTYPE_DEVICE;
107
+ MemcpyParams.srcDevice = PtrAsync;
108
+ MemcpyParams.dstMemoryType = CU_MEMORYTYPE_DEVICE;
109
+ MemcpyParams.dstDevice = (CUdeviceptr)PtrX;
110
+ MemcpyParams.WidthInBytes = Size * sizeof (int32_t );
111
+ MemcpyParams.Height = 1 ;
112
+ MemcpyParams.Depth = 1 ;
113
+ CUgraphNode MemcpyNode;
114
+ Res = cuGraphAddMemcpyNode (&MemcpyNode, NativeGraph, &MemsetNode, 1 ,
115
+ &MemcpyParams, Context);
116
+ assert (Res == CUDA_SUCCESS);
117
+
118
+ CUgraphNode FreeNode;
119
+ Res = cuGraphAddMemFreeNode (&FreeNode, NativeGraph, &MemcpyNode, 1 ,
120
+ PtrAsync);
121
+ assert (Res == CUDA_SUCCESS);
122
+ #endif
71
123
});
72
124
});
73
125
0 commit comments