Skip to content

Commit a9a325d

Browse files
mfrancepilloisEwanC
authored andcommitted
Adds CUDA support
1 parent 2882e1f commit a9a325d

File tree

1 file changed

+110
-27
lines changed

1 file changed

+110
-27
lines changed

source/adapters/cuda/command_buffer.cpp

Lines changed: 110 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,90 @@ static void setCopyParams(const void *SrcPtr, const CUmemorytype_enum SrcType,
9999
Params.Depth = 1;
100100
}
101101

102+
// Helper function for enqueuing memory fills
103+
static ur_result_t enqueueCommandBufferFillHelper(
104+
ur_exp_command_buffer_handle_t CommandBuffer, void *DstDevice,
105+
const CUmemorytype_enum DstType, const void *Pattern, size_t PatternSize,
106+
size_t Size, uint32_t NumSyncPointsInWaitList,
107+
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
108+
ur_exp_command_buffer_sync_point_t *SyncPoint) {
109+
ur_result_t Result;
110+
std::vector<CUgraphNode> DepsList;
111+
UR_CALL(getNodesFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
112+
SyncPointWaitList, DepsList));
113+
114+
try {
115+
size_t N = Size / PatternSize;
116+
auto Value = *static_cast<const uint32_t *>(Pattern);
117+
auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE
118+
? *static_cast<CUdeviceptr *>(DstDevice)
119+
: (CUdeviceptr)DstDevice;
120+
121+
if ((PatternSize == 1) || (PatternSize == 2) || (PatternSize == 4)) {
122+
// Create a new node
123+
CUgraphNode GraphNode;
124+
CUDA_MEMSET_NODE_PARAMS NodeParams = {};
125+
NodeParams.dst = DstPtr;
126+
NodeParams.elementSize = PatternSize;
127+
NodeParams.height = N;
128+
NodeParams.pitch = PatternSize;
129+
NodeParams.value = Value;
130+
NodeParams.width = 1;
131+
132+
Result = UR_CHECK_ERROR(cuGraphAddMemsetNode(
133+
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
134+
DepsList.size(), &NodeParams, CommandBuffer->Device->getContext()));
135+
136+
// Get sync point and register the cuNode with it.
137+
*SyncPoint =
138+
CommandBuffer->AddSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
139+
140+
} else {
141+
// CUDA has no memset functions that allow setting values more than 4
142+
// bytes. UR API lets you pass an arbitrary "pattern" to the buffer
143+
// fill, which can be more than 4 bytes. We must break up the pattern
144+
// into 4 byte values, and set the buffer using multiple strided calls.
145+
// This means that one cuGraphAddMemsetNode call is made for every 4 bytes
146+
// in the pattern.
147+
148+
size_t NumberOfSteps = PatternSize / sizeof(uint32_t);
149+
150+
// we walk up the pattern in 4-byte steps, and call cuMemset for each
151+
// 4-byte chunk of the pattern.
152+
for (auto Step = 0u; Step < NumberOfSteps; ++Step) {
153+
// take 4 bytes of the pattern
154+
auto Value = *(static_cast<const uint32_t *>(Pattern) + Step);
155+
156+
// offset the pointer to the part of the buffer we want to write to
157+
auto OffsetPtr = DstPtr + (Step * sizeof(uint32_t));
158+
159+
// Create a new node
160+
CUgraphNode GraphNode;
161+
// Update NodeParam
162+
CUDA_MEMSET_NODE_PARAMS NodeParamsStep = {};
163+
NodeParamsStep.dst = (CUdeviceptr)OffsetPtr;
164+
NodeParamsStep.elementSize = 4;
165+
NodeParamsStep.height = N;
166+
NodeParamsStep.pitch = PatternSize;
167+
NodeParamsStep.value = Value;
168+
NodeParamsStep.width = 1;
169+
170+
Result = UR_CHECK_ERROR(cuGraphAddMemsetNode(
171+
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
172+
DepsList.size(), &NodeParamsStep,
173+
CommandBuffer->Device->getContext()));
174+
175+
// Get sync point and register the cuNode with it.
176+
*SyncPoint = CommandBuffer->AddSyncPoint(
177+
std::make_shared<CUgraphNode>(GraphNode));
178+
}
179+
}
180+
} catch (ur_result_t Err) {
181+
Result = Err;
182+
}
183+
return Result;
184+
}
185+
102186
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(
103187
ur_context_handle_t hContext, ur_device_handle_t hDevice,
104188
const ur_exp_command_buffer_desc_t *pCommandBufferDesc,
@@ -602,20 +686,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
602686
uint32_t numSyncPointsInWaitList,
603687
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
604688
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
605-
(void)hCommandBuffer;
606-
(void)hBuffer;
607-
(void)pPattern;
608-
(void)patternSize;
609-
(void)offset;
610-
(void)size;
611-
612-
(void)numSyncPointsInWaitList;
613-
(void)pSyncPointWaitList;
614-
(void)pSyncPoint;
615-
616-
detail::ur::die("Experimental Command-buffer feature is not "
617-
"implemented for CUDA adapter.");
618-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
689+
auto ArgsAreMultiplesOfPatternSize =
690+
(offset % patternSize == 0) || (size % patternSize == 0);
691+
692+
auto PatternIsValid = (pPattern != nullptr);
693+
694+
auto PatternSizeIsValid = ((patternSize & (patternSize - 1)) == 0) &&
695+
(patternSize > 0); // is a positive power of two
696+
UR_ASSERT(ArgsAreMultiplesOfPatternSize && PatternIsValid &&
697+
PatternSizeIsValid,
698+
UR_RESULT_ERROR_INVALID_SIZE);
699+
700+
auto DstDevice = std::get<BufferMem>(hBuffer->Mem).get() + offset;
701+
702+
return enqueueCommandBufferFillHelper(
703+
hCommandBuffer, &DstDevice, CU_MEMORYTYPE_DEVICE, pPattern, patternSize,
704+
size, numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint);
619705
}
620706

621707
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp(
@@ -624,19 +710,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp(
624710
uint32_t numSyncPointsInWaitList,
625711
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
626712
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
627-
(void)hCommandBuffer;
628-
(void)pPtr;
629-
(void)pPattern;
630-
(void)patternSize;
631-
(void)size;
632-
633-
(void)numSyncPointsInWaitList;
634-
(void)pSyncPointWaitList;
635-
(void)pSyncPoint;
636-
637-
detail::ur::die("Experimental Command-buffer feature is not "
638-
"implemented for CUDA adapter.");
639-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
713+
714+
auto PatternIsValid = (pPattern != nullptr);
715+
716+
auto PatternSizeIsValid = ((patternSize & (patternSize - 1)) == 0) &&
717+
(patternSize > 0); // is a positive power of two
718+
719+
UR_ASSERT(PatternIsValid && PatternSizeIsValid, UR_RESULT_ERROR_INVALID_SIZE);
720+
return enqueueCommandBufferFillHelper(
721+
hCommandBuffer, pPtr, CU_MEMORYTYPE_UNIFIED, pPattern, patternSize, size,
722+
numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint);
640723
}
641724

642725
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(

0 commit comments

Comments
 (0)