Skip to content

Commit da92770

Browse files
committed
Kernel Update
1 parent 802a3ce commit da92770

File tree

1 file changed

+113
-69
lines changed

1 file changed

+113
-69
lines changed

source/adapters/level_zero/command_buffer.cpp

Lines changed: 113 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ static ur_result_t enqueueCommandBufferMemCopyHelper(
406406
} else {
407407
// FIXME Why doesn't the event need to be host visible
408408
std::vector<ze_event_handle_t> ZeEventList;
409-
ur_event_handle_t LaunchEvent;
409+
ur_event_handle_t LaunchEvent = nullptr;
410410
UR_CALL(createSyncPoint(CommandType, CommandBuffer, NumSyncPointsInWaitList,
411411
SyncPointWaitList, RetSyncPoint, false, ZeEventList,
412412
LaunchEvent));
@@ -761,7 +761,7 @@ static ur_result_t
761761
createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
762762
ur_kernel_handle_t Kernel, uint32_t WorkDim,
763763
const size_t *LocalWorkSize,
764-
ur_exp_command_buffer_command_handle_t& Command) {
764+
ur_exp_command_buffer_command_handle_t &Command) {
765765

766766
// If command-buffer is updatable then get command id which is going to be
767767
// used if command is updated in the future. This
@@ -1371,20 +1371,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferReleaseCommandExp(
13711371
return UR_RESULT_SUCCESS;
13721372
}
13731373

1374-
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
1374+
static ur_result_t validateCommandDesc(
13751375
ur_exp_command_buffer_command_handle_t Command,
13761376
const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
1377-
UR_ASSERT(Command->Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
1378-
UR_ASSERT(CommandDesc->newWorkDim <= 3,
1379-
UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
13801377

1381-
// Lock command, kernel and command buffer for update.
1382-
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Guard(
1383-
Command->Mutex, Command->CommandBuffer->Mutex, Command->Kernel->Mutex);
1384-
UR_ASSERT(Command->CommandBuffer->IsUpdatable,
1385-
UR_RESULT_ERROR_INVALID_OPERATION);
1386-
UR_ASSERT(Command->CommandBuffer->IsFinalized,
1387-
UR_RESULT_ERROR_INVALID_OPERATION);
1378+
auto CommandBuffer = Command->CommandBuffer;
1379+
auto SupportedFeatures =
1380+
Command->CommandBuffer->Device->ZeDeviceMutableCmdListsProperties
1381+
->mutableCommandFlags;
1382+
logger::debug("Mutable features supported by device {}", SupportedFeatures);
13881383

13891384
uint32_t Dim = CommandDesc->newWorkDim;
13901385
if (Dim != 0) {
@@ -1409,25 +1404,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
14091404
}
14101405
}
14111406

1412-
auto CommandBuffer = Command->CommandBuffer;
1413-
const void *NextDesc = nullptr;
1414-
auto SupportedFeatures =
1415-
Command->CommandBuffer->Device->ZeDeviceMutableCmdListsProperties
1416-
->mutableCommandFlags;
1417-
logger::debug("Mutable features supported by device {}", SupportedFeatures);
1418-
1419-
// We need the created descriptors to live till the point when
1420-
// zexCommandListUpdateMutableCommandsExp is called at the end of the
1421-
// function.
1422-
std::vector<std::unique_ptr<ZeStruct<ze_mutable_kernel_argument_exp_desc_t>>>
1423-
ArgDescs;
1424-
std::vector<std::unique_ptr<ZeStruct<ze_mutable_global_offset_exp_desc_t>>>
1425-
OffsetDescs;
1426-
std::vector<std::unique_ptr<ZeStruct<ze_mutable_group_size_exp_desc_t>>>
1427-
GroupSizeDescs;
1428-
std::vector<std::unique_ptr<ZeStruct<ze_mutable_group_count_exp_desc_t>>>
1429-
GroupCountDescs;
1430-
14311407
// Check if new global offset is provided.
14321408
size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset;
14331409
UR_ASSERT(!NewGlobalWorkOffset ||
@@ -1439,6 +1415,56 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
14391415
logger::error("No global offset extension found on this driver");
14401416
return UR_RESULT_ERROR_INVALID_VALUE;
14411417
}
1418+
}
1419+
1420+
// Check if new group size is provided.
1421+
size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize;
1422+
UR_ASSERT(!NewLocalWorkSize ||
1423+
(SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE),
1424+
UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1425+
1426+
// Check if new global size is provided and we need to update group count.
1427+
size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize;
1428+
UR_ASSERT(!NewGlobalWorkSize ||
1429+
(SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_COUNT),
1430+
UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1431+
UR_ASSERT(!(NewGlobalWorkSize && !NewLocalWorkSize) ||
1432+
(SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE),
1433+
UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1434+
1435+
UR_ASSERT(
1436+
(!CommandDesc->numNewMemObjArgs && !CommandDesc->numNewPointerArgs &&
1437+
!CommandDesc->numNewValueArgs) ||
1438+
(SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS),
1439+
UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1440+
1441+
return UR_RESULT_SUCCESS;
1442+
}
1443+
1444+
static ur_result_t updateKernelCommand(
1445+
ur_exp_command_buffer_command_handle_t Command,
1446+
const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
1447+
1448+
// We need the created descriptors to live till the point when
1449+
// zeCommandListUpdateMutableCommandsExp is called at the end of the
1450+
// function.
1451+
std::vector<std::variant<
1452+
std::unique_ptr<ZeStruct<ze_mutable_kernel_argument_exp_desc_t>>,
1453+
std::unique_ptr<ZeStruct<ze_mutable_global_offset_exp_desc_t>>,
1454+
std::unique_ptr<ZeStruct<ze_mutable_group_size_exp_desc_t>>,
1455+
std::unique_ptr<ZeStruct<ze_mutable_group_count_exp_desc_t>>>>
1456+
Descs;
1457+
1458+
const auto CommandBuffer = Command->CommandBuffer;
1459+
const void *NextDesc = nullptr;
1460+
1461+
uint32_t Dim = CommandDesc->newWorkDim;
1462+
size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset;
1463+
size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize;
1464+
size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize;
1465+
1466+
// Check if a new global offset is provided.
1467+
if (NewGlobalWorkOffset && Dim > 0) {
14421468
auto MutableGroupOffestDesc =
14431469
std::make_unique<ZeStruct<ze_mutable_global_offset_exp_desc_t>>();
14441470
MutableGroupOffestDesc->commandId = Command->CommandId;
@@ -1451,15 +1477,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
14511477
DEBUG_LOG(MutableGroupOffestDesc->offsetY);
14521478
MutableGroupOffestDesc->offsetZ = Dim == 3 ? NewGlobalWorkOffset[2] : 0;
14531479
DEBUG_LOG(MutableGroupOffestDesc->offsetZ);
1480+
14541481
NextDesc = MutableGroupOffestDesc.get();
1455-
OffsetDescs.push_back(std::move(MutableGroupOffestDesc));
1482+
Descs.push_back(std::move(MutableGroupOffestDesc));
14561483
}
14571484

1458-
// Check if new group size is provided.
1459-
size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize;
1460-
UR_ASSERT(!NewLocalWorkSize ||
1461-
(SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE),
1462-
UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1485+
// Check if a new group size is provided.
14631486
if (NewLocalWorkSize && Dim > 0) {
14641487
auto MutableGroupSizeDesc =
14651488
std::make_unique<ZeStruct<ze_mutable_group_size_exp_desc_t>>();
@@ -1473,29 +1496,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
14731496
DEBUG_LOG(MutableGroupSizeDesc->groupSizeY);
14741497
MutableGroupSizeDesc->groupSizeZ = Dim == 3 ? NewLocalWorkSize[2] : 1;
14751498
DEBUG_LOG(MutableGroupSizeDesc->groupSizeZ);
1499+
14761500
NextDesc = MutableGroupSizeDesc.get();
1477-
GroupSizeDescs.push_back(std::move(MutableGroupSizeDesc));
1501+
Descs.push_back(std::move(MutableGroupSizeDesc));
14781502
}
14791503

1480-
// Check if new global size is provided and we need to update group count.
1481-
size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize;
1482-
UR_ASSERT(!NewGlobalWorkSize ||
1483-
(SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_COUNT),
1484-
UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1485-
UR_ASSERT(!(NewGlobalWorkSize && !NewLocalWorkSize) ||
1486-
(SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE),
1487-
UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1488-
1504+
// Check if a new global size is provided and if we need to update the group
1505+
// count.
14891506
ze_group_count_t ZeThreadGroupDimensions{1, 1, 1};
14901507
if (NewGlobalWorkSize && Dim > 0) {
1491-
uint32_t WG[3];
1492-
// If new global work size is provided but new local work size is not
1493-
// provided then we still need to update local work size based on size
1494-
// suggested by the driver for the kernel.
1508+
// If a new global work size is provided but a new local work size is not
1509+
// then we still need to update local work size based on the size suggested
1510+
// by the driver for the kernel.
14951511
bool UpdateWGSize = NewLocalWorkSize == nullptr;
1512+
1513+
uint32_t WG[3];
14961514
UR_CALL(calculateKernelWorkDimensions(
14971515
Command->Kernel, CommandBuffer->Device, ZeThreadGroupDimensions, WG,
14981516
Dim, NewGlobalWorkSize, NewLocalWorkSize));
1517+
14991518
auto MutableGroupCountDesc =
15001519
std::make_unique<ZeStruct<ze_mutable_group_count_exp_desc_t>>();
15011520
MutableGroupCountDesc->commandId = Command->CommandId;
@@ -1506,8 +1525,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
15061525
DEBUG_LOG(MutableGroupCountDesc->pGroupCount->groupCountX);
15071526
DEBUG_LOG(MutableGroupCountDesc->pGroupCount->groupCountY);
15081527
DEBUG_LOG(MutableGroupCountDesc->pGroupCount->groupCountZ);
1528+
15091529
NextDesc = MutableGroupCountDesc.get();
1510-
GroupCountDescs.push_back(std::move(MutableGroupCountDesc));
1530+
Descs.push_back(std::move(MutableGroupCountDesc));
15111531

15121532
if (UpdateWGSize) {
15131533
auto MutableGroupSizeDesc =
@@ -1524,16 +1544,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
15241544
DEBUG_LOG(MutableGroupSizeDesc->groupSizeZ);
15251545

15261546
NextDesc = MutableGroupSizeDesc.get();
1527-
GroupSizeDescs.push_back(std::move(MutableGroupSizeDesc));
1547+
Descs.push_back(std::move(MutableGroupSizeDesc));
15281548
}
15291549
}
15301550

1531-
UR_ASSERT(
1532-
(!CommandDesc->numNewMemObjArgs && !CommandDesc->numNewPointerArgs &&
1533-
!CommandDesc->numNewValueArgs) ||
1534-
(SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS),
1535-
UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1536-
15371551
// Check if new memory object arguments are provided.
15381552
for (uint32_t NewMemObjArgNum = CommandDesc->numNewMemObjArgs;
15391553
NewMemObjArgNum-- > 0;) {
@@ -1557,6 +1571,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
15571571
return UR_RESULT_ERROR_INVALID_ARGUMENT;
15581572
}
15591573
}
1574+
15601575
ur_mem_handle_t NewMemObjArg = NewMemObjArgDesc.hNewMemObjArg;
15611576
// The NewMemObjArg may be a NULL pointer in which case a NULL value is used
15621577
// for the kernel argument declared as a pointer to global or constant
@@ -1566,6 +1581,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
15661581
UR_CALL(NewMemObjArg->getZeHandlePtr(ZeHandlePtr, UrAccessMode,
15671582
CommandBuffer->Device));
15681583
}
1584+
15691585
auto ZeMutableArgDesc =
15701586
std::make_unique<ZeStruct<ze_mutable_kernel_argument_exp_desc_t>>();
15711587
ZeMutableArgDesc->commandId = Command->CommandId;
@@ -1580,14 +1596,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
15801596
DEBUG_LOG(ZeMutableArgDesc->pArgValue);
15811597

15821598
NextDesc = ZeMutableArgDesc.get();
1583-
ArgDescs.push_back(std::move(ZeMutableArgDesc));
1599+
Descs.push_back(std::move(ZeMutableArgDesc));
15841600
}
15851601

15861602
// Check if there are new pointer arguments.
15871603
for (uint32_t NewPointerArgNum = CommandDesc->numNewPointerArgs;
15881604
NewPointerArgNum-- > 0;) {
15891605
ur_exp_command_buffer_update_pointer_arg_desc_t NewPointerArgDesc =
15901606
CommandDesc->pNewPointerArgList[NewPointerArgNum];
1607+
15911608
auto ZeMutableArgDesc =
15921609
std::make_unique<ZeStruct<ze_mutable_kernel_argument_exp_desc_t>>();
15931610
ZeMutableArgDesc->commandId = Command->CommandId;
@@ -1602,14 +1619,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
16021619
DEBUG_LOG(ZeMutableArgDesc->pArgValue);
16031620

16041621
NextDesc = ZeMutableArgDesc.get();
1605-
ArgDescs.push_back(std::move(ZeMutableArgDesc));
1622+
Descs.push_back(std::move(ZeMutableArgDesc));
16061623
}
16071624

16081625
// Check if there are new value arguments.
16091626
for (uint32_t NewValueArgNum = CommandDesc->numNewValueArgs;
16101627
NewValueArgNum-- > 0;) {
16111628
ur_exp_command_buffer_update_value_arg_desc_t NewValueArgDesc =
16121629
CommandDesc->pNewValueArgList[NewValueArgNum];
1630+
16131631
auto ZeMutableArgDesc =
16141632
std::make_unique<ZeStruct<ze_mutable_kernel_argument_exp_desc_t>>();
16151633
ZeMutableArgDesc->commandId = Command->CommandId;
@@ -1634,26 +1652,52 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
16341652
}
16351653
ZeMutableArgDesc->pArgValue = ArgValuePtr;
16361654
DEBUG_LOG(ZeMutableArgDesc->pArgValue);
1655+
16371656
NextDesc = ZeMutableArgDesc.get();
1638-
ArgDescs.push_back(std::move(ZeMutableArgDesc));
1657+
Descs.push_back(std::move(ZeMutableArgDesc));
16391658
}
16401659

16411660
ZeStruct<ze_mutable_commands_exp_desc_t> MutableCommandDesc;
16421661
MutableCommandDesc.pNext = NextDesc;
16431662
MutableCommandDesc.flags = 0;
16441663

1645-
// We must synchronize mutable command list execution before mutating.
1646-
if (ze_fence_handle_t &ZeFence = CommandBuffer->ZeActiveFence) {
1647-
ZE2UR_CALL(zeFenceHostSynchronize, (ZeFence, UINT64_MAX));
1648-
}
1649-
16501664
auto Plt = CommandBuffer->Context->getPlatform();
16511665
UR_ASSERT(Plt->ZeMutableCmdListExt.Supported,
16521666
UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
16531667
ZE2UR_CALL(
16541668
Plt->ZeMutableCmdListExt.zexCommandListUpdateMutableCommandsExp,
16551669
(CommandBuffer->ZeComputeCommandListTranslated, &MutableCommandDesc));
1656-
ZE2UR_CALL(zeCommandListClose, (CommandBuffer->ZeComputeCommandList));
1670+
1671+
return UR_RESULT_SUCCESS;
1672+
}
1673+
1674+
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
1675+
ur_exp_command_buffer_command_handle_t Command,
1676+
const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
1677+
UR_ASSERT(Command->Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
1678+
UR_ASSERT(CommandDesc->newWorkDim <= 3,
1679+
UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
1680+
1681+
// Lock command, kernel and command buffer for update.
1682+
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Guard(
1683+
Command->Mutex, Command->CommandBuffer->Mutex, Command->Kernel->Mutex);
1684+
1685+
UR_ASSERT(Command->CommandBuffer->IsUpdatable,
1686+
UR_RESULT_ERROR_INVALID_OPERATION);
1687+
UR_ASSERT(Command->CommandBuffer->IsFinalized,
1688+
UR_RESULT_ERROR_INVALID_OPERATION);
1689+
1690+
UR_CALL(validateCommandDesc(Command, CommandDesc));
1691+
1692+
// We must synchronize mutable command list execution before mutating.
1693+
if (ze_fence_handle_t &ZeFence = Command->CommandBuffer->ZeActiveFence) {
1694+
ZE2UR_CALL(zeFenceHostSynchronize, (ZeFence, UINT64_MAX));
1695+
}
1696+
1697+
UR_CALL(updateKernelCommand(Command, CommandDesc));
1698+
1699+
ZE2UR_CALL(zeCommandListClose,
1700+
(Command->CommandBuffer->ZeComputeCommandList));
16571701

16581702
return UR_RESULT_SUCCESS;
16591703
}

0 commit comments

Comments
 (0)