@@ -406,7 +406,7 @@ static ur_result_t enqueueCommandBufferMemCopyHelper(
406
406
} else {
407
407
// FIXME Why doesn't the event need to be host visible
408
408
std::vector<ze_event_handle_t > ZeEventList;
409
- ur_event_handle_t LaunchEvent;
409
+ ur_event_handle_t LaunchEvent = nullptr ;
410
410
UR_CALL (createSyncPoint (CommandType, CommandBuffer, NumSyncPointsInWaitList,
411
411
SyncPointWaitList, RetSyncPoint, false , ZeEventList,
412
412
LaunchEvent));
@@ -761,7 +761,7 @@ static ur_result_t
761
761
createCommandHandle (ur_exp_command_buffer_handle_t CommandBuffer,
762
762
ur_kernel_handle_t Kernel, uint32_t WorkDim,
763
763
const size_t *LocalWorkSize,
764
- ur_exp_command_buffer_command_handle_t & Command) {
764
+ ur_exp_command_buffer_command_handle_t & Command) {
765
765
766
766
// If command-buffer is updatable then get command id which is going to be
767
767
// used if command is updated in the future. This
@@ -1371,20 +1371,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferReleaseCommandExp(
1371
1371
return UR_RESULT_SUCCESS;
1372
1372
}
1373
1373
1374
- UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp (
1374
+ static ur_result_t validateCommandDesc (
1375
1375
ur_exp_command_buffer_command_handle_t Command,
1376
1376
const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
1377
- UR_ASSERT (Command->Kernel , UR_RESULT_ERROR_INVALID_NULL_HANDLE);
1378
- UR_ASSERT (CommandDesc->newWorkDim <= 3 ,
1379
- UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
1380
1377
1381
- // Lock command, kernel and command buffer for update.
1382
- std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Guard (
1383
- Command->Mutex , Command->CommandBuffer ->Mutex , Command->Kernel ->Mutex );
1384
- UR_ASSERT (Command->CommandBuffer ->IsUpdatable ,
1385
- UR_RESULT_ERROR_INVALID_OPERATION);
1386
- UR_ASSERT (Command->CommandBuffer ->IsFinalized ,
1387
- UR_RESULT_ERROR_INVALID_OPERATION);
1378
+ auto CommandBuffer = Command->CommandBuffer ;
1379
+ auto SupportedFeatures =
1380
+ Command->CommandBuffer ->Device ->ZeDeviceMutableCmdListsProperties
1381
+ ->mutableCommandFlags ;
1382
+ logger::debug (" Mutable features supported by device {}" , SupportedFeatures);
1388
1383
1389
1384
uint32_t Dim = CommandDesc->newWorkDim ;
1390
1385
if (Dim != 0 ) {
@@ -1409,25 +1404,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
1409
1404
}
1410
1405
}
1411
1406
1412
- auto CommandBuffer = Command->CommandBuffer ;
1413
- const void *NextDesc = nullptr ;
1414
- auto SupportedFeatures =
1415
- Command->CommandBuffer ->Device ->ZeDeviceMutableCmdListsProperties
1416
- ->mutableCommandFlags ;
1417
- logger::debug (" Mutable features supported by device {}" , SupportedFeatures);
1418
-
1419
- // We need the created descriptors to live till the point when
1420
- // zexCommandListUpdateMutableCommandsExp is called at the end of the
1421
- // function.
1422
- std::vector<std::unique_ptr<ZeStruct<ze_mutable_kernel_argument_exp_desc_t >>>
1423
- ArgDescs;
1424
- std::vector<std::unique_ptr<ZeStruct<ze_mutable_global_offset_exp_desc_t >>>
1425
- OffsetDescs;
1426
- std::vector<std::unique_ptr<ZeStruct<ze_mutable_group_size_exp_desc_t >>>
1427
- GroupSizeDescs;
1428
- std::vector<std::unique_ptr<ZeStruct<ze_mutable_group_count_exp_desc_t >>>
1429
- GroupCountDescs;
1430
-
1431
1407
// Check if new global offset is provided.
1432
1408
size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset ;
1433
1409
UR_ASSERT (!NewGlobalWorkOffset ||
@@ -1439,6 +1415,56 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
1439
1415
logger::error (" No global offset extension found on this driver" );
1440
1416
return UR_RESULT_ERROR_INVALID_VALUE;
1441
1417
}
1418
+ }
1419
+
1420
+ // Check if new group size is provided.
1421
+ size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize ;
1422
+ UR_ASSERT (!NewLocalWorkSize ||
1423
+ (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE),
1424
+ UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1425
+
1426
+ // Check if new global size is provided and we need to update group count.
1427
+ size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize ;
1428
+ UR_ASSERT (!NewGlobalWorkSize ||
1429
+ (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_COUNT),
1430
+ UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1431
+ UR_ASSERT (!(NewGlobalWorkSize && !NewLocalWorkSize) ||
1432
+ (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE),
1433
+ UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1434
+
1435
+ UR_ASSERT (
1436
+ (!CommandDesc->numNewMemObjArgs && !CommandDesc->numNewPointerArgs &&
1437
+ !CommandDesc->numNewValueArgs ) ||
1438
+ (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS),
1439
+ UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1440
+
1441
+ return UR_RESULT_SUCCESS;
1442
+ }
1443
+
1444
+ static ur_result_t updateKernelCommand (
1445
+ ur_exp_command_buffer_command_handle_t Command,
1446
+ const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
1447
+
1448
+ // We need the created descriptors to live till the point when
1449
+ // zeCommandListUpdateMutableCommandsExp is called at the end of the
1450
+ // function.
1451
+ std::vector<std::variant<
1452
+ std::unique_ptr<ZeStruct<ze_mutable_kernel_argument_exp_desc_t >>,
1453
+ std::unique_ptr<ZeStruct<ze_mutable_global_offset_exp_desc_t >>,
1454
+ std::unique_ptr<ZeStruct<ze_mutable_group_size_exp_desc_t >>,
1455
+ std::unique_ptr<ZeStruct<ze_mutable_group_count_exp_desc_t >>>>
1456
+ Descs;
1457
+
1458
+ const auto CommandBuffer = Command->CommandBuffer ;
1459
+ const void *NextDesc = nullptr ;
1460
+
1461
+ uint32_t Dim = CommandDesc->newWorkDim ;
1462
+ size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset ;
1463
+ size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize ;
1464
+ size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize ;
1465
+
1466
+ // Check if a new global offset is provided.
1467
+ if (NewGlobalWorkOffset && Dim > 0 ) {
1442
1468
auto MutableGroupOffestDesc =
1443
1469
std::make_unique<ZeStruct<ze_mutable_global_offset_exp_desc_t >>();
1444
1470
MutableGroupOffestDesc->commandId = Command->CommandId ;
@@ -1451,15 +1477,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
1451
1477
DEBUG_LOG (MutableGroupOffestDesc->offsetY );
1452
1478
MutableGroupOffestDesc->offsetZ = Dim == 3 ? NewGlobalWorkOffset[2 ] : 0 ;
1453
1479
DEBUG_LOG (MutableGroupOffestDesc->offsetZ );
1480
+
1454
1481
NextDesc = MutableGroupOffestDesc.get ();
1455
- OffsetDescs .push_back (std::move (MutableGroupOffestDesc));
1482
+ Descs .push_back (std::move (MutableGroupOffestDesc));
1456
1483
}
1457
1484
1458
- // Check if new group size is provided.
1459
- size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize ;
1460
- UR_ASSERT (!NewLocalWorkSize ||
1461
- (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE),
1462
- UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1485
+ // Check if a new group size is provided.
1463
1486
if (NewLocalWorkSize && Dim > 0 ) {
1464
1487
auto MutableGroupSizeDesc =
1465
1488
std::make_unique<ZeStruct<ze_mutable_group_size_exp_desc_t >>();
@@ -1473,29 +1496,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
1473
1496
DEBUG_LOG (MutableGroupSizeDesc->groupSizeY );
1474
1497
MutableGroupSizeDesc->groupSizeZ = Dim == 3 ? NewLocalWorkSize[2 ] : 1 ;
1475
1498
DEBUG_LOG (MutableGroupSizeDesc->groupSizeZ );
1499
+
1476
1500
NextDesc = MutableGroupSizeDesc.get ();
1477
- GroupSizeDescs .push_back (std::move (MutableGroupSizeDesc));
1501
+ Descs .push_back (std::move (MutableGroupSizeDesc));
1478
1502
}
1479
1503
1480
- // Check if new global size is provided and we need to update group count.
1481
- size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize ;
1482
- UR_ASSERT (!NewGlobalWorkSize ||
1483
- (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_COUNT),
1484
- UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1485
- UR_ASSERT (!(NewGlobalWorkSize && !NewLocalWorkSize) ||
1486
- (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_GROUP_SIZE),
1487
- UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1488
-
1504
+ // Check if a new global size is provided and if we need to update the group
1505
+ // count.
1489
1506
ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
1490
1507
if (NewGlobalWorkSize && Dim > 0 ) {
1491
- uint32_t WG[3 ];
1492
- // If new global work size is provided but new local work size is not
1493
- // provided then we still need to update local work size based on size
1494
- // suggested by the driver for the kernel.
1508
+ // If a new global work size is provided but a new local work size is not
1509
+ // then we still need to update local work size based on the size suggested
1510
+ // by the driver for the kernel.
1495
1511
bool UpdateWGSize = NewLocalWorkSize == nullptr ;
1512
+
1513
+ uint32_t WG[3 ];
1496
1514
UR_CALL (calculateKernelWorkDimensions (
1497
1515
Command->Kernel , CommandBuffer->Device , ZeThreadGroupDimensions, WG,
1498
1516
Dim, NewGlobalWorkSize, NewLocalWorkSize));
1517
+
1499
1518
auto MutableGroupCountDesc =
1500
1519
std::make_unique<ZeStruct<ze_mutable_group_count_exp_desc_t >>();
1501
1520
MutableGroupCountDesc->commandId = Command->CommandId ;
@@ -1506,8 +1525,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
1506
1525
DEBUG_LOG (MutableGroupCountDesc->pGroupCount ->groupCountX );
1507
1526
DEBUG_LOG (MutableGroupCountDesc->pGroupCount ->groupCountY );
1508
1527
DEBUG_LOG (MutableGroupCountDesc->pGroupCount ->groupCountZ );
1528
+
1509
1529
NextDesc = MutableGroupCountDesc.get ();
1510
- GroupCountDescs .push_back (std::move (MutableGroupCountDesc));
1530
+ Descs .push_back (std::move (MutableGroupCountDesc));
1511
1531
1512
1532
if (UpdateWGSize) {
1513
1533
auto MutableGroupSizeDesc =
@@ -1524,16 +1544,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
1524
1544
DEBUG_LOG (MutableGroupSizeDesc->groupSizeZ );
1525
1545
1526
1546
NextDesc = MutableGroupSizeDesc.get ();
1527
- GroupSizeDescs .push_back (std::move (MutableGroupSizeDesc));
1547
+ Descs .push_back (std::move (MutableGroupSizeDesc));
1528
1548
}
1529
1549
}
1530
1550
1531
- UR_ASSERT (
1532
- (!CommandDesc->numNewMemObjArgs && !CommandDesc->numNewPointerArgs &&
1533
- !CommandDesc->numNewValueArgs ) ||
1534
- (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_ARGUMENTS),
1535
- UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1536
-
1537
1551
// Check if new memory object arguments are provided.
1538
1552
for (uint32_t NewMemObjArgNum = CommandDesc->numNewMemObjArgs ;
1539
1553
NewMemObjArgNum-- > 0 ;) {
@@ -1557,6 +1571,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
1557
1571
return UR_RESULT_ERROR_INVALID_ARGUMENT;
1558
1572
}
1559
1573
}
1574
+
1560
1575
ur_mem_handle_t NewMemObjArg = NewMemObjArgDesc.hNewMemObjArg ;
1561
1576
// The NewMemObjArg may be a NULL pointer in which case a NULL value is used
1562
1577
// for the kernel argument declared as a pointer to global or constant
@@ -1566,6 +1581,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
1566
1581
UR_CALL (NewMemObjArg->getZeHandlePtr (ZeHandlePtr, UrAccessMode,
1567
1582
CommandBuffer->Device ));
1568
1583
}
1584
+
1569
1585
auto ZeMutableArgDesc =
1570
1586
std::make_unique<ZeStruct<ze_mutable_kernel_argument_exp_desc_t >>();
1571
1587
ZeMutableArgDesc->commandId = Command->CommandId ;
@@ -1580,14 +1596,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
1580
1596
DEBUG_LOG (ZeMutableArgDesc->pArgValue );
1581
1597
1582
1598
NextDesc = ZeMutableArgDesc.get ();
1583
- ArgDescs .push_back (std::move (ZeMutableArgDesc));
1599
+ Descs .push_back (std::move (ZeMutableArgDesc));
1584
1600
}
1585
1601
1586
1602
// Check if there are new pointer arguments.
1587
1603
for (uint32_t NewPointerArgNum = CommandDesc->numNewPointerArgs ;
1588
1604
NewPointerArgNum-- > 0 ;) {
1589
1605
ur_exp_command_buffer_update_pointer_arg_desc_t NewPointerArgDesc =
1590
1606
CommandDesc->pNewPointerArgList [NewPointerArgNum];
1607
+
1591
1608
auto ZeMutableArgDesc =
1592
1609
std::make_unique<ZeStruct<ze_mutable_kernel_argument_exp_desc_t >>();
1593
1610
ZeMutableArgDesc->commandId = Command->CommandId ;
@@ -1602,14 +1619,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
1602
1619
DEBUG_LOG (ZeMutableArgDesc->pArgValue );
1603
1620
1604
1621
NextDesc = ZeMutableArgDesc.get ();
1605
- ArgDescs .push_back (std::move (ZeMutableArgDesc));
1622
+ Descs .push_back (std::move (ZeMutableArgDesc));
1606
1623
}
1607
1624
1608
1625
// Check if there are new value arguments.
1609
1626
for (uint32_t NewValueArgNum = CommandDesc->numNewValueArgs ;
1610
1627
NewValueArgNum-- > 0 ;) {
1611
1628
ur_exp_command_buffer_update_value_arg_desc_t NewValueArgDesc =
1612
1629
CommandDesc->pNewValueArgList [NewValueArgNum];
1630
+
1613
1631
auto ZeMutableArgDesc =
1614
1632
std::make_unique<ZeStruct<ze_mutable_kernel_argument_exp_desc_t >>();
1615
1633
ZeMutableArgDesc->commandId = Command->CommandId ;
@@ -1634,26 +1652,52 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
1634
1652
}
1635
1653
ZeMutableArgDesc->pArgValue = ArgValuePtr;
1636
1654
DEBUG_LOG (ZeMutableArgDesc->pArgValue );
1655
+
1637
1656
NextDesc = ZeMutableArgDesc.get ();
1638
- ArgDescs .push_back (std::move (ZeMutableArgDesc));
1657
+ Descs .push_back (std::move (ZeMutableArgDesc));
1639
1658
}
1640
1659
1641
1660
ZeStruct<ze_mutable_commands_exp_desc_t > MutableCommandDesc;
1642
1661
MutableCommandDesc.pNext = NextDesc;
1643
1662
MutableCommandDesc.flags = 0 ;
1644
1663
1645
- // We must synchronize mutable command list execution before mutating.
1646
- if (ze_fence_handle_t &ZeFence = CommandBuffer->ZeActiveFence ) {
1647
- ZE2UR_CALL (zeFenceHostSynchronize, (ZeFence, UINT64_MAX));
1648
- }
1649
-
1650
1664
auto Plt = CommandBuffer->Context ->getPlatform ();
1651
1665
UR_ASSERT (Plt->ZeMutableCmdListExt .Supported ,
1652
1666
UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1653
1667
ZE2UR_CALL (
1654
1668
Plt->ZeMutableCmdListExt .zexCommandListUpdateMutableCommandsExp ,
1655
1669
(CommandBuffer->ZeComputeCommandListTranslated , &MutableCommandDesc));
1656
- ZE2UR_CALL (zeCommandListClose, (CommandBuffer->ZeComputeCommandList ));
1670
+
1671
+ return UR_RESULT_SUCCESS;
1672
+ }
1673
+
1674
+ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp (
1675
+ ur_exp_command_buffer_command_handle_t Command,
1676
+ const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
1677
+ UR_ASSERT (Command->Kernel , UR_RESULT_ERROR_INVALID_NULL_HANDLE);
1678
+ UR_ASSERT (CommandDesc->newWorkDim <= 3 ,
1679
+ UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
1680
+
1681
+ // Lock command, kernel and command buffer for update.
1682
+ std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Guard (
1683
+ Command->Mutex , Command->CommandBuffer ->Mutex , Command->Kernel ->Mutex );
1684
+
1685
+ UR_ASSERT (Command->CommandBuffer ->IsUpdatable ,
1686
+ UR_RESULT_ERROR_INVALID_OPERATION);
1687
+ UR_ASSERT (Command->CommandBuffer ->IsFinalized ,
1688
+ UR_RESULT_ERROR_INVALID_OPERATION);
1689
+
1690
+ UR_CALL (validateCommandDesc (Command, CommandDesc));
1691
+
1692
+ // We must synchronize mutable command list execution before mutating.
1693
+ if (ze_fence_handle_t &ZeFence = Command->CommandBuffer ->ZeActiveFence ) {
1694
+ ZE2UR_CALL (zeFenceHostSynchronize, (ZeFence, UINT64_MAX));
1695
+ }
1696
+
1697
+ UR_CALL (updateKernelCommand (Command, CommandDesc));
1698
+
1699
+ ZE2UR_CALL (zeCommandListClose,
1700
+ (Command->CommandBuffer ->ZeComputeCommandList ));
1657
1701
1658
1702
return UR_RESULT_SUCCESS;
1659
1703
}
0 commit comments