@@ -732,14 +732,13 @@ pi_result _pi_device::initialize(int SubSubDeviceOrdinal,
732
732
}
733
733
}
734
734
735
- // Reinitialize a sub-sub-device with its own ordinal, index and numQueues
735
+ // Reinitialize a sub-sub-device with its own ordinal, index.
736
736
// Our sub-sub-device representation is currently [Level-Zero sub-device
737
- // handle + Level-Zero compute group/engine index]. As we have a single queue
738
- // per device, we need to reinitialize numQueues in ZeProperties to be 1 .
737
+ // handle + Level-Zero compute group/engine index]. Only the specified
738
+ // index queue will be used to submit work to the sub-sub-device .
739
739
if (SubSubDeviceOrdinal >= 0 ) {
740
740
QueueGroup[queue_group_info_t ::Compute].ZeOrdinal = SubSubDeviceOrdinal;
741
741
QueueGroup[queue_group_info_t ::Compute].ZeIndex = SubSubDeviceIndex;
742
- QueueGroup[queue_group_info_t ::Compute].ZeProperties .numQueues = 1 ;
743
742
} else { // Proceed with initialization for root and sub-device
744
743
// How is it possible that there are no "compute" capabilities?
745
744
if (QueueGroup[queue_group_info_t ::Compute].ZeOrdinal < 0 ) {
@@ -862,6 +861,50 @@ pi_device _pi_context::getRootDevice() const {
862
861
}
863
862
864
863
pi_result _pi_context::initialize () {
864
+
865
+ // Helper lambda to create various USM allocators for a device.
866
+ auto createUSMAllocators = [this ](pi_device Device) {
867
+ SharedMemAllocContexts.emplace (
868
+ std::piecewise_construct, std::make_tuple (Device),
869
+ std::make_tuple (std::unique_ptr<SystemMemory>(
870
+ new USMSharedMemoryAlloc (this , Device))));
871
+ SharedReadOnlyMemAllocContexts.emplace (
872
+ std::piecewise_construct, std::make_tuple (Device),
873
+ std::make_tuple (std::unique_ptr<SystemMemory>(
874
+ new USMSharedReadOnlyMemoryAlloc (this , Device))));
875
+ DeviceMemAllocContexts.emplace (
876
+ std::piecewise_construct, std::make_tuple (Device),
877
+ std::make_tuple (std::unique_ptr<SystemMemory>(
878
+ new USMDeviceMemoryAlloc (this , Device))));
879
+ };
880
+
881
+ // Recursive helper to call createUSMAllocators for all sub-devices
882
+ std::function<void (pi_device)> createUSMAllocatorsRecursive;
883
+ createUSMAllocatorsRecursive =
884
+ [this , createUSMAllocators,
885
+ &createUSMAllocatorsRecursive](pi_device Device) -> void {
886
+ createUSMAllocators (Device);
887
+ for (auto &SubDevice : Device->SubDevices )
888
+ createUSMAllocatorsRecursive (SubDevice);
889
+ };
890
+
891
+ // Create USM allocator context for each pair (device, context).
892
+ //
893
+ for (auto &Device : Devices) {
894
+ createUSMAllocatorsRecursive (Device);
895
+ }
896
+ // Create USM allocator context for host. Device and Shared USM allocations
897
+ // are device-specific. Host allocations are not device-dependent therefore
898
+ // we don't need a map with device as key.
899
+ HostMemAllocContext = std::make_unique<USMAllocContext>(
900
+ std::unique_ptr<SystemMemory>(new USMHostMemoryAlloc (this )));
901
+
902
+ // We may allocate memory to this root device so create allocators.
903
+ if (SingleRootDevice && DeviceMemAllocContexts.find (SingleRootDevice) ==
904
+ DeviceMemAllocContexts.end ()) {
905
+ createUSMAllocators (SingleRootDevice);
906
+ }
907
+
865
908
// Create the immediate command list to be used for initializations
866
909
// Created as synchronous so level-zero performs implicit synchronization and
867
910
// there is no need to query for completion in the plugin
@@ -1112,32 +1155,30 @@ _pi_queue::_pi_queue(std::vector<ze_command_queue_handle_t> &ComputeQueues,
1112
1155
// First, see if the queue's device allows for round-robin or it is
1113
1156
// fixed to one particular compute CCS (it is so for sub-sub-devices).
1114
1157
auto &ComputeQueueGroupInfo = Device->QueueGroup [queue_type::Compute];
1158
+ ComputeQueueGroup.ZeQueues = ComputeQueues;
1115
1159
if (ComputeQueueGroupInfo.ZeIndex >= 0 ) {
1116
1160
ComputeQueueGroup.LowerIndex = ComputeQueueGroupInfo.ZeIndex ;
1117
1161
ComputeQueueGroup.UpperIndex = ComputeQueueGroupInfo.ZeIndex ;
1118
1162
ComputeQueueGroup.NextIndex = ComputeQueueGroupInfo.ZeIndex ;
1119
1163
} else {
1120
- ComputeQueueGroup.LowerIndex = 0 ;
1121
- ComputeQueueGroup.UpperIndex = INT_MAX;
1122
- ComputeQueueGroup.NextIndex = 0 ;
1123
- }
1124
-
1125
- uint32_t FilterLowerIndex = getRangeOfAllowedComputeEngines.first ;
1126
- uint32_t FilterUpperIndex = getRangeOfAllowedComputeEngines.second ;
1127
- FilterUpperIndex =
1128
- std::min ((size_t )FilterUpperIndex, ComputeQueues.size () - 1 );
1129
- if (FilterLowerIndex <= FilterUpperIndex) {
1130
- ComputeQueueGroup.ZeQueues = ComputeQueues;
1131
- ComputeQueueGroup.LowerIndex = FilterLowerIndex;
1132
- ComputeQueueGroup.UpperIndex = FilterUpperIndex;
1133
- ComputeQueueGroup.NextIndex = ComputeQueueGroup.LowerIndex ;
1134
- // Create space to hold immediate commandlists corresponding to the ZeQueues
1135
- if (UseImmediateCommandLists) {
1136
- ComputeQueueGroup.ImmCmdLists = std::vector<pi_command_list_ptr_t >(
1137
- ComputeQueueGroup.ZeQueues .size (), CommandListMap.end ());
1164
+ // Set-up to round-robin across allowed range of engines.
1165
+ uint32_t FilterLowerIndex = getRangeOfAllowedComputeEngines.first ;
1166
+ uint32_t FilterUpperIndex = getRangeOfAllowedComputeEngines.second ;
1167
+ FilterUpperIndex = std::min ((size_t )FilterUpperIndex,
1168
+ FilterLowerIndex + ComputeQueues.size () - 1 );
1169
+ if (FilterLowerIndex <= FilterUpperIndex) {
1170
+ ComputeQueueGroup.LowerIndex = FilterLowerIndex;
1171
+ ComputeQueueGroup.UpperIndex = FilterUpperIndex;
1172
+ ComputeQueueGroup.NextIndex = ComputeQueueGroup.LowerIndex ;
1173
+ // Create space to hold immediate commandlists corresponding to the
1174
+ // ZeQueues
1175
+ if (UseImmediateCommandLists) {
1176
+ ComputeQueueGroup.ImmCmdLists = std::vector<pi_command_list_ptr_t >(
1177
+ ComputeQueueGroup.ZeQueues .size (), CommandListMap.end ());
1178
+ }
1179
+ } else {
1180
+ die (" No compute queue available/allowed." );
1138
1181
}
1139
- } else {
1140
- die (" No compute queue available." );
1141
1182
}
1142
1183
1143
1184
// Copy group initialization.
@@ -1148,8 +1189,8 @@ _pi_queue::_pi_queue(std::vector<ze_command_queue_handle_t> &ComputeQueues,
1148
1189
} else {
1149
1190
uint32_t FilterLowerIndex = getRangeOfAllowedCopyEngines.first ;
1150
1191
uint32_t FilterUpperIndex = getRangeOfAllowedCopyEngines.second ;
1151
- FilterUpperIndex =
1152
- std::min (( size_t )FilterUpperIndex, CopyQueues.size () - 1 );
1192
+ FilterUpperIndex = std::min (( size_t )FilterUpperIndex,
1193
+ FilterLowerIndex + CopyQueues.size () - 1 );
1153
1194
if (FilterLowerIndex <= FilterUpperIndex) {
1154
1195
CopyQueueGroup.ZeQueues = CopyQueues;
1155
1196
CopyQueueGroup.LowerIndex = FilterLowerIndex;
@@ -3410,11 +3451,7 @@ pi_result piQueueCreate(pi_context Context, pi_device Device,
3410
3451
PI_ASSERT (Context, PI_ERROR_INVALID_CONTEXT);
3411
3452
PI_ASSERT (Queue, PI_ERROR_INVALID_QUEUE);
3412
3453
PI_ASSERT (Device, PI_ERROR_INVALID_DEVICE);
3413
-
3414
- if (std::find (Context->Devices .begin (), Context->Devices .end (), Device) ==
3415
- Context->Devices .end ()) {
3416
- return PI_ERROR_INVALID_DEVICE;
3417
- }
3454
+ PI_ASSERT (Context->isValidDevice (Device), PI_ERROR_INVALID_DEVICE);
3418
3455
3419
3456
// Create placeholder queues in the compute queue group.
3420
3457
// Actual L0 queues will be created at first use.
@@ -4196,11 +4233,7 @@ pi_result piextMemCreateWithNativeHandle(pi_native_handle NativeHandle,
4196
4233
pi_device Device = nullptr ;
4197
4234
if (ZeDevice) {
4198
4235
Device = Context->getPlatform ()->getDeviceFromNativeHandle (ZeDevice);
4199
- // Check that the device is present in this context.
4200
- if (std::find (Context->Devices .begin (), Context->Devices .end (), Device) ==
4201
- Context->Devices .end ()) {
4202
- return PI_ERROR_INVALID_CONTEXT;
4203
- }
4236
+ PI_ASSERT (Context->isValidDevice (Device), PI_ERROR_INVALID_CONTEXT);
4204
4237
}
4205
4238
4206
4239
try {
@@ -4469,12 +4502,7 @@ pi_result piProgramLink(pi_context Context, pi_uint32 NumDevices,
4469
4502
4470
4503
// Validate input parameters.
4471
4504
PI_ASSERT (DeviceList, PI_ERROR_INVALID_DEVICE);
4472
- {
4473
- auto DeviceEntry =
4474
- find (Context->Devices .begin (), Context->Devices .end (), DeviceList[0 ]);
4475
- if (DeviceEntry == Context->Devices .end ())
4476
- return PI_ERROR_INVALID_DEVICE;
4477
- }
4505
+ PI_ASSERT (Context->isValidDevice (DeviceList[0 ]), PI_ERROR_INVALID_DEVICE);
4478
4506
PI_ASSERT (!PFnNotify && !UserData, PI_ERROR_INVALID_VALUE);
4479
4507
if (NumInputPrograms == 0 || InputPrograms == nullptr )
4480
4508
return PI_ERROR_INVALID_VALUE;
@@ -4679,12 +4707,9 @@ pi_result piProgramBuild(pi_program Program, pi_uint32 NumDevices,
4679
4707
std::scoped_lock Guard (Program->Mutex );
4680
4708
// Check if device belongs to associated context.
4681
4709
PI_ASSERT (Program->Context , PI_ERROR_INVALID_PROGRAM);
4682
- {
4683
- auto DeviceEntry = find (Program->Context ->Devices .begin (),
4684
- Program->Context ->Devices .end (), DeviceList[0 ]);
4685
- if (DeviceEntry == Program->Context ->Devices .end ())
4686
- return PI_ERROR_INVALID_VALUE;
4687
- }
4710
+ PI_ASSERT (Program->Context ->isValidDevice (DeviceList[0 ]),
4711
+ PI_ERROR_INVALID_VALUE);
4712
+
4688
4713
// It is legal to build a program created from either IL or from native
4689
4714
// device code.
4690
4715
if (Program->State != _pi_program::IL &&
0 commit comments