3
3
// See LICENSE.TXT
4
4
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5
5
6
+ #include < algorithm>
6
7
#include < cstring>
7
8
#include < fstream>
8
9
@@ -41,6 +42,23 @@ std::ostream &operator<<(std::ostream &out,
41
42
return out;
42
43
}
43
44
45
+ std::ostream &operator <<(std::ostream &out, const ur_device_handle_t &device) {
46
+ size_t size;
47
+ urDeviceGetInfo (device, UR_DEVICE_INFO_NAME, 0 , nullptr , &size);
48
+ std::vector<char > name (size);
49
+ urDeviceGetInfo (device, UR_DEVICE_INFO_NAME, size, name.data (), nullptr );
50
+ out << name.data ();
51
+ return out;
52
+ }
53
+
54
+ std::ostream &operator <<(std::ostream &out,
55
+ const std::vector<ur_device_handle_t > &devices) {
56
+ for (auto device : devices) {
57
+ out << " \n * \" " << device << " \" " ;
58
+ }
59
+ return out;
60
+ }
61
+
44
62
uur::PlatformEnvironment::PlatformEnvironment (int argc, char **argv)
45
63
: platform_options{parsePlatformOptions (argc, argv)} {
46
64
instance = this ;
@@ -100,14 +118,16 @@ uur::PlatformEnvironment::PlatformEnvironment(int argc, char **argv)
100
118
}
101
119
102
120
if (platform_options.platform_name .empty ()) {
103
- if (platforms.size () == 1 ) {
121
+
122
+ if (platforms.size () == 1 || platform_options.platforms_count == 1 ) {
104
123
platform = platforms[0 ];
105
124
} else {
106
125
std::stringstream ss_error;
107
126
ss_error << " Select a single platform from below using the "
108
127
" --platform=NAME "
109
128
" command-line option:"
110
- << platforms;
129
+ << platforms << std::endl
130
+ << " or set --platforms_count=1." ;
111
131
error = ss_error.str ();
112
132
return ;
113
133
}
@@ -136,7 +156,8 @@ uur::PlatformEnvironment::PlatformEnvironment(int argc, char **argv)
136
156
<< " \" not found. Select a single platform from below "
137
157
" using the "
138
158
" --platform=NAME command-line options:"
139
- << platforms;
159
+ << platforms << std::endl
160
+ << " or set --platforms_count=1." ;
140
161
error = ss_error.str ();
141
162
return ;
142
163
}
@@ -177,6 +198,10 @@ PlatformEnvironment::parsePlatformOptions(int argc, char **argv) {
177
198
arg, " --platform=" , sizeof (" --platform=" ) - 1 ) == 0 ) {
178
199
options.platform_name =
179
200
std::string (&arg[std::strlen (" --platform=" )]);
201
+ } else if (std::strncmp (arg, " --platforms_count=" ,
202
+ sizeof (" --platforms_count=" ) - 1 ) == 0 ) {
203
+ options.platforms_count = std::strtoul (
204
+ &arg[std::strlen (" --platforms_count=" )], nullptr , 10 );
180
205
}
181
206
}
182
207
@@ -192,10 +217,31 @@ PlatformEnvironment::parsePlatformOptions(int argc, char **argv) {
192
217
return options;
193
218
}
194
219
220
+ DevicesEnvironment::DeviceOptions
221
+ DevicesEnvironment::parseDeviceOptions (int argc, char **argv) {
222
+ DeviceOptions options;
223
+ for (int argi = 1 ; argi < argc; ++argi) {
224
+ const char *arg = argv[argi];
225
+ if (!(std::strcmp (arg, " -h" ) && std::strcmp (arg, " --help" ))) {
226
+ // TODO - print help
227
+ break ;
228
+ } else if (std::strncmp (arg, " --device=" , sizeof (" --device=" ) - 1 ) ==
229
+ 0 ) {
230
+ options.device_name = std::string (&arg[std::strlen (" --device=" )]);
231
+ } else if (std::strncmp (arg, " --devices_count=" ,
232
+ sizeof (" --devices_count=" ) - 1 ) == 0 ) {
233
+ options.devices_count = std::strtoul (
234
+ &arg[std::strlen (" --devices_count=" )], nullptr , 10 );
235
+ }
236
+ }
237
+ return options;
238
+ }
239
+
195
240
DevicesEnvironment *DevicesEnvironment::instance = nullptr ;
196
241
197
242
DevicesEnvironment::DevicesEnvironment (int argc, char **argv)
198
- : PlatformEnvironment(argc, argv) {
243
+ : PlatformEnvironment(argc, argv),
244
+ device_options (parseDeviceOptions(argc, argv)) {
199
245
instance = this ;
200
246
if (!error.empty ()) {
201
247
return ;
@@ -209,11 +255,64 @@ DevicesEnvironment::DevicesEnvironment(int argc, char **argv)
209
255
error = " Could not find any devices associated with the platform" ;
210
256
return ;
211
257
}
212
- devices.resize (count);
213
- if (urDeviceGet (platform, UR_DEVICE_TYPE_ALL, count, devices.data (),
214
- nullptr )) {
215
- error = " urDeviceGet() failed to get devices." ;
216
- return ;
258
+
259
+ // Get the argument (devices_count) to limit test devices count.
260
+ // In case, the devices_count is "0", the variable count will not be changed.
261
+ // The CTS will run on all devices.
262
+ if (device_options.device_name .empty ()) {
263
+ if (device_options.devices_count >
264
+ (std::numeric_limits<uint32_t >::max)()) {
265
+ error = " Invalid devices_count argument" ;
266
+ return ;
267
+ } else if (device_options.devices_count > 0 ) {
268
+ count = (std::min)(
269
+ count, static_cast <uint32_t >(device_options.devices_count ));
270
+ }
271
+ devices.resize (count);
272
+ if (urDeviceGet (platform, UR_DEVICE_TYPE_ALL, count, devices.data (),
273
+ nullptr )) {
274
+ error = " urDeviceGet() failed to get devices." ;
275
+ return ;
276
+ }
277
+ } else {
278
+ devices.resize (count);
279
+ if (urDeviceGet (platform, UR_DEVICE_TYPE_ALL, count, devices.data (),
280
+ nullptr )) {
281
+ error = " urDeviceGet() failed to get devices." ;
282
+ return ;
283
+ }
284
+ for (u_long i = 0 ; i < count; i++) {
285
+ size_t size;
286
+ if (urDeviceGetInfo (devices[i], UR_DEVICE_INFO_NAME, 0 , nullptr ,
287
+ &size)) {
288
+ error = " urDeviceGetInfo() failed" ;
289
+ return ;
290
+ }
291
+ std::vector<char > device_name (size);
292
+ if (urDeviceGetInfo (devices[i], UR_DEVICE_INFO_NAME, size,
293
+ device_name.data (), nullptr )) {
294
+ error = " urDeviceGetInfo() failed" ;
295
+ return ;
296
+ }
297
+ if (device_options.device_name == device_name.data ()) {
298
+ device = devices[i];
299
+ devices.clear ();
300
+ devices.resize (1 );
301
+ devices[0 ] = device;
302
+ break ;
303
+ }
304
+ }
305
+ if (!device) {
306
+ std::stringstream ss_error;
307
+ ss_error << " Device \" " << device_options.device_name
308
+ << " \" not found. Select a single device from below "
309
+ " using the "
310
+ " --device=NAME command-line options:"
311
+ << devices << std::endl
312
+ << " or set --devices_count=COUNT." ;
313
+ error = ss_error.str ();
314
+ return ;
315
+ }
217
316
}
218
317
}
219
318
0 commit comments