Skip to content

Commit a1a317a

Browse files
committed
Address review comments
1 parent a71c1d5 commit a1a317a

File tree

2 files changed

+93
-96
lines changed

2 files changed

+93
-96
lines changed

source/adapters/level_zero/command_buffer.cpp

Lines changed: 91 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace {
2323
// Gets a C pointer from a vector. If the vector is empty returns nullptr
2424
// instead. This is different from the behaviour of the data() member function
2525
// of the vector class which might not return nullptr when the vector is empty.
26-
template <typename T> static T *getPointerFromVector(std::vector<T> &V) {
26+
template <typename T> T *getPointerFromVector(std::vector<T> &V) {
2727
return V.size() == 0 ? nullptr : V.data();
2828
}
2929

@@ -33,10 +33,10 @@ template <typename T> static T *getPointerFromVector(std::vector<T> &V) {
3333
* @param[in] VersionMajor Major version number to compare to.
3434
* @param[in] VersionMinor Minor version number to compare to.
3535
* @param[in] VersionBuild Build version number to compare to.
36-
* @return true is the version of the driver is higher than or equal to the
36+
* @return true if the version of the driver is higher than or equal to the
3737
* compared version.
3838
*/
39-
bool IsDriverVersionNewerOrSimilar(ur_context_handle_t Context,
39+
bool isDriverVersionNewerOrSimilar(ur_context_handle_t Context,
4040
uint32_t VersionMajor, uint32_t VersionMinor,
4141
uint32_t VersionBuild) {
4242
ZeStruct<ze_driver_properties_t> ZeDriverProperties;
@@ -64,15 +64,18 @@ bool IsDriverVersionNewerOrSimilar(ur_context_handle_t Context,
6464
* @return UR_RESULT_SUCCESS or an error code on failure
6565
*/
6666
ur_result_t
67-
PreferCopyEngineForFill(ur_exp_command_buffer_handle_t CommandBuffer,
67+
preferCopyEngineForFill(ur_exp_command_buffer_handle_t CommandBuffer,
6868
size_t PatternSize, bool &PreferCopyEngine) {
69-
const char *UrRet = std::getenv("UR_L0_USE_COPY_ENGINE_FOR_FILL");
70-
const char *PiRet =
71-
std::getenv("SYCL_PI_LEVEL_ZERO_USE_COPY_ENGINE_FOR_FILL");
69+
assert(PatternSize > 0);
70+
71+
PreferCopyEngine = false;
72+
if (!CommandBuffer->UseCopyEngine()) {
73+
return UR_RESULT_SUCCESS;
74+
}
7275

73-
// If the copy engine available and PatternSize is valid, the command is
74-
// enqueued in the ZeCopyCommandList, otherwise enqueue it in the compute
75-
// command list.
76+
// If the copy engine is available, and it supports this pattern size, the
77+
// command should be enqueued in the copy command list, otherwise enqueue it
78+
// in the compute command list.
7679
PreferCopyEngine =
7780
PatternSize <=
7881
CommandBuffer->Device
@@ -89,6 +92,10 @@ PreferCopyEngineForFill(ur_exp_command_buffer_handle_t CommandBuffer,
8992
UR_RESULT_ERROR_INVALID_VALUE);
9093
}
9194

95+
const char *UrRet = std::getenv("UR_L0_USE_COPY_ENGINE_FOR_FILL");
96+
const char *PiRet =
97+
std::getenv("SYCL_PI_LEVEL_ZERO_USE_COPY_ENGINE_FOR_FILL");
98+
9299
PreferCopyEngine =
93100
PreferCopyEngine &&
94101
(UrRet ? std::stoi(UrRet) : (PiRet ? std::stoi(PiRet) : 0));
@@ -253,49 +260,52 @@ ur_result_t getEventsFromSyncPoints(
253260
}
254261

255262
/**
256-
* If needed, creates a sync point for a given command.
263+
* If needed, creates a sync point for a given command and returns the L0
264+
* events associated with the sync point.
257265
* This operations is skipped if the command buffer is in order.
258266
* @param[in] CommandType The type of the command.
259267
* @param[in] CommandBuffer The CommandBuffer where the command is appended.
260268
* @param[in] NumSyncPointsInWaitList Number of sync points that are
261269
* dependencies for the command.
262270
* @param[in] SyncPointWaitList List of sync point that are dependencies for the
263271
* command.
264-
* @param[out][optional] RetSyncPoint The new sync point.
265-
* @param[in] host_visible Whether the event associated with the sync point
272+
* @param[in] HostVisible Whether the event associated with the sync point
266273
* should be host visible.
274+
* @param[out][optional] RetSyncPoint The new sync point.
267275
* @param[out] ZeEventList A list of L0 events that are dependencies for this
268276
* sync point.
269277
* @param[out] ZeLaunchEvent The L0 event associated with this sync point.
270278
* @return UR_RESULT_SUCCESS or an error code on failure
271279
*/
272-
ur_result_t createSyncPointIfNeeded(
280+
ur_result_t createSyncPointAndGetZeEvents(
273281
ur_command_t CommandType, ur_exp_command_buffer_handle_t CommandBuffer,
274282
uint32_t NumSyncPointsInWaitList,
275283
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
276-
ur_exp_command_buffer_sync_point_t *RetSyncPoint, bool host_visible,
284+
bool HostVisible, ur_exp_command_buffer_sync_point_t *RetSyncPoint,
277285
std::vector<ze_event_handle_t> &ZeEventList,
278286
ze_event_handle_t &ZeLaunchEvent) {
279287

280288
ZeLaunchEvent = nullptr;
281-
if (!CommandBuffer->IsInOrderCmdList) {
282-
UR_CALL(getEventsFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
283-
SyncPointWaitList, ZeEventList));
284-
ur_event_handle_t LaunchEvent;
285-
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, false, host_visible,
286-
&LaunchEvent, false,
287-
!CommandBuffer->IsProfilingEnabled));
288-
LaunchEvent->CommandType = CommandType;
289-
ZeLaunchEvent = LaunchEvent->ZeEvent;
290-
291-
// Get sync point and register the event with it.
292-
ur_exp_command_buffer_sync_point_t SyncPoint =
293-
CommandBuffer->GetNextSyncPoint();
294-
CommandBuffer->RegisterSyncPoint(SyncPoint, LaunchEvent);
295-
296-
if (RetSyncPoint) {
297-
*RetSyncPoint = SyncPoint;
298-
}
289+
290+
if (CommandBuffer->IsInOrderCmdList) {
291+
return UR_RESULT_SUCCESS;
292+
}
293+
294+
UR_CALL(getEventsFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
295+
SyncPointWaitList, ZeEventList));
296+
ur_event_handle_t LaunchEvent;
297+
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, false, HostVisible,
298+
&LaunchEvent, false, !CommandBuffer->IsProfilingEnabled));
299+
LaunchEvent->CommandType = CommandType;
300+
ZeLaunchEvent = LaunchEvent->ZeEvent;
301+
302+
// Get sync point and register the event with it.
303+
ur_exp_command_buffer_sync_point_t SyncPoint =
304+
CommandBuffer->GetNextSyncPoint();
305+
CommandBuffer->RegisterSyncPoint(SyncPoint, LaunchEvent);
306+
307+
if (RetSyncPoint) {
308+
*RetSyncPoint = SyncPoint;
299309
}
300310

301311
return UR_RESULT_SUCCESS;
@@ -313,12 +323,12 @@ ur_result_t enqueueCommandBufferMemCopyHelper(
313323

314324
std::vector<ze_event_handle_t> ZeEventList;
315325
ze_event_handle_t ZeLaunchEvent = nullptr;
316-
UR_CALL(createSyncPointIfNeeded(
326+
UR_CALL(createSyncPointAndGetZeEvents(
317327
CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
318-
RetSyncPoint, false, ZeEventList, ZeLaunchEvent));
328+
false, RetSyncPoint, ZeEventList, ZeLaunchEvent));
319329

320-
ze_command_list_handle_t ZeCommandList;
321-
UR_CALL(CommandBuffer->chooseCommandList(PreferCopyEngine, &ZeCommandList));
330+
ze_command_list_handle_t ZeCommandList =
331+
CommandBuffer->chooseCommandList(PreferCopyEngine);
322332

323333
logger::debug("calling zeCommandListAppendMemoryCopy()");
324334
ZE2UR_CALL(zeCommandListAppendMemoryCopy,
@@ -372,12 +382,12 @@ ur_result_t enqueueCommandBufferMemCopyRectHelper(
372382

373383
std::vector<ze_event_handle_t> ZeEventList;
374384
ze_event_handle_t ZeLaunchEvent = nullptr;
375-
UR_CALL(createSyncPointIfNeeded(
385+
UR_CALL(createSyncPointAndGetZeEvents(
376386
CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
377-
RetSyncPoint, false, ZeEventList, ZeLaunchEvent));
387+
false, RetSyncPoint, ZeEventList, ZeLaunchEvent));
378388

379-
ze_command_list_handle_t ZeCommandList;
380-
UR_CALL(CommandBuffer->chooseCommandList(PreferCopyEngine, &ZeCommandList));
389+
ze_command_list_handle_t ZeCommandList =
390+
CommandBuffer->chooseCommandList(PreferCopyEngine);
381391

382392
logger::debug("calling zeCommandListAppendMemoryCopyRegion()");
383393
ZE2UR_CALL(zeCommandListAppendMemoryCopyRegion,
@@ -401,16 +411,16 @@ ur_result_t enqueueCommandBufferFillHelper(
401411

402412
std::vector<ze_event_handle_t> ZeEventList;
403413
ze_event_handle_t ZeLaunchEvent = nullptr;
404-
UR_CALL(createSyncPointIfNeeded(
414+
UR_CALL(createSyncPointAndGetZeEvents(
405415
CommandType, CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList,
406-
RetSyncPoint, true, ZeEventList, ZeLaunchEvent));
416+
true, RetSyncPoint, ZeEventList, ZeLaunchEvent));
407417

408418
bool PreferCopyEngine;
409419
UR_CALL(
410-
PreferCopyEngineForFill(CommandBuffer, PatternSize, PreferCopyEngine));
420+
preferCopyEngineForFill(CommandBuffer, PatternSize, PreferCopyEngine));
411421

412-
ze_command_list_handle_t ZeCommandList;
413-
UR_CALL(CommandBuffer->chooseCommandList(PreferCopyEngine, &ZeCommandList));
422+
ze_command_list_handle_t ZeCommandList =
423+
CommandBuffer->chooseCommandList(PreferCopyEngine);
414424

415425
logger::debug("calling zeCommandListAppendMemoryFill()");
416426
ZE2UR_CALL(zeCommandListAppendMemoryFill,
@@ -549,15 +559,14 @@ void ur_exp_command_buffer_handle_t_::RegisterSyncPoint(
549559
ZeEventsList.push_back(Event->ZeEvent);
550560
}
551561

552-
ur_result_t ur_exp_command_buffer_handle_t_::chooseCommandList(
553-
bool PreferCopyEngine, ze_command_list_handle_t *ZeCommandList) {
562+
ze_command_list_handle_t
563+
ur_exp_command_buffer_handle_t_::chooseCommandList(bool PreferCopyEngine) {
554564
if (PreferCopyEngine && this->UseCopyEngine() && !this->IsInOrderCmdList) {
555565
// We indicate that ZeCopyCommandList contains commands to be submitted.
556566
this->MCopyCommandListEmpty = false;
557-
*ZeCommandList = this->ZeCopyCommandList;
567+
return this->ZeCopyCommandList;
558568
}
559-
*ZeCommandList = this->ZeComputeCommandList;
560-
return UR_RESULT_SUCCESS;
569+
return this->ZeComputeCommandList;
561570
}
562571

563572
ur_result_t ur_exp_command_buffer_handle_t_::getFenceForQueue(
@@ -584,19 +593,19 @@ namespace {
584593
* @param[in] Context The Context associated with the command-list
585594
* @param[in] Device The Device associated with the command-list
586595
* @param[in] IsInOrder Whether the command-list should be in-order.
587-
* @param[in] isUpdatable Whether the command-list should be mutable.
588-
* @param[in] isCopy Whether to use copy-engine for the the new command-list.
596+
* @param[in] IsUpdatable Whether the command-list should be mutable.
597+
* @param[in] IsCopy Whether to use copy-engine for the the new command-list.
589598
* @param[out] CommandList The L0 command-list created by this function.
590599
* @return UR_RESULT_SUCCESS or an error code on failure
591600
*/
592601
ur_result_t createMainCommandList(ur_context_handle_t Context,
593602
ur_device_handle_t Device, bool IsInOrder,
594-
bool isUpdatable, bool isCopy,
603+
bool IsUpdatable, bool IsCopy,
595604
ze_command_list_handle_t &CommandList) {
596605

597-
auto type = isCopy ? ur_device_handle_t_::queue_group_info_t::type::MainCopy
606+
auto Type = IsCopy ? ur_device_handle_t_::queue_group_info_t::type::MainCopy
598607
: ur_device_handle_t_::queue_group_info_t::type::Compute;
599-
uint32_t QueueGroupOrdinal = Device->QueueGroup[type].ZeOrdinal;
608+
uint32_t QueueGroupOrdinal = Device->QueueGroup[Type].ZeOrdinal;
600609

601610
ZeStruct<ze_command_list_desc_t> ZeCommandListDesc;
602611
ZeCommandListDesc.commandQueueGroupOrdinal = QueueGroupOrdinal;
@@ -610,7 +619,10 @@ ur_result_t createMainCommandList(ur_context_handle_t Context,
610619
DEBUG_LOG(ZeCommandListDesc.flags);
611620

612621
ZeStruct<ze_mutable_command_list_exp_desc_t> ZeMutableCommandListDesc;
613-
if (isUpdatable) {
622+
if (IsUpdatable) {
623+
auto Platform = Context->getPlatform();
624+
UR_ASSERT(Platform->ZeMutableCmdListExt.Supported,
625+
UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
614626
ZeMutableCommandListDesc.flags = 0;
615627
ZeCommandListDesc.pNext = &ZeMutableCommandListDesc;
616628
}
@@ -621,19 +633,6 @@ ur_result_t createMainCommandList(ur_context_handle_t Context,
621633
return UR_RESULT_SUCCESS;
622634
}
623635

624-
/**
625-
* Appends a barrier to CommandList that waits for all the Events in EventList
626-
* @param[in] CommandList The command-list where the barrier should be appended.
627-
* @param[in] ZeEventList A list of events to wait for.
628-
* @return UR_RESULT_SUCCESS or an error code on failure
629-
*/
630-
ur_result_t waitForEvents(ze_command_list_handle_t CommandList,
631-
std::vector<ze_event_handle_t> &ZeEventList) {
632-
ZE2UR_CALL(zeCommandListAppendBarrier,
633-
(CommandList, nullptr, ZeEventList.size(), ZeEventList.data()));
634-
return UR_RESULT_SUCCESS;
635-
}
636-
637636
/**
638637
* Checks whether the command buffer can be constructed using in order
639638
* command-lists.
@@ -644,7 +643,7 @@ ur_result_t waitForEvents(ze_command_list_handle_t CommandList,
644643
bool canBeInOrder(ur_context_handle_t Context,
645644
const ur_exp_command_buffer_desc_t *CommandBufferDesc) {
646645
// In-order command-lists are not available in old driver version.
647-
bool CompatibleDriver = IsDriverVersionNewerOrSimilar(Context, 1, 3, 28454);
646+
bool CompatibleDriver = isDriverVersionNewerOrSimilar(Context, 1, 3, 28454);
648647
return CompatibleDriver
649648
? (CommandBufferDesc ? CommandBufferDesc->isInOrder : false)
650649
: false;
@@ -657,7 +656,7 @@ urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device,
657656
ur_exp_command_buffer_handle_t *CommandBuffer) {
658657

659658
bool IsInOrder = canBeInOrder(Context, CommandBufferDesc);
660-
bool enableProfiling =
659+
bool EnableProfiling =
661660
CommandBufferDesc && CommandBufferDesc->enableProfiling;
662661
bool IsUpdatable = CommandBufferDesc && CommandBufferDesc->isUpdatable;
663662

@@ -666,18 +665,20 @@ urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device,
666665
ur_event_handle_t AllResetEvent;
667666

668667
UR_CALL(EventCreate(Context, nullptr, false, false, &SignalEvent, false,
669-
!enableProfiling));
668+
!EnableProfiling));
670669
UR_CALL(EventCreate(Context, nullptr, false, false, &WaitEvent, false,
671-
!enableProfiling));
670+
!EnableProfiling));
672671
UR_CALL(EventCreate(Context, nullptr, false, false, &AllResetEvent, false,
673-
!enableProfiling));
672+
!EnableProfiling));
674673
std::vector<ze_event_handle_t> PrecondEvents = {WaitEvent->ZeEvent,
675674
AllResetEvent->ZeEvent};
676675

677676
ze_command_list_handle_t ZeComputeCommandList = nullptr;
678677
UR_CALL(createMainCommandList(Context, Device, IsInOrder, IsUpdatable, false,
679678
ZeComputeCommandList));
680-
UR_CALL(waitForEvents(ZeComputeCommandList, PrecondEvents));
679+
ZE2UR_CALL(zeCommandListAppendBarrier,
680+
(ZeComputeCommandList, nullptr, PrecondEvents.size(),
681+
PrecondEvents.data()));
681682

682683
ze_command_list_handle_t ZeCommandListResetEvents = nullptr;
683684
UR_CALL(createMainCommandList(Context, Device, false, false, false,
@@ -692,7 +693,9 @@ urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device,
692693
if (Device->hasMainCopyEngine()) {
693694
UR_CALL(createMainCommandList(Context, Device, false, false, true,
694695
ZeCopyCommandList));
695-
UR_CALL(waitForEvents(ZeCopyCommandList, PrecondEvents));
696+
ZE2UR_CALL(zeCommandListAppendBarrier,
697+
(ZeCopyCommandList, nullptr, PrecondEvents.size(),
698+
PrecondEvents.data()));
696699
}
697700

698701
ze_command_list_handle_t ZeComputeCommandListTranslated = nullptr;
@@ -859,10 +862,8 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
859862
ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE |
860863
ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET;
861864

862-
auto Plt = CommandBuffer->Context->getPlatform();
863-
UR_ASSERT(Plt->ZeMutableCmdListExt.Supported,
864-
UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
865-
ZE2UR_CALL(Plt->ZeMutableCmdListExt.zexCommandListGetNextCommandIdExp,
865+
auto Platform = CommandBuffer->Context->getPlatform();
866+
ZE2UR_CALL(Platform->ZeMutableCmdListExt.zexCommandListGetNextCommandIdExp,
866867
(CommandBuffer->ZeComputeCommandListTranslated,
867868
&ZeMutableCommandDesc, &CommandId));
868869
DEBUG_LOG(CommandId);
@@ -926,9 +927,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
926927

927928
std::vector<ze_event_handle_t> ZeEventList;
928929
ze_event_handle_t ZeLaunchEvent = nullptr;
929-
UR_CALL(createSyncPointIfNeeded(
930+
UR_CALL(createSyncPointAndGetZeEvents(
930931
UR_COMMAND_KERNEL_LAUNCH, CommandBuffer, NumSyncPointsInWaitList,
931-
SyncPointWaitList, RetSyncPoint, false, ZeEventList, ZeLaunchEvent));
932+
SyncPointWaitList, false, RetSyncPoint, ZeEventList, ZeLaunchEvent));
932933

933934
logger::debug("calling zeCommandListAppendLaunchKernel()");
934935
ZE2UR_CALL(zeCommandListAppendLaunchKernel,
@@ -1123,9 +1124,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
11231124
} else {
11241125
std::vector<ze_event_handle_t> ZeEventList;
11251126
ze_event_handle_t ZeLaunchEvent = nullptr;
1126-
UR_CALL(createSyncPointIfNeeded(
1127+
UR_CALL(createSyncPointAndGetZeEvents(
11271128
UR_COMMAND_USM_PREFETCH, CommandBuffer, NumSyncPointsInWaitList,
1128-
SyncPointWaitList, RetSyncPoint, true, ZeEventList, ZeLaunchEvent));
1129+
SyncPointWaitList, true, RetSyncPoint, ZeEventList, ZeLaunchEvent));
11291130

11301131
if (NumSyncPointsInWaitList) {
11311132
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
@@ -1186,9 +1187,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
11861187
} else {
11871188
std::vector<ze_event_handle_t> ZeEventList;
11881189
ze_event_handle_t ZeLaunchEvent = nullptr;
1189-
UR_CALL(createSyncPointIfNeeded(
1190+
UR_CALL(createSyncPointAndGetZeEvents(
11901191
UR_COMMAND_USM_ADVISE, CommandBuffer, NumSyncPointsInWaitList,
1191-
SyncPointWaitList, RetSyncPoint, true, ZeEventList, ZeLaunchEvent));
1192+
SyncPointWaitList, true, RetSyncPoint, ZeEventList, ZeLaunchEvent));
11921193

11931194
if (NumSyncPointsInWaitList) {
11941195
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
@@ -1372,7 +1373,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
13721373
uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList,
13731374
ur_event_handle_t *Event) {
13741375
auto Queue = Legacy(UrQueue);
1375-
std::scoped_lock<ur_shared_mutex> lock(Queue->Mutex);
1376+
std::scoped_lock<ur_shared_mutex> Lock(Queue->Mutex);
13761377

13771378
ze_command_queue_handle_t ZeCommandQueue;
13781379
getZeCommandQueue(Queue, false, ZeCommandQueue);
@@ -1752,11 +1753,9 @@ ur_result_t updateKernelCommand(
17521753
MutableCommandDesc.pNext = NextDesc;
17531754
MutableCommandDesc.flags = 0;
17541755

1755-
auto Plt = CommandBuffer->Context->getPlatform();
1756-
UR_ASSERT(Plt->ZeMutableCmdListExt.Supported,
1757-
UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1756+
auto Platform = CommandBuffer->Context->getPlatform();
17581757
ZE2UR_CALL(
1759-
Plt->ZeMutableCmdListExt.zexCommandListUpdateMutableCommandsExp,
1758+
Platform->ZeMutableCmdListExt.zexCommandListUpdateMutableCommandsExp,
17601759
(CommandBuffer->ZeComputeCommandListTranslated, &MutableCommandDesc));
17611760

17621761
return UR_RESULT_SUCCESS;

0 commit comments

Comments
 (0)