@@ -718,31 +718,28 @@ urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t CommandBuffer) {
718
718
return UR_RESULT_SUCCESS;
719
719
}
720
720
721
- UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp (
722
- ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
723
- uint32_t WorkDim, const size_t *GlobalWorkOffset,
724
- const size_t *GlobalWorkSize, const size_t *LocalWorkSize,
725
- uint32_t NumSyncPointsInWaitList,
726
- const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
727
- ur_exp_command_buffer_sync_point_t *RetSyncPoint,
728
- ur_exp_command_buffer_command_handle_t *Command) {
729
- UR_ASSERT (Kernel->Program , UR_RESULT_ERROR_INVALID_NULL_POINTER);
730
- // Lock automatically releases when this goes out of scope.
731
- std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
732
- Kernel->Mutex , Kernel->Program ->Mutex , CommandBuffer->Mutex );
733
-
734
- if (GlobalWorkOffset != NULL ) {
735
- if (!CommandBuffer->Context ->getPlatform ()
736
- ->ZeDriverGlobalOffsetExtensionFound ) {
737
- logger::debug (" No global offset extension found on this driver" );
738
- return UR_RESULT_ERROR_INVALID_VALUE;
739
- }
721
+ static ur_result_t
722
+ setKernelGlobalOffset (ur_exp_command_buffer_handle_t CommandBuffer,
723
+ ur_kernel_handle_t Kernel,
724
+ const size_t *GlobalWorkOffset) {
740
725
741
- ZE2UR_CALL (zeKernelSetGlobalOffsetExp,
742
- (Kernel->ZeKernel , GlobalWorkOffset[0 ], GlobalWorkOffset[1 ],
743
- GlobalWorkOffset[2 ]));
726
+ if (!CommandBuffer->Context ->getPlatform ()
727
+ ->ZeDriverGlobalOffsetExtensionFound ) {
728
+ logger::debug (" No global offset extension found on this driver" );
729
+ return UR_RESULT_ERROR_INVALID_VALUE;
744
730
}
745
731
732
+ ZE2UR_CALL (zeKernelSetGlobalOffsetExp,
733
+ (Kernel->ZeKernel , GlobalWorkOffset[0 ], GlobalWorkOffset[1 ],
734
+ GlobalWorkOffset[2 ]));
735
+
736
+ return UR_RESULT_SUCCESS;
737
+ }
738
+
739
+ static ur_result_t
740
+ setKernelPendingArguments (ur_exp_command_buffer_handle_t CommandBuffer,
741
+ ur_kernel_handle_t Kernel) {
742
+
746
743
// If there are any pending arguments set them now.
747
744
for (auto &Arg : Kernel->PendingArguments ) {
748
745
// The ArgValue may be a NULL pointer in which case a NULL value is used for
@@ -757,25 +754,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
757
754
}
758
755
Kernel->PendingArguments .clear ();
759
756
760
- ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
761
- uint32_t WG[3 ];
762
-
763
- UR_CALL (calculateKernelWorkDimensions (Kernel, CommandBuffer->Device ,
764
- ZeThreadGroupDimensions, WG, WorkDim,
765
- GlobalWorkSize, LocalWorkSize));
766
-
767
- ZE2UR_CALL (zeKernelSetGroupSize, (Kernel->ZeKernel , WG[0 ], WG[1 ], WG[2 ]));
757
+ return UR_RESULT_SUCCESS;
758
+ }
768
759
769
- CommandBuffer->KernelsList .push_back (Kernel);
770
- // Increment the reference count of the Kernel and indicate that the Kernel
771
- // is in use. Once the event has been signaled, the code in
772
- // CleanupCompletedEvent(Event) will do a urKernelRelease to update the
773
- // reference count on the kernel, using the kernel saved in CommandData.
774
- UR_CALL (urKernelRetain (Kernel));
760
+ static ur_result_t
761
+ createCommandHandle (ur_exp_command_buffer_handle_t CommandBuffer,
762
+ ur_kernel_handle_t Kernel, uint32_t WorkDim,
763
+ const size_t *LocalWorkSize,
764
+ ur_exp_command_buffer_command_handle_t & Command) {
775
765
776
766
// If command-buffer is updatable then get command id which is going to be
777
767
// used if command is updated in the future. This
778
- // zeCommandListGetNextCommandIdExp can be called only if command is
768
+ // zeCommandListGetNextCommandIdExp can be called only if the command is
779
769
// updatable.
780
770
uint64_t CommandId = 0 ;
781
771
if (CommandBuffer->IsUpdatable ) {
@@ -794,15 +784,60 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
794
784
DEBUG_LOG (CommandId);
795
785
}
796
786
try {
797
- if (Command)
798
- *Command = new ur_exp_command_buffer_command_handle_t_ (
799
- CommandBuffer, CommandId, WorkDim, LocalWorkSize != nullptr , Kernel);
787
+ Command = new ur_exp_command_buffer_command_handle_t_ (
788
+ CommandBuffer, CommandId, WorkDim, LocalWorkSize != nullptr , Kernel);
800
789
} catch (const std::bad_alloc &) {
801
790
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
802
791
} catch (...) {
803
792
return UR_RESULT_ERROR_UNKNOWN;
804
793
}
805
794
795
+ return UR_RESULT_SUCCESS;
796
+ }
797
+ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp (
798
+ ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
799
+ uint32_t WorkDim, const size_t *GlobalWorkOffset,
800
+ const size_t *GlobalWorkSize, const size_t *LocalWorkSize,
801
+ uint32_t NumSyncPointsInWaitList,
802
+ const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
803
+ ur_exp_command_buffer_sync_point_t *RetSyncPoint,
804
+ ur_exp_command_buffer_command_handle_t *Command) {
805
+ UR_ASSERT (Kernel->Program , UR_RESULT_ERROR_INVALID_NULL_POINTER);
806
+
807
+ // Lock automatically releases when this goes out of scope.
808
+ std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
809
+ Kernel->Mutex , Kernel->Program ->Mutex , CommandBuffer->Mutex );
810
+
811
+ if (GlobalWorkOffset != NULL ) {
812
+ UR_CALL (setKernelGlobalOffset (CommandBuffer, Kernel, GlobalWorkOffset));
813
+ }
814
+
815
+ // If there are any pending arguments set them now.
816
+ if (!Kernel->PendingArguments .empty ()) {
817
+ UR_CALL (setKernelPendingArguments (CommandBuffer, Kernel));
818
+ }
819
+
820
+ ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
821
+ uint32_t WG[3 ];
822
+ UR_CALL (calculateKernelWorkDimensions (Kernel, CommandBuffer->Device ,
823
+ ZeThreadGroupDimensions, WG, WorkDim,
824
+ GlobalWorkSize, LocalWorkSize));
825
+
826
+ ZE2UR_CALL (zeKernelSetGroupSize, (Kernel->ZeKernel , WG[0 ], WG[1 ], WG[2 ]));
827
+
828
+ CommandBuffer->KernelsList .push_back (Kernel);
829
+
830
+ // Increment the reference count of the Kernel and indicate that the Kernel
831
+ // is in use. Once the event has been signaled, the code in
832
+ // CleanupCompletedEvent(Event) will do a urKernelRelease to update the
833
+ // reference count on the kernel, using the kernel saved in CommandData.
834
+ UR_CALL (urKernelRetain (Kernel));
835
+
836
+ if (Command && CommandBuffer->IsUpdatable ) {
837
+ UR_CALL (createCommandHandle (CommandBuffer, Kernel, WorkDim, LocalWorkSize,
838
+ *Command));
839
+ }
840
+
806
841
if (CommandBuffer->IsInOrderCmdList ) {
807
842
ZE2UR_CALL (zeCommandListAppendLaunchKernel,
808
843
(CommandBuffer->ZeComputeCommandList , Kernel->ZeKernel ,
0 commit comments