Skip to content

Commit 802a3ce

Browse files
committed
EnqueueKernel
1 parent cd0bec9 commit 802a3ce

File tree

1 file changed

+75
-40
lines changed

1 file changed

+75
-40
lines changed

source/adapters/level_zero/command_buffer.cpp

Lines changed: 75 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -718,31 +718,28 @@ urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t CommandBuffer) {
718718
return UR_RESULT_SUCCESS;
719719
}
720720

721-
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
722-
ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
723-
uint32_t WorkDim, const size_t *GlobalWorkOffset,
724-
const size_t *GlobalWorkSize, const size_t *LocalWorkSize,
725-
uint32_t NumSyncPointsInWaitList,
726-
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
727-
ur_exp_command_buffer_sync_point_t *RetSyncPoint,
728-
ur_exp_command_buffer_command_handle_t *Command) {
729-
UR_ASSERT(Kernel->Program, UR_RESULT_ERROR_INVALID_NULL_POINTER);
730-
// Lock automatically releases when this goes out of scope.
731-
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock(
732-
Kernel->Mutex, Kernel->Program->Mutex, CommandBuffer->Mutex);
733-
734-
if (GlobalWorkOffset != NULL) {
735-
if (!CommandBuffer->Context->getPlatform()
736-
->ZeDriverGlobalOffsetExtensionFound) {
737-
logger::debug("No global offset extension found on this driver");
738-
return UR_RESULT_ERROR_INVALID_VALUE;
739-
}
721+
static ur_result_t
722+
setKernelGlobalOffset(ur_exp_command_buffer_handle_t CommandBuffer,
723+
ur_kernel_handle_t Kernel,
724+
const size_t *GlobalWorkOffset) {
740725

741-
ZE2UR_CALL(zeKernelSetGlobalOffsetExp,
742-
(Kernel->ZeKernel, GlobalWorkOffset[0], GlobalWorkOffset[1],
743-
GlobalWorkOffset[2]));
726+
if (!CommandBuffer->Context->getPlatform()
727+
->ZeDriverGlobalOffsetExtensionFound) {
728+
logger::debug("No global offset extension found on this driver");
729+
return UR_RESULT_ERROR_INVALID_VALUE;
744730
}
745731

732+
ZE2UR_CALL(zeKernelSetGlobalOffsetExp,
733+
(Kernel->ZeKernel, GlobalWorkOffset[0], GlobalWorkOffset[1],
734+
GlobalWorkOffset[2]));
735+
736+
return UR_RESULT_SUCCESS;
737+
}
738+
739+
static ur_result_t
740+
setKernelPendingArguments(ur_exp_command_buffer_handle_t CommandBuffer,
741+
ur_kernel_handle_t Kernel) {
742+
746743
// If there are any pending arguments set them now.
747744
for (auto &Arg : Kernel->PendingArguments) {
748745
// The ArgValue may be a NULL pointer in which case a NULL value is used for
@@ -757,25 +754,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
757754
}
758755
Kernel->PendingArguments.clear();
759756

760-
ze_group_count_t ZeThreadGroupDimensions{1, 1, 1};
761-
uint32_t WG[3];
762-
763-
UR_CALL(calculateKernelWorkDimensions(Kernel, CommandBuffer->Device,
764-
ZeThreadGroupDimensions, WG, WorkDim,
765-
GlobalWorkSize, LocalWorkSize));
766-
767-
ZE2UR_CALL(zeKernelSetGroupSize, (Kernel->ZeKernel, WG[0], WG[1], WG[2]));
757+
return UR_RESULT_SUCCESS;
758+
}
768759

769-
CommandBuffer->KernelsList.push_back(Kernel);
770-
// Increment the reference count of the Kernel and indicate that the Kernel
771-
// is in use. Once the event has been signaled, the code in
772-
// CleanupCompletedEvent(Event) will do a urKernelRelease to update the
773-
// reference count on the kernel, using the kernel saved in CommandData.
774-
UR_CALL(urKernelRetain(Kernel));
760+
static ur_result_t
761+
createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
762+
ur_kernel_handle_t Kernel, uint32_t WorkDim,
763+
const size_t *LocalWorkSize,
764+
ur_exp_command_buffer_command_handle_t& Command) {
775765

776766
// If command-buffer is updatable then get command id which is going to be
777767
// used if command is updated in the future. This
778-
// zeCommandListGetNextCommandIdExp can be called only if command is
768+
// zeCommandListGetNextCommandIdExp can be called only if the command is
779769
// updatable.
780770
uint64_t CommandId = 0;
781771
if (CommandBuffer->IsUpdatable) {
@@ -794,15 +784,60 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
794784
DEBUG_LOG(CommandId);
795785
}
796786
try {
797-
if (Command)
798-
*Command = new ur_exp_command_buffer_command_handle_t_(
799-
CommandBuffer, CommandId, WorkDim, LocalWorkSize != nullptr, Kernel);
787+
Command = new ur_exp_command_buffer_command_handle_t_(
788+
CommandBuffer, CommandId, WorkDim, LocalWorkSize != nullptr, Kernel);
800789
} catch (const std::bad_alloc &) {
801790
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
802791
} catch (...) {
803792
return UR_RESULT_ERROR_UNKNOWN;
804793
}
805794

795+
return UR_RESULT_SUCCESS;
796+
}
797+
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
798+
ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
799+
uint32_t WorkDim, const size_t *GlobalWorkOffset,
800+
const size_t *GlobalWorkSize, const size_t *LocalWorkSize,
801+
uint32_t NumSyncPointsInWaitList,
802+
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
803+
ur_exp_command_buffer_sync_point_t *RetSyncPoint,
804+
ur_exp_command_buffer_command_handle_t *Command) {
805+
UR_ASSERT(Kernel->Program, UR_RESULT_ERROR_INVALID_NULL_POINTER);
806+
807+
// Lock automatically releases when this goes out of scope.
808+
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock(
809+
Kernel->Mutex, Kernel->Program->Mutex, CommandBuffer->Mutex);
810+
811+
if (GlobalWorkOffset != NULL) {
812+
UR_CALL(setKernelGlobalOffset(CommandBuffer, Kernel, GlobalWorkOffset));
813+
}
814+
815+
// If there are any pending arguments set them now.
816+
if (!Kernel->PendingArguments.empty()) {
817+
UR_CALL(setKernelPendingArguments(CommandBuffer, Kernel));
818+
}
819+
820+
ze_group_count_t ZeThreadGroupDimensions{1, 1, 1};
821+
uint32_t WG[3];
822+
UR_CALL(calculateKernelWorkDimensions(Kernel, CommandBuffer->Device,
823+
ZeThreadGroupDimensions, WG, WorkDim,
824+
GlobalWorkSize, LocalWorkSize));
825+
826+
ZE2UR_CALL(zeKernelSetGroupSize, (Kernel->ZeKernel, WG[0], WG[1], WG[2]));
827+
828+
CommandBuffer->KernelsList.push_back(Kernel);
829+
830+
// Increment the reference count of the Kernel and indicate that the Kernel
831+
// is in use. Once the event has been signaled, the code in
832+
// CleanupCompletedEvent(Event) will do a urKernelRelease to update the
833+
// reference count on the kernel, using the kernel saved in CommandData.
834+
UR_CALL(urKernelRetain(Kernel));
835+
836+
if (Command && CommandBuffer->IsUpdatable) {
837+
UR_CALL(createCommandHandle(CommandBuffer, Kernel, WorkDim, LocalWorkSize,
838+
*Command));
839+
}
840+
806841
if (CommandBuffer->IsInOrderCmdList) {
807842
ZE2UR_CALL(zeCommandListAppendLaunchKernel,
808843
(CommandBuffer->ZeComputeCommandList, Kernel->ZeKernel,

0 commit comments

Comments
 (0)