Skip to content

Commit 7c0dc6d

Browse files
authored
Merge pull request #2425 from Bensuo/ben/cuda-event-fix
[CUDA] Fix potential issue with command buffer fills on CUDA
2 parents 06a8c51 + 76054dd commit 7c0dc6d

File tree

5 files changed

+249
-22
lines changed

5 files changed

+249
-22
lines changed

source/adapters/cuda/command_buffer.cpp

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -236,16 +236,29 @@ static ur_result_t enqueueCommandBufferFillHelper(
236236
EventWaitList));
237237
}
238238

239+
// CUDA has no memset functions that allow setting values more than 4
240+
// bytes. UR API lets you pass an arbitrary "pattern" to the buffer
241+
// fill, which can be more than 4 bytes. Calculate the number of steps
242+
// required here to see if decomposing to multiple fill nodes is required.
243+
size_t NumberOfSteps = PatternSize / sizeof(uint8_t);
244+
239245
// Graph node added to graph, if multiple nodes are created this will
240246
// be set to the leaf node
241247
CUgraphNode GraphNode;
248+
// Track if multiple nodes are created so we can pass them to the command
249+
// handle
250+
std::vector<CUgraphNode> DecomposedNodes;
251+
252+
if (NumberOfSteps > 4) {
253+
DecomposedNodes.reserve(NumberOfSteps);
254+
}
242255

243256
const size_t N = Size / PatternSize;
244257
auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE
245258
? *static_cast<CUdeviceptr *>(DstDevice)
246259
: (CUdeviceptr)DstDevice;
247260

