@@ -42,6 +42,23 @@ std::ostream &operator<<(std::ostream &out,
42
42
return out;
43
43
}
44
44
45
+ std::ostream &operator <<(std::ostream &out, const ur_device_handle_t &device) {
46
+ size_t size;
47
+ urDeviceGetInfo (device, UR_DEVICE_INFO_NAME, 0 , nullptr , &size);
48
+ std::vector<char > name (size);
49
+ urDeviceGetInfo (device, UR_DEVICE_INFO_NAME, size, name.data (), nullptr );
50
+ out << name.data ();
51
+ return out;
52
+ }
53
+
54
+ std::ostream &operator <<(std::ostream &out,
55
+ const std::vector<ur_device_handle_t > &devices) {
56
+ for (auto device : devices) {
57
+ out << " \n * \" " << device << " \" " ;
58
+ }
59
+ return out;
60
+ }
61
+
45
62
uur::PlatformEnvironment::PlatformEnvironment (int argc, char **argv)
46
63
: platform_options{parsePlatformOptions (argc, argv)} {
47
64
instance = this ;
@@ -101,14 +118,16 @@ uur::PlatformEnvironment::PlatformEnvironment(int argc, char **argv)
101
118
}
102
119
103
120
if (platform_options.platform_name .empty ()) {
104
- if (platforms.size () == 1 ) {
121
+
122
+ if (platforms.size () == 1 || platform_options.platforms_count == 1 ) {
105
123
platform = platforms[0 ];
106
124
} else {
107
125
std::stringstream ss_error;
108
126
ss_error << " Select a single platform from below using the "
109
127
" --platform=NAME "
110
128
" command-line option:"
111
- << platforms;
129
+ << platforms << std::endl
130
+ << " or set --platforms_count=1." ;
112
131
error = ss_error.str ();
113
132
return ;
114
133
}
@@ -137,7 +156,8 @@ uur::PlatformEnvironment::PlatformEnvironment(int argc, char **argv)
137
156
<< " \" not found. Select a single platform from below "
138
157
" using the "
139
158
" --platform=NAME command-line options:"
140
- << platforms;
159
+ << platforms << std::endl
160
+ << " or set --platforms_count=1." ;
141
161
error = ss_error.str ();
142
162
return ;
143
163
}
@@ -178,6 +198,10 @@ PlatformEnvironment::parsePlatformOptions(int argc, char **argv) {
178
198
arg, " --platform=" , sizeof (" --platform=" ) - 1 ) == 0 ) {
179
199
options.platform_name =
180
200
std::string (&arg[std::strlen (" --platform=" )]);
201
+ } else if (std::strncmp (arg, " --platforms_count=" ,
202
+ sizeof (" --platforms_count=" ) - 1 ) == 0 ) {
203
+ options.platforms_count = std::strtoul (
204
+ &arg[std::strlen (" --platforms_count=" )], nullptr , 10 );
181
205
}
182
206
}
183
207
@@ -193,10 +217,31 @@ PlatformEnvironment::parsePlatformOptions(int argc, char **argv) {
193
217
return options;
194
218
}
195
219
220
+ DevicesEnvironment::DeviceOptions
221
+ DevicesEnvironment::parseDeviceOptions (int argc, char **argv) {
222
+ DeviceOptions options;
223
+ for (int argi = 1 ; argi < argc; ++argi) {
224
+ const char *arg = argv[argi];
225
+ if (!(std::strcmp (arg, " -h" ) && std::strcmp (arg, " --help" ))) {
226
+ // TODO - print help
227
+ break ;
228
+ } else if (std::strncmp (arg, " --device=" , sizeof (" --device=" ) - 1 ) ==
229
+ 0 ) {
230
+ options.device_name = std::string (&arg[std::strlen (" --device=" )]);
231
+ } else if (std::strncmp (arg, " --devices_count=" ,
232
+ sizeof (" --devices_count=" ) - 1 ) == 0 ) {
233
+ options.devices_count = std::strtoul (
234
+ &arg[std::strlen (" --devices_count=" )], nullptr , 10 );
235
+ }
236
+ }
237
+ return options;
238
+ }
239
+
196
240
DevicesEnvironment *DevicesEnvironment::instance = nullptr ;
197
241
198
242
DevicesEnvironment::DevicesEnvironment (int argc, char **argv)
199
- : PlatformEnvironment(argc, argv) {
243
+ : PlatformEnvironment(argc, argv),
244
+ device_options (parseDeviceOptions(argc, argv)) {
200
245
instance = this ;
201
246
if (!error.empty ()) {
202
247
return ;
@@ -210,27 +255,64 @@ DevicesEnvironment::DevicesEnvironment(int argc, char **argv)
210
255
error = " Could not find any devices associated with the platform" ;
211
256
return ;
212
257
}
213
- // Get the argument (test_devices_count) to limit test devices count.
214
- u_long count_set = 0 ;
215
- for (int i = 1 ; i < argc; ++i) {
216
- if (std::strcmp (argv[i], " --test_devices_count" ) == 0 && i + 1 < argc) {
217
- count_set = std::strtoul (argv[i + 1 ], nullptr , 10 );
218
- break ;
219
- }
220
- }
221
- // In case, the count_set is "0", the variable count will not be changed.
258
+
259
+ // Get the argument (devices_count) to limit test devices count.
260
+ // In case, the devices_count is "0", the variable count will not be changed.
222
261
// The CTS will run on all devices.
223
- if (count_set > (std::numeric_limits<uint32_t >::max)()) {
224
- error = " Invalid test_devices_count argument" ;
225
- return ;
226
- } else if (count_set > 0 ) {
227
- count = (std::min)(count, static_cast <uint32_t >(count_set));
228
- }
229
- devices.resize (count);
230
- if (urDeviceGet (platform, UR_DEVICE_TYPE_ALL, count, devices.data (),
231
- nullptr )) {
232
- error = " urDeviceGet() failed to get devices." ;
233
- return ;
262
+ if (device_options.device_name .empty ()) {
263
+ if (device_options.devices_count >
264
+ (std::numeric_limits<uint32_t >::max)()) {
265
+ error = " Invalid devices_count argument" ;
266
+ return ;
267
+ } else if (device_options.devices_count > 0 ) {
268
+ count = (std::min)(
269
+ count, static_cast <uint32_t >(device_options.devices_count ));
270
+ }
271
+ devices.resize (count);
272
+ if (urDeviceGet (platform, UR_DEVICE_TYPE_ALL, count, devices.data (),
273
+ nullptr )) {
274
+ error = " urDeviceGet() failed to get devices." ;
275
+ return ;
276
+ }
277
+ } else {
278
+ devices.resize (count);
279
+ if (urDeviceGet (platform, UR_DEVICE_TYPE_ALL, count, devices.data (),
280
+ nullptr )) {
281
+ error = " urDeviceGet() failed to get devices." ;
282
+ return ;
283
+ }
284
+ for (u_long i = 0 ; i < count; i++) {
285
+ size_t size;
286
+ if (urDeviceGetInfo (devices[i], UR_DEVICE_INFO_NAME, 0 , nullptr ,
287
+ &size)) {
288
+ error = " urDeviceGetInfo() failed" ;
289
+ return ;
290
+ }
291
+ std::vector<char > device_name (size);
292
+ if (urDeviceGetInfo (devices[i], UR_DEVICE_INFO_NAME, size,
293
+ device_name.data (), nullptr )) {
294
+ error = " urDeviceGetInfo() failed" ;
295
+ return ;
296
+ }
297
+ if (device_options.device_name == device_name.data ()) {
298
+ device = devices[i];
299
+ devices.clear ();
300
+ devices.resize (1 );
301
+ devices[0 ] = device;
302
+ break ;
303
+ }
304
+ }
305
+ if (!device) {
306
+ std::stringstream ss_error;
307
+ ss_error << " Device \" " << device_options.device_name
308
+ << " \" not found. Select a single device from below "
309
+ " using the "
310
+ " --device=NAME command-line options:"
311
+ << devices << std::endl
312
+ << " or set --devices_count=COUNT." ;
313
+ error = ss_error.str ();
314
+ return ;
315
+ }
234
316
}
235
317
}
236
318
0 commit comments