Skip to content

Commit 4cc1094

Browse files
authored
[SYCL] Fix ODS filtering bug when platforms have multiple devices (#7425)
There is a filtering bug when ODS environment variable is used in environments where there are platforms with more than one device. One manifestation of this bug occurs in the form of valid devices being incorrectly excluded from the list of available devices. This PR attempts to fix this.
1 parent 2d1319f commit 4cc1094

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

sycl/source/detail/platform_impl.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,10 @@ std::vector<platform> platform_impl::get_platforms() {
153153
// to distinguish the case where we are working with ONEAPI_DEVICE_SELECTOR
154154
// in the places where the functionality diverges between these two
155155
// environment variables.
156+
// The return value is a vector that represents the indices of the chosen
157+
// devices.
156158
template <typename ListT, typename FilterT>
157-
static int filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
159+
static std::vector<int> filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
158160
RT::PiPlatform Platform, ListT *FilterList) {
159161

160162
constexpr bool is_ods_target = std::is_same_v<FilterT, ods_target>;
@@ -184,22 +186,23 @@ static int filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
184186
// in the if statement above because it will then be out of scope in the rest
185187
// of the function
186188
std::map<RT::PiDevice *, bool> Blacklist;
189+
std::vector<int> original_indices;
187190

188191
std::vector<plugin> &Plugins = RT::initialize();
189192
auto It =
190193
std::find_if(Plugins.begin(), Plugins.end(), [Platform](plugin &Plugin) {
191194
return Plugin.containsPiPlatform(Platform);
192195
});
193-
if (It == Plugins.end())
194-
return -1;
196+
if (It == Plugins.end()) {
197+
return original_indices;
198+
}
195199
plugin &Plugin = *It;
196200
backend Backend = Plugin.getBackend();
197201
int InsertIDx = 0;
198202
// DeviceIds should be given consecutive numbers across platforms in the same
199203
// backend
200204
std::lock_guard<std::mutex> Guard(*Plugin.getPluginMutex());
201205
int DeviceNum = Plugin.getStartingDeviceId(Platform);
202-
int StartingNum = DeviceNum;
203206
for (RT::PiDevice Device : PiDevices) {
204207
RT::PiDeviceType PiDevType;
205208
Plugin.call<PiApiKind::piDeviceGetInfo>(Device, PI_DEVICE_INFO_TYPE,
@@ -223,6 +226,7 @@ static int filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
223226
if (!Blacklist[&Device]) { // ensure it is not blacklisted
224227
if (!Filter.IsNegativeTarget) { // is filter positive?
225228
PiDevices[InsertIDx++] = Device;
229+
original_indices.push_back(DeviceNum);
226230
} else {
227231
// Filter is negative and the device matches the filter so
228232
// blacklist the device.
@@ -231,6 +235,7 @@ static int filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
231235
}
232236
} else { // dealing with SYCL_DEVICE_FILTER
233237
PiDevices[InsertIDx++] = Device;
238+
original_indices.push_back(DeviceNum);
234239
}
235240
break;
236241
}
@@ -241,6 +246,7 @@ static int filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
241246
if (!Blacklist[&Device]) {
242247
if (!Filter.IsNegativeTarget) {
243248
PiDevices[InsertIDx++] = Device;
249+
original_indices.push_back(DeviceNum);
244250
} else {
245251
// Filter is negative and the device matches the filter so
246252
// blacklist the device.
@@ -249,6 +255,7 @@ static int filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
249255
}
250256
} else {
251257
PiDevices[InsertIDx++] = Device;
258+
original_indices.push_back(DeviceNum);
252259
}
253260
break;
254261
}
@@ -262,7 +269,7 @@ static int filterDeviceFilter(std::vector<RT::PiDevice> &PiDevices,
262269
// to assign a unique device id number across platforms that belong to
263270
// the same backend. For example, opencl:cpu:0, opencl:acc:1, opencl:gpu:2
264271
Plugin.setLastDeviceId(Platform, DeviceNum);
265-
return StartingNum;
272+
return original_indices;
266273
}
267274

268275
std::shared_ptr<device_impl>
@@ -307,7 +314,7 @@ static bool supportsPartitionProperty(const device &dev,
307314

308315
static std::vector<device> amendDeviceAndSubDevices(
309316
backend PlatformBackend, std::vector<device> &DeviceList,
310-
ods_target_list *OdsTargetList, int PlatformDeviceIndex,
317+
ods_target_list *OdsTargetList, const std::vector<int>& original_indices,
311318
PlatformImplPtr PlatformImpl) {
312319
constexpr info::partition_property partitionProperty =
313320
info::partition_property::partition_by_affinity_domain;
@@ -335,7 +342,7 @@ static std::vector<device> amendDeviceAndSubDevices(
335342

336343
} else if (target.DeviceNum) { // opencl:0
337344
deviceMatch =
338-
(target.DeviceNum.value() == PlatformDeviceIndex + (int)i);
345+
(target.DeviceNum.value() == original_indices[i]);
339346
}
340347

341348
if (deviceMatch) {
@@ -437,7 +444,6 @@ static std::vector<device> amendDeviceAndSubDevices(
437444
}
438445
} // /for
439446
} // /for
440-
441447
return FinalResult;
442448
}
443449

@@ -500,17 +506,17 @@ platform_impl::get_devices(info::device_type DeviceType) const {
500506
// The first step is to filter out devices that are not compatible with
501507
// SYCL_DEVICE_FILTER or ONEAPI_DEVICE_SELECTOR. This is also the mechanism by
502508
// which top level device ids are assigned.
503-
int PlatformDeviceIndex;
509+
std::vector<int> PlatformDeviceIndices;
504510
if (OdsTargetList) {
505511
if (FilterList) {
506512
throw sycl::exception(sycl::make_error_code(errc::invalid),
507513
"ONEAPI_DEVICE_SELECTOR cannot be used in "
508514
"conjunction with SYCL_DEVICE_FILTER");
509515
}
510-
PlatformDeviceIndex = filterDeviceFilter<ods_target_list, ods_target>(
516+
PlatformDeviceIndices = filterDeviceFilter<ods_target_list, ods_target>(
511517
PiDevices, MPlatform, OdsTargetList);
512518
} else if (FilterList) {
513-
PlatformDeviceIndex = filterDeviceFilter<device_filter_list, device_filter>(
519+
PlatformDeviceIndices = filterDeviceFilter<device_filter_list, device_filter>(
514520
PiDevices, MPlatform, FilterList);
515521
}
516522

@@ -533,7 +539,7 @@ platform_impl::get_devices(info::device_type DeviceType) const {
533539
// Otherwise, our last step is to revisit the devices, possibly replacing
534540
// them with subdevices (which have been ignored until now)
535541
return amendDeviceAndSubDevices(Backend, Res, OdsTargetList,
536-
PlatformDeviceIndex, PlatformImpl);
542+
PlatformDeviceIndices, PlatformImpl);
537543
}
538544

539545
bool platform_impl::has_extension(const std::string &ExtensionName) const {

0 commit comments

Comments
 (0)