248-
if ((PatternSize == 1) || (PatternSize == 2) || (PatternSize == 4)) {
261+
if (NumberOfSteps <= 4) {
249262
CUDA_MEMSET_NODE_PARAMS NodeParams = {};
250263
NodeParams.dst = DstPtr;
251264
NodeParams.elementSize = PatternSize;
@@ -276,14 +289,9 @@ static ur_result_t enqueueCommandBufferFillHelper(
276289
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
277290
&NodeParams, CommandBuffer->Device->getNativeContext()));
278291
} else {
279-
// CUDA has no memset functions that allow setting values more than 4
280-
// bytes. UR API lets you pass an arbitrary "pattern" to the buffer
281-
// fill, which can be more than 4 bytes. We must break up the pattern
282-
// into 1 byte values, and set the buffer using multiple strided calls.
283-
// This means that one cuGraphAddMemsetNode call is made for every 1
284-
// bytes in the pattern.
285-
286-
size_t NumberOfSteps = PatternSize / sizeof(uint8_t);
292+
// We must break up the rest of the pattern into 1 byte values, and set
293+
// the buffer using multiple strided calls. This means that one
294+
// cuGraphAddMemsetNode call is made for every 1 bytes in the pattern.
287295

288296
// Update NodeParam
289297
CUDA_MEMSET_NODE_PARAMS NodeParamsStepFirst = {};
@@ -294,12 +302,13 @@ static ur_result_t enqueueCommandBufferFillHelper(
294302
NodeParamsStepFirst.value = *static_cast<const uint32_t *>(Pattern);
295303
NodeParamsStepFirst.width = 1;
296304

305+
// Inital decomposed node depends on the provided external event wait
306+
// nodes
297307
UR_CHECK_ERROR(cuGraphAddMemsetNode(
298308
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(), DepsList.size(),
299309
&NodeParamsStepFirst, CommandBuffer->Device->getNativeContext()));
300310

301-
DepsList.clear();
302-
DepsList.push_back(GraphNode);
311+
DecomposedNodes.push_back(GraphNode);
303312

304313
// we walk up the pattern in 1-byte steps, and call cuMemset for each
305314
// 1-byte chunk of the pattern.
@@ -319,13 +328,16 @@ static ur_result_t enqueueCommandBufferFillHelper(
319328
NodeParamsStep.value = Value;
320329
NodeParamsStep.width = 1;
321330

331+
// Copy the last GraphNode ptr so we can pass it as the dependency for
332+
// the next one
333+
CUgraphNode PrevNode = GraphNode;
334+
322335
UR_CHECK_ERROR(cuGraphAddMemsetNode(
323-
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
324-
DepsList.size(), &NodeParamsStep,
336+
&GraphNode, CommandBuffer->CudaGraph, &PrevNode, 1, &NodeParamsStep,
325337
CommandBuffer->Device->getNativeContext()));
326338

327-
DepsList.clear();
328-
DepsList.push_back(GraphNode);
339+
// Store the decomposed node
340+
DecomposedNodes.push_back(GraphNode);
329341
}
330342
}
331343

@@ -344,7 +356,8 @@ static ur_result_t enqueueCommandBufferFillHelper(
344356

345357
std::vector<CUgraphNode> WaitNodes =
346358
NumEventsInWaitList ? std::move(DepsList) : std::vector<CUgraphNode>();
347-
auto NewCommand = new T(CommandBuffer, GraphNode, SignalNode, WaitNodes);
359+
auto NewCommand = new T(CommandBuffer, GraphNode, SignalNode, WaitNodes,
360+
std::move(DecomposedNodes));
348361
CommandBuffer->CommandHandles.push_back(NewCommand);
349362

350363
if (RetCommand) {

source/adapters/cuda/command_buffer.hpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,19 @@ struct usm_memcpy_command_handle : ur_exp_command_buffer_command_handle_t_ {
172172
struct usm_fill_command_handle : ur_exp_command_buffer_command_handle_t_ {
173173
usm_fill_command_handle(ur_exp_command_buffer_handle_t CommandBuffer,
174174
CUgraphNode Node, CUgraphNode SignalNode,
175-
const std::vector<CUgraphNode> &WaitNodes)
175+
const std::vector<CUgraphNode> &WaitNodes,
176+
const std::vector<CUgraphNode> &DecomposedNodes = {})
176177
: ur_exp_command_buffer_command_handle_t_(CommandBuffer, Node, SignalNode,
177-
WaitNodes) {}
178+
WaitNodes),
179+
DecomposedNodes(std::move(DecomposedNodes)) {}
178180
CommandType getCommandType() const noexcept override {
179181
return CommandType::USMFill;
180182
}
183+
184+
// If this fill command was decomposed into multiple nodes, this vector
185+
// contains all of those nodes in the order they were added to the graph.
186+
// Currently unused but will be required for updating in future.
187+
std::vector<CUgraphNode> DecomposedNodes;
181188
};
182189

183190
struct buffer_copy_command_handle : ur_exp_command_buffer_command_handle_t_ {
@@ -250,14 +257,21 @@ struct buffer_write_rect_command_handle
250257
};
251258

252259
struct buffer_fill_command_handle : ur_exp_command_buffer_command_handle_t_ {
253-
buffer_fill_command_handle(ur_exp_command_buffer_handle_t CommandBuffer,
254-
CUgraphNode Node, CUgraphNode SignalNode,
255-
const std::vector<CUgraphNode> &WaitNodes)
260+
buffer_fill_command_handle(
261+
ur_exp_command_buffer_handle_t CommandBuffer, CUgraphNode Node,
262+
CUgraphNode SignalNode, const std::vector<CUgraphNode> &WaitNodes,
263+
const std::vector<CUgraphNode> &DecomposedNodes = {})
256264
: ur_exp_command_buffer_command_handle_t_(CommandBuffer, Node, SignalNode,
257-
WaitNodes) {}
265+
WaitNodes),
266+
DecomposedNodes(std::move(DecomposedNodes)) {}
258267
CommandType getCommandType() const noexcept override {
259268
return CommandType::MemBufferFill;
260269
}
270+
271+
// If this fill command was decomposed into multiple nodes, this vector
272+
// contains all of those nodes in the order they were added to the graph.
273+
// Currently unused but will be required for updating in future.
274+
std::vector<CUgraphNode> DecomposedNodes;
261275
};
262276

263277
struct usm_prefetch_command_handle : ur_exp_command_buffer_command_handle_t_ {

test/conformance/exp_command_buffer/event_sync.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,42 @@ TEST_P(CommandEventSyncTest, USMFillExp) {
7575
}
7676
}
7777

78+
// Test fill using a large pattern size since implementations may need to handle
79+
// this differently.
80+
TEST_P(CommandEventSyncTest, USMFillLargePatternExp) {
81+
// Device ptrs are allocated in the test fixture with 32-bit values * num
82+
// elements, since we are doubling the pattern size we want to treat those
83+
// device pointers as if they were created with half the number of elements.
84+
constexpr size_t modifiedElementSize = elements / 2;
85+
// Get wait event from queue fill on ptr 0
86+
uint64_t patternX = 42;
87+
ASSERT_SUCCESS(urEnqueueUSMFill(queue, device_ptrs[0], sizeof(patternX),
88+
&patternX, allocation_size, 0, nullptr,
89+
&external_events[0]));
90+
91+
// Test fill command overwriting ptr 0 waiting on queue event
92+
uint64_t patternY = 0xA;
93+
ASSERT_SUCCESS(urCommandBufferAppendUSMFillExp(
94+
cmd_buf_handle, device_ptrs[0], &patternY, sizeof(patternY),
95+
allocation_size, 0, nullptr, 1, &external_events[0], nullptr,
96+
&external_events[1], nullptr));
97+
ASSERT_SUCCESS(urCommandBufferFinalizeExp(cmd_buf_handle));
98+
ASSERT_SUCCESS(
99+
urCommandBufferEnqueueExp(cmd_buf_handle, queue, 0, nullptr, nullptr));
100+
101+
// Queue read ptr 0 based on event returned from command-buffer command
102+
std::array<uint64_t, modifiedElementSize> host_enqueue_ptr{};
103+
ASSERT_SUCCESS(urEnqueueUSMMemcpy(queue, false, host_enqueue_ptr.data(),
104+
device_ptrs[0], allocation_size, 1,
105+
&external_events[1], nullptr));
106+
107+
// Verify
108+
ASSERT_SUCCESS(urQueueFinish(queue));
109+
for (size_t i = 0; i < modifiedElementSize; i++) {
110+
ASSERT_EQ(host_enqueue_ptr[i], patternY);
111+
}
112+
}
113+
78114
TEST_P(CommandEventSyncTest, MemBufferCopyExp) {
79115
// Get wait event from queue fill on buffer 0
80116
uint32_t patternX = 42;
@@ -341,6 +377,42 @@ TEST_P(CommandEventSyncTest, MemBufferFillExp) {
341377
}
342378
}
343379

380+
// Test fill using a large pattern size since implementations may need to handle
381+
// this differently.
382+
TEST_P(CommandEventSyncTest, MemBufferFillLargePatternExp) {
383+
// Device buffers are allocated in the test fixture with 32-bit values * num
384+
// elements, since we are doubling the pattern size we want to treat those
385+
// device pointers as if they were created with half the number of elements.
386+
constexpr size_t modifiedElementSize = elements / 2;
387+
// Get wait event from queue fill on buffer 0
388+
uint64_t patternX = 42;
389+
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffers[0], &patternX,
390+
sizeof(patternX), 0, allocation_size,
391+
0, nullptr, &external_events[0]));
392+
393+
// Test fill command overwriting buffer 0 based on queue event
394+
uint64_t patternY = 0xA;
395+
ASSERT_SUCCESS(urCommandBufferAppendMemBufferFillExp(
396+
cmd_buf_handle, buffers[0], &patternY, sizeof(patternY), 0,
397+
allocation_size, 0, nullptr, 1, &external_events[0], nullptr,
398+
&external_events[1], nullptr));
399+
ASSERT_SUCCESS(urCommandBufferFinalizeExp(cmd_buf_handle));
400+
ASSERT_SUCCESS(
401+
urCommandBufferEnqueueExp(cmd_buf_handle, queue, 0, nullptr, nullptr));
402+
403+
// Queue read buffer 0 based on event returned from command-buffer command
404+
std::array<uint64_t, modifiedElementSize> host_enqueue_ptr{};
405+
ASSERT_SUCCESS(urEnqueueMemBufferRead(
406+
queue, buffers[0], false, 0, allocation_size, host_enqueue_ptr.data(),
407+
1, &external_events[1], nullptr));
408+
409+
// Verify
410+
ASSERT_SUCCESS(urQueueFinish(queue));
411+
for (size_t i = 0; i < modifiedElementSize; i++) {
412+
ASSERT_EQ(host_enqueue_ptr[i], patternY);
413+
}
414+
}
415+
344416
TEST_P(CommandEventSyncTest, USMPrefetchExp) {
345417
// Get wait event from queue fill on ptr 0
346418
uint32_t patternX = 42;

test/conformance/exp_command_buffer/exp_command_buffer_adapter_level_zero_v2.match

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ CommandEventSyncTest.USMPrefetchExp/*
3333
CommandEventSyncTest.USMAdviseExp/*
3434
CommandEventSyncTest.MultipleEventCommands/*
3535
CommandEventSyncTest.MultipleEventCommandsBetweenCommandBuffers/*
36+
CommandEventSyncTest.USMFillLargePatternExp/*
37+
CommandEventSyncTest.MemBufferFillLargePatternExp/*
3638
CommandEventSyncUpdateTest.USMMemcpyExp/*
3739
CommandEventSyncUpdateTest.USMFillExp/*
3840
CommandEventSyncUpdateTest.MemBufferCopyExp/*
@@ -45,3 +47,5 @@ CommandEventSyncUpdateTest.MemBufferFillExp/*
4547
CommandEventSyncUpdateTest.USMPrefetchExp/*
4648
CommandEventSyncUpdateTest.USMAdviseExp/*
4749
CommandEventSyncUpdateTest.MultipleEventCommands/*
50+
CommandEventSyncUpdateTest.USMFillLargePatternExp/*
51+
CommandEventSyncUpdateTest.MemBufferFillLargePatternExp/*

test/conformance/exp_command_buffer/update/event_sync.cpp

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,68 @@ TEST_P(CommandEventSyncUpdateTest, USMFillExp) {
129129
}
130130
}
131131

132+
// Test fill using a large pattern size since implementations may need to handle
133+
// this differently.
134+
TEST_P(CommandEventSyncUpdateTest, USMFillLargePatternExp) {
135+
// Device ptrs are allocated in the test fixture with 32-bit values * num
136+
// elements, since we are doubling the pattern size we want to treat those
137+
// device pointers as if they were created with half the number of elements.
138+
constexpr size_t modifiedElementSize = elements / 2;
139+
// Get wait event from queue fill on ptr 0
140+
uint64_t patternX = 42;
141+
ASSERT_SUCCESS(urEnqueueUSMFill(queue, device_ptrs[0], sizeof(patternX),
142+
&patternX, allocation_size, 0, nullptr,
143+
&external_events[0]));
144+
145+
// Test fill command overwriting ptr 0 waiting on queue event
146+
uint64_t patternY = 0xA;
147+
ASSERT_SUCCESS(urCommandBufferAppendUSMFillExp(
148+
updatable_cmd_buf_handle, device_ptrs[0], &patternY, sizeof(patternY),
149+
allocation_size, 0, nullptr, 1, &external_events[0], nullptr,
150+
&external_events[1], &command_handles[0]));
151+
ASSERT_NE(nullptr, command_handles[0]);
152+
ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle));
153+
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
154+
nullptr, nullptr));
155+
156+
// Queue read ptr 0 based on event returned from command-buffer command
157+
std::array<uint64_t, modifiedElementSize> host_enqueue_ptr{};
158+
ASSERT_SUCCESS(urEnqueueUSMMemcpy(queue, false, host_enqueue_ptr.data(),
159+
device_ptrs[0], allocation_size, 1,
160+
&external_events[1], nullptr));
161+
162+
// Verify
163+
ASSERT_SUCCESS(urQueueFinish(queue));
164+
for (size_t i = 0; i < modifiedElementSize; i++) {
165+
ASSERT_EQ(host_enqueue_ptr[i], patternY);
166+
}
167+
168+
uint64_t patternZ = 666;
169+
ASSERT_SUCCESS(urEnqueueUSMFill(queue, device_ptrs[0], sizeof(patternZ),
170+
&patternZ, allocation_size, 0, nullptr,
171+
&external_events[2]));
172+
173+
// Update command command-wait event to wait on fill of new value
174+
ASSERT_SUCCESS(urCommandBufferUpdateWaitEventsExp(command_handles[0], 1,
175+
&external_events[2]));
176+
177+
// Get a new signal event for command-buffer
178+
ASSERT_SUCCESS(urCommandBufferUpdateSignalEventExp(command_handles[0],
179+
&external_events[3]));
180+
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
181+
nullptr, nullptr));
182+
183+
ASSERT_SUCCESS(urEnqueueUSMMemcpy(queue, false, host_enqueue_ptr.data(),
184+
device_ptrs[0], allocation_size, 1,
185+
&external_events[3], nullptr));
186+
187+
// Verify update
188+
ASSERT_SUCCESS(urQueueFinish(queue));
189+
for (size_t i = 0; i < modifiedElementSize; i++) {
190+
ASSERT_EQ(host_enqueue_ptr[i], patternY);
191+
}
192+
}
193+
132194
TEST_P(CommandEventSyncUpdateTest, MemBufferCopyExp) {
133195
// Get wait event from queue fill on buffer 0
134196
uint32_t patternX = 42;
@@ -532,6 +594,68 @@ TEST_P(CommandEventSyncUpdateTest, MemBufferWriteRectExp) {
532594
}
533595
}
534596

597+
// Test fill using a large pattern size since implementations may need to handle
598+
// this differently.
599+
TEST_P(CommandEventSyncUpdateTest, MemBufferFillLargePatternExp) {
600+
// Device buffers are allocated in the test fixture with 32-bit values * num
601+
// elements, since we are doubling the pattern size we want to treat those
602+
// device pointers as if they were created with half the number of elements.
603+
constexpr size_t modifiedElementSize = elements / 2;
604+
// Get wait event from queue fill on buffer 0
605+
uint64_t patternX = 42;
606+
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffers[0], &patternX,
607+
sizeof(patternX), 0, allocation_size,
608+
0, nullptr, &external_events[0]));
609+
610+
// Test fill command overwriting buffer 0 based on queue event
611+
uint64_t patternY = 0xA;
612+
ASSERT_SUCCESS(urCommandBufferAppendMemBufferFillExp(
613+
updatable_cmd_buf_handle, buffers[0], &patternY, sizeof(patternY), 0,
614+
allocation_size, 0, nullptr, 1, &external_events[0], nullptr,
615+
&external_events[1], &command_handles[0]));
616+
ASSERT_NE(nullptr, command_handles[0]);
617+
ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle));
618+
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
619+
nullptr, nullptr));
620+
621+
// Queue read buffer 0 based on event returned from command-buffer command
622+
std::array<uint64_t, modifiedElementSize> host_enqueue_ptr{};
623+
ASSERT_SUCCESS(urEnqueueMemBufferRead(
624+
queue, buffers[0], false, 0, allocation_size, host_enqueue_ptr.data(),
625+
1, &external_events[1], nullptr));
626+
627+
// Verify
628+
ASSERT_SUCCESS(urQueueFinish(queue));
629+
for (size_t i = 0; i < modifiedElementSize; i++) {
630+
ASSERT_EQ(host_enqueue_ptr[i], patternY);
631+
}
632+
633+
uint64_t patternZ = 666;
634+
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffers[0], &patternZ,
635+
sizeof(patternZ), 0, allocation_size,
636+
0, nullptr, &external_events[2]));
637+
638+
// Update command command-wait event to wait on fill of new value
639+
ASSERT_SUCCESS(urCommandBufferUpdateWaitEventsExp(command_handles[0], 1,
640+
&external_events[2]));
641+
642+
// Get a new signal event for command-buffer
643+
ASSERT_SUCCESS(urCommandBufferUpdateSignalEventExp(command_handles[0],
644+
&external_events[3]));
645+
646+
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
647+
nullptr, nullptr));
648+
ASSERT_SUCCESS(urEnqueueMemBufferRead(
649+
queue, buffers[0], false, 0, allocation_size, host_enqueue_ptr.data(),
650+
1, &external_events[3], nullptr));
651+
652+
// Verify update
653+
ASSERT_SUCCESS(urQueueFinish(queue));
654+
for (size_t i = 0; i < modifiedElementSize; i++) {
655+
ASSERT_EQ(host_enqueue_ptr[i], patternY);
656+
}
657+
}
658+
535659
TEST_P(CommandEventSyncUpdateTest, MemBufferFillExp) {
536660
// Get wait event from queue fill on buffer 0
537661
uint32_t patternX = 42;

0 commit comments

Comments
 (0)