@@ -153,8 +153,10 @@ std::vector<platform> platform_impl::get_platforms() {
153
153
// to distinguish the case where we are working with ONEAPI_DEVICE_SELECTOR
154
154
// in the places where the functionality diverges between these two
155
155
// environment variables.
156
+ // The return value is a vector that represents the indices of the chosen
157
+ // devices.
156
158
template <typename ListT, typename FilterT>
157
- static int filterDeviceFilter (std::vector<RT::PiDevice> &PiDevices,
159
+ static std::vector< int > filterDeviceFilter (std::vector<RT::PiDevice> &PiDevices,
158
160
RT::PiPlatform Platform, ListT *FilterList) {
159
161
160
162
constexpr bool is_ods_target = std::is_same_v<FilterT, ods_target>;
@@ -184,22 +186,23 @@ static int filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
184
186
// in the if statement above because it will then be out of scope in the rest
185
187
// of the function
186
188
std::map<RT::PiDevice *, bool > Blacklist;
189
+ std::vector<int > original_indices;
187
190
188
191
std::vector<plugin> &Plugins = RT::initialize ();
189
192
auto It =
190
193
std::find_if (Plugins.begin (), Plugins.end (), [Platform](plugin &Plugin) {
191
194
return Plugin.containsPiPlatform (Platform);
192
195
});
193
- if (It == Plugins.end ())
194
- return -1 ;
196
+ if (It == Plugins.end ()) {
197
+ return original_indices;
198
+ }
195
199
plugin &Plugin = *It;
196
200
backend Backend = Plugin.getBackend ();
197
201
int InsertIDx = 0 ;
198
202
// DeviceIds should be given consecutive numbers across platforms in the same
199
203
// backend
200
204
std::lock_guard<std::mutex> Guard (*Plugin.getPluginMutex ());
201
205
int DeviceNum = Plugin.getStartingDeviceId (Platform);
202
- int StartingNum = DeviceNum;
203
206
for (RT::PiDevice Device : PiDevices) {
204
207
RT::PiDeviceType PiDevType;
205
208
Plugin.call <PiApiKind::piDeviceGetInfo>(Device, PI_DEVICE_INFO_TYPE,
@@ -223,6 +226,7 @@ static int filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
223
226
if (!Blacklist[&Device]) { // ensure it is not blacklisted
224
227
if (!Filter.IsNegativeTarget ) { // is filter positive?
225
228
PiDevices[InsertIDx++] = Device;
229
+ original_indices.push_back (DeviceNum);
226
230
} else {
227
231
// Filter is negative and the device matches the filter so
228
232
// blacklist the device.
@@ -231,6 +235,7 @@ static int filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
231
235
}
232
236
} else { // dealing with SYCL_DEVICE_FILTER
233
237
PiDevices[InsertIDx++] = Device;
238
+ original_indices.push_back (DeviceNum);
234
239
}
235
240
break ;
236
241
}
@@ -241,6 +246,7 @@ static int filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
241
246
if (!Blacklist[&Device]) {
242
247
if (!Filter.IsNegativeTarget ) {
243
248
PiDevices[InsertIDx++] = Device;
249
+ original_indices.push_back (DeviceNum);
244
250
} else {
245
251
// Filter is negative and the device matches the filter so
246
252
// blacklist the device.
@@ -249,6 +255,7 @@ static int filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
249
255
}
250
256
} else {
251
257
PiDevices[InsertIDx++] = Device;
258
+ original_indices.push_back (DeviceNum);
252
259
}
253
260
break ;
254
261
}
@@ -262,7 +269,7 @@ static int filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
262
269
// to assign a unique device id number across platforms that belong to
263
270
// the same backend. For example, opencl:cpu:0, opencl:acc:1, opencl:gpu:2
264
271
Plugin.setLastDeviceId (Platform, DeviceNum);
265
- return StartingNum ;
272
+ return original_indices ;
266
273
}
267
274
268
275
std::shared_ptr<device_impl>
@@ -307,7 +314,7 @@ static bool supportsPartitionProperty(const device &dev,
307
314
308
315
static std::vector<device> amendDeviceAndSubDevices (
309
316
backend PlatformBackend, std::vector<device> &DeviceList,
310
- ods_target_list *OdsTargetList, int PlatformDeviceIndex ,
317
+ ods_target_list *OdsTargetList, const std::vector< int >& original_indices ,
311
318
PlatformImplPtr PlatformImpl) {
312
319
constexpr info::partition_property partitionProperty =
313
320
info::partition_property::partition_by_affinity_domain;
@@ -335,7 +342,7 @@ static std::vector<device> amendDeviceAndSubDevices(
335
342
336
343
} else if (target.DeviceNum ) { // opencl:0
337
344
deviceMatch =
338
- (target.DeviceNum .value () == PlatformDeviceIndex + ( int )i );
345
+ (target.DeviceNum .value () == original_indices[i] );
339
346
}
340
347
341
348
if (deviceMatch) {
@@ -437,7 +444,6 @@ static std::vector<device> amendDeviceAndSubDevices(
437
444
}
438
445
} // /for
439
446
} // /for
440
-
441
447
return FinalResult;
442
448
}
443
449
@@ -500,17 +506,17 @@ platform_impl::get_devices(info::device_type DeviceType) const {
500
506
// The first step is to filter out devices that are not compatible with
501
507
// SYCL_DEVICE_FILTER or ONEAPI_DEVICE_SELECTOR. This is also the mechanism by
502
508
// which top level device ids are assigned.
503
- int PlatformDeviceIndex ;
509
+ std::vector< int > PlatformDeviceIndices ;
504
510
if (OdsTargetList) {
505
511
if (FilterList) {
506
512
throw sycl::exception (sycl::make_error_code (errc::invalid),
507
513
" ONEAPI_DEVICE_SELECTOR cannot be used in "
508
514
" conjunction with SYCL_DEVICE_FILTER" );
509
515
}
510
- PlatformDeviceIndex = filterDeviceFilter<ods_target_list, ods_target>(
516
+ PlatformDeviceIndices = filterDeviceFilter<ods_target_list, ods_target>(
511
517
PiDevices, MPlatform, OdsTargetList);
512
518
} else if (FilterList) {
513
- PlatformDeviceIndex = filterDeviceFilter<device_filter_list, device_filter>(
519
+ PlatformDeviceIndices = filterDeviceFilter<device_filter_list, device_filter>(
514
520
PiDevices, MPlatform, FilterList);
515
521
}
516
522
@@ -533,7 +539,7 @@ platform_impl::get_devices(info::device_type DeviceType) const {
533
539
// Otherwise, our last step is to revisit the devices, possibly replacing
534
540
// them with subdevices (which have been ignored until now)
535
541
return amendDeviceAndSubDevices (Backend, Res, OdsTargetList,
536
- PlatformDeviceIndex , PlatformImpl);
542
+ PlatformDeviceIndices , PlatformImpl);
537
543
}
538
544
539
545
bool platform_impl::has_extension (const std::string &ExtensionName) const {
0 commit comments