@@ -99,6 +99,90 @@ static void setCopyParams(const void *SrcPtr, const CUmemorytype_enum SrcType,
99
99
Params.Depth = 1 ;
100
100
}
101
101
102
+ // Helper function for enqueuing memory fills
103
+ static ur_result_t enqueueCommandBufferFillHelper (
104
+ ur_exp_command_buffer_handle_t CommandBuffer, void *DstDevice,
105
+ const CUmemorytype_enum DstType, const void *Pattern, size_t PatternSize,
106
+ size_t Size, uint32_t NumSyncPointsInWaitList,
107
+ const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
108
+ ur_exp_command_buffer_sync_point_t *SyncPoint) {
109
+ ur_result_t Result;
110
+ std::vector<CUgraphNode> DepsList;
111
+ UR_CALL (getNodesFromSyncPoints (CommandBuffer, NumSyncPointsInWaitList,
112
+ SyncPointWaitList, DepsList));
113
+
114
+ try {
115
+ size_t N = Size / PatternSize;
116
+ auto Value = *static_cast <const uint32_t *>(Pattern);
117
+ auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE
118
+ ? *static_cast <CUdeviceptr *>(DstDevice)
119
+ : (CUdeviceptr)DstDevice;
120
+
121
+ if ((PatternSize == 1 ) || (PatternSize == 2 ) || (PatternSize == 4 )) {
122
+ // Create a new node
123
+ CUgraphNode GraphNode;
124
+ CUDA_MEMSET_NODE_PARAMS NodeParams = {};
125
+ NodeParams.dst = DstPtr;
126
+ NodeParams.elementSize = PatternSize;
127
+ NodeParams.height = N;
128
+ NodeParams.pitch = PatternSize;
129
+ NodeParams.value = Value;
130
+ NodeParams.width = 1 ;
131
+
132
+ Result = UR_CHECK_ERROR (cuGraphAddMemsetNode (
133
+ &GraphNode, CommandBuffer->CudaGraph , DepsList.data (),
134
+ DepsList.size (), &NodeParams, CommandBuffer->Device ->getContext ()));
135
+
136
+ // Get sync point and register the cuNode with it.
137
+ *SyncPoint =
138
+ CommandBuffer->AddSyncPoint (std::make_shared<CUgraphNode>(GraphNode));
139
+
140
+ } else {
141
+ // CUDA has no memset functions that allow setting values more than 4
142
+ // bytes. UR API lets you pass an arbitrary "pattern" to the buffer
143
+ // fill, which can be more than 4 bytes. We must break up the pattern
144
+ // into 4 byte values, and set the buffer using multiple strided calls.
145
+ // This means that one cuGraphAddMemsetNode call is made for every 4 bytes
146
+ // in the pattern.
147
+
148
+ size_t NumberOfSteps = PatternSize / sizeof (uint32_t );
149
+
150
+ // we walk up the pattern in 4-byte steps, and call cuMemset for each
151
+ // 4-byte chunk of the pattern.
152
+ for (auto Step = 0u ; Step < NumberOfSteps; ++Step) {
153
+ // take 4 bytes of the pattern
154
+ auto Value = *(static_cast <const uint32_t *>(Pattern) + Step);
155
+
156
+ // offset the pointer to the part of the buffer we want to write to
157
+ auto OffsetPtr = DstPtr + (Step * sizeof (uint32_t ));
158
+
159
+ // Create a new node
160
+ CUgraphNode GraphNode;
161
+ // Update NodeParam
162
+ CUDA_MEMSET_NODE_PARAMS NodeParamsStep = {};
163
+ NodeParamsStep.dst = (CUdeviceptr)OffsetPtr;
164
+ NodeParamsStep.elementSize = 4 ;
165
+ NodeParamsStep.height = N;
166
+ NodeParamsStep.pitch = PatternSize;
167
+ NodeParamsStep.value = Value;
168
+ NodeParamsStep.width = 1 ;
169
+
170
+ Result = UR_CHECK_ERROR (cuGraphAddMemsetNode (
171
+ &GraphNode, CommandBuffer->CudaGraph , DepsList.data (),
172
+ DepsList.size (), &NodeParamsStep,
173
+ CommandBuffer->Device ->getContext ()));
174
+
175
+ // Get sync point and register the cuNode with it.
176
+ *SyncPoint = CommandBuffer->AddSyncPoint (
177
+ std::make_shared<CUgraphNode>(GraphNode));
178
+ }
179
+ }
180
+ } catch (ur_result_t Err) {
181
+ Result = Err;
182
+ }
183
+ return Result;
184
+ }
185
+
102
186
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp (
103
187
ur_context_handle_t hContext, ur_device_handle_t hDevice,
104
188
const ur_exp_command_buffer_desc_t *pCommandBufferDesc,
@@ -602,20 +686,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
602
686
uint32_t numSyncPointsInWaitList,
603
687
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
604
688
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
605
- (void )hCommandBuffer;
606
- (void )hBuffer;
607
- (void )pPattern;
608
- (void )patternSize;
609
- (void )offset;
610
- (void )size;
611
-
612
- (void )numSyncPointsInWaitList;
613
- (void )pSyncPointWaitList;
614
- (void )pSyncPoint;
615
-
616
- detail::ur::die (" Experimental Command-buffer feature is not "
617
- " implemented for CUDA adapter." );
618
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
689
+ auto ArgsAreMultiplesOfPatternSize =
690
+ (offset % patternSize == 0 ) || (size % patternSize == 0 );
691
+
692
+ auto PatternIsValid = (pPattern != nullptr );
693
+
694
+ auto PatternSizeIsValid = ((patternSize & (patternSize - 1 )) == 0 ) &&
695
+ (patternSize > 0 ); // is a positive power of two
696
+ UR_ASSERT (ArgsAreMultiplesOfPatternSize && PatternIsValid &&
697
+ PatternSizeIsValid,
698
+ UR_RESULT_ERROR_INVALID_SIZE);
699
+
700
+ auto DstDevice = std::get<BufferMem>(hBuffer->Mem ).get () + offset;
701
+
702
+ return enqueueCommandBufferFillHelper (
703
+ hCommandBuffer, &DstDevice, CU_MEMORYTYPE_DEVICE, pPattern, patternSize,
704
+ size, numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint);
619
705
}
620
706
621
707
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp (
@@ -624,19 +710,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp(
624
710
uint32_t numSyncPointsInWaitList,
625
711
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
626
712
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
627
- (void )hCommandBuffer;
628
- (void )pPtr;
629
- (void )pPattern;
630
- (void )patternSize;
631
- (void )size;
632
-
633
- (void )numSyncPointsInWaitList;
634
- (void )pSyncPointWaitList;
635
- (void )pSyncPoint;
636
-
637
- detail::ur::die (" Experimental Command-buffer feature is not "
638
- " implemented for CUDA adapter." );
639
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
713
+
714
+ auto PatternIsValid = (pPattern != nullptr );
715
+
716
+ auto PatternSizeIsValid = ((patternSize & (patternSize - 1 )) == 0 ) &&
717
+ (patternSize > 0 ); // is a positive power of two
718
+
719
+ UR_ASSERT (PatternIsValid && PatternSizeIsValid, UR_RESULT_ERROR_INVALID_SIZE);
720
+ return enqueueCommandBufferFillHelper (
721
+ hCommandBuffer, pPtr, CU_MEMORYTYPE_UNIFIED, pPattern, patternSize, size,
722
+ numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint);
640
723
}
641
724
642
725
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp (
0 commit comments