Skip to content

Commit 0790bf8

Browse files
committed
Fix urProgramCompileExp, urProgramBuildExp, and urProgramLinkExp definition to match spec
Signed-off-by: Spruit, Neil R <neil.r.spruit@intel.com>
1 parent 4b5e559 commit 0790bf8

File tree

4 files changed

+65
-108
lines changed

4 files changed

+65
-108
lines changed

source/adapters/cuda/program.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,16 +226,14 @@ urProgramCompile(ur_context_handle_t hContext, ur_program_handle_t hProgram,
226226
return UR_RESULT_SUCCESS;
227227
}
228228

229-
UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_context_handle_t,
230-
ur_program_handle_t,
229+
UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_program_handle_t,
231230
uint32_t,
232231
ur_device_handle_t *,
233232
const char *) {
234233
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
235234
}
236235

237-
UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_context_handle_t,
238-
ur_program_handle_t,
236+
UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_program_handle_t,
239237
uint32_t,
240238
ur_device_handle_t *,
241239
const char *) {

source/adapters/hip/program.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,16 +245,14 @@ urProgramCompile(ur_context_handle_t hContext, ur_program_handle_t hProgram,
245245
return urProgramBuild(hContext, hProgram, pOptions);
246246
}
247247

248-
UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_context_handle_t,
249-
ur_program_handle_t,
248+
UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_program_handle_t,
250249
uint32_t,
251250
ur_device_handle_t *,
252251
const char *) {
253252
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
254253
}
255254

256-
UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_context_handle_t,
257-
ur_program_handle_t,
255+
UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_program_handle_t,
258256
uint32_t,
259257
ur_device_handle_t *,
260258
const char *) {

source/adapters/level_zero/program.cpp

Lines changed: 61 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuild(
113113
const char *Options ///< [in][optional] pointer to build options
114114
///< null-terminated string.
115115
) {
116-
return urProgramBuildExp(Context, Program, 1, Context->Devices.data(),
117-
Options);
116+
return urProgramBuildExp(Program, 1, Context->Devices.data(), Options);
118117
}
119118

120119
UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(
121-
ur_context_handle_t Context, ///< [in] handle of the context instance.
122-
ur_program_handle_t Program, ///< [in] Handle of the program to build.
123-
uint32_t numDevices, ur_device_handle_t *phDevices,
124-
const char *Options ///< [in][optional] pointer to build options
125-
///< null-terminated string.
120+
ur_program_handle_t hProgram, ///< [in] Handle of the program to build.
121+
uint32_t numDevices, ///< [in] number of devices
122+
ur_device_handle_t *phDevices, ///< [in][range(0, numDevices)] pointer to
123+
///< array of device handles
124+
const char *pOptions ///< [in][optional] pointer to build options
125+
///< null-terminated string.
126126
) {
127127
// TODO
128128
// Check if device belongs to associated context.
@@ -131,43 +131,42 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(
131131
// UR_RESULT_ERROR_INVALID_VALUE);
132132

133133
// We should have either IL or native device code.
134-
UR_ASSERT(Program->Code, UR_RESULT_ERROR_INVALID_PROGRAM);
134+
UR_ASSERT(hProgram->Code, UR_RESULT_ERROR_INVALID_PROGRAM);
135135

136136
// It is legal to build a program created from either IL or from native
137137
// device code.
138-
if (Program->State != ur_program_handle_t_::IL &&
139-
Program->State != ur_program_handle_t_::Native) {
138+
if (hProgram->State != ur_program_handle_t_::IL &&
139+
hProgram->State != ur_program_handle_t_::Native) {
140140
return UR_RESULT_ERROR_INVALID_OPERATION;
141141
}
142142

143-
std::scoped_lock<ur_shared_mutex> Guard(Program->Mutex);
143+
std::scoped_lock<ur_shared_mutex> Guard(hProgram->Mutex);
144144

145145
// Ask Level Zero to build and load the native code onto the device.
146146
ZeStruct<ze_module_desc_t> ZeModuleDesc;
147-
ur_program_handle_t_::SpecConstantShim Shim(Program);
148-
ZeModuleDesc.format = (Program->State == ur_program_handle_t_::IL)
147+
ur_program_handle_t_::SpecConstantShim Shim(hProgram);
148+
ZeModuleDesc.format = (hProgram->State == ur_program_handle_t_::IL)
149149
? ZE_MODULE_FORMAT_IL_SPIRV
150150
: ZE_MODULE_FORMAT_NATIVE;
151-
ZeModuleDesc.inputSize = Program->CodeLength;
152-
ZeModuleDesc.pInputModule = Program->Code.get();
153-
ZeModuleDesc.pBuildFlags = Options;
151+
ZeModuleDesc.inputSize = hProgram->CodeLength;
152+
ZeModuleDesc.pInputModule = hProgram->Code.get();
153+
ZeModuleDesc.pBuildFlags = pOptions;
154154
ZeModuleDesc.pConstants = Shim.ze();
155155

156156
ze_device_handle_t ZeDevice = phDevices[0]->ZeDevice;
157-
ze_context_handle_t ZeContext = Program->Context->ZeContext;
158-
std::ignore = Context;
157+
ze_context_handle_t ZeContext = hProgram->Context->ZeContext;
159158
std::ignore = numDevices;
160159
ze_module_handle_t ZeModule = nullptr;
161160

162161
ur_result_t Result = UR_RESULT_SUCCESS;
163-
Program->State = ur_program_handle_t_::Exe;
162+
hProgram->State = ur_program_handle_t_::Exe;
164163
ze_result_t ZeResult =
165164
ZE_CALL_NOCHECK(zeModuleCreate, (ZeContext, ZeDevice, &ZeModuleDesc,
166-
&ZeModule, &Program->ZeBuildLog));
165+
&ZeModule, &hProgram->ZeBuildLog));
167166
if (ZeResult != ZE_RESULT_SUCCESS) {
168167
// We adjust ur_program below to avoid attempting to release zeModule when
169168
// RT calls urProgramRelease().
170-
Program->State = ur_program_handle_t_::Invalid;
169+
hProgram->State = ur_program_handle_t_::Invalid;
171170
Result = ze2urResult(ZeResult);
172171
if (ZeModule) {
173172
ZE_CALL_NOCHECK(zeModuleDestroy, (ZeModule));
@@ -179,9 +178,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(
179178
// call to zeModuleDynamicLink. However, modules created with
180179
// urProgramBuild are supposed to be fully linked and ready to use.
181180
// Therefore, do an extra check now for unresolved symbols.
182-
ZeResult = checkUnresolvedSymbols(ZeModule, &Program->ZeBuildLog);
181+
ZeResult = checkUnresolvedSymbols(ZeModule, &hProgram->ZeBuildLog);
183182
if (ZeResult != ZE_RESULT_SUCCESS) {
184-
Program->State = ur_program_handle_t_::Invalid;
183+
hProgram->State = ur_program_handle_t_::Invalid;
185184
Result = (ZeResult == ZE_RESULT_ERROR_MODULE_LINK_FAILURE)
186185
? UR_RESULT_ERROR_PROGRAM_BUILD_FAILURE
187186
: ze2urResult(ZeResult);
@@ -193,22 +192,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(
193192
}
194193

195194
// We no longer need the IL / native code.
196-
Program->Code.reset();
197-
Program->ZeModule = ZeModule;
195+
hProgram->Code.reset();
196+
hProgram->ZeModule = ZeModule;
198197
return Result;
199198
}
200199

201200
UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(
202-
ur_context_handle_t Context, ///< [in] handle of the context instance.
203201
ur_program_handle_t
204-
Program, ///< [in][out] handle of the program to compile.
205-
uint32_t numDevices, ur_device_handle_t *phDevices,
206-
const char *Options ///< [in][optional] pointer to build options
207-
///< null-terminated string.
202+
hProgram, ///< [in][out] handle of the program to compile.
203+
uint32_t numDevices, ///< [in] number of devices
204+
ur_device_handle_t *phDevices, ///< [in][range(0, numDevices)] pointer to
205+
///< array of device handles
206+
const char *pOptions ///< [in][optional] pointer to build options
207+
///< null-terminated string.
208208
) {
209209
std::ignore = numDevices;
210210
std::ignore = phDevices;
211-
return urProgramCompile(Context, Program, Options);
211+
return urProgramCompile(hProgram->Context, hProgram, pOptions);
212212
}
213213

214214
UR_APIEXPORT ur_result_t UR_APICALL urProgramCompile(
@@ -251,38 +251,37 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLink(
251251
ur_program_handle_t
252252
*Program ///< [out] pointer to handle of program object created.
253253
) {
254-
return urProgramLinkExp(Context, Count, Programs, 1, Context->Devices.data(),
254+
return urProgramLinkExp(Context, Count, Context->Devices.data(), 1, Programs,
255255
Options, Program);
256256
}
257257

258258
UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
259-
ur_context_handle_t Context, ///< [in] handle of the context instance.
259+
ur_context_handle_t hContext, ///< [in] handle of the context instance.
260260
uint32_t numDevices, ///< [in] number of devices
261261
ur_device_handle_t *phDevices, ///< [in][range(0, numDevices)] pointer to
262262
///< array of device handles
263-
uint32_t Count, ///< [in] number of program handles in `phPrograms`.
264-
const ur_program_handle_t *Programs, ///< [in][range(0, count)] pointer to
265-
///< array of program handles.
266-
const char *Options, ///< [in][optional] pointer to linker options
267-
///< null-terminated string.
263+
uint32_t count, ///< [in] number of program handles in `phPrograms`.
264+
const ur_program_handle_t *phPrograms, ///< [in][range(0, count)] pointer to
265+
///< array of program handles.
266+
const char *pOptions, ///< [in][optional] pointer to linker options
267+
///< null-terminated string.
268268
ur_program_handle_t
269-
*Program ///< [out] pointer to handle of program object created.
269+
*phProgram ///< [out] pointer to handle of program object created.
270270
) {
271271
std::ignore = numDevices;
272-
273-
UR_ASSERT(Context->isValidDevice(phDevices[0]),
272+
UR_ASSERT(hContext->isValidDevice(phDevices[0]),
274273
UR_RESULT_ERROR_INVALID_DEVICE);
275274

276275
// We do not support any link flags at this time because the Level Zero API
277276
// does not have any way to pass flags that are specific to linking.
278-
if (Options && *Options != '\0') {
277+
if (pOptions && *pOptions != '\0') {
279278
std::string ErrorMessage(
280279
"Level Zero does not support kernel link flags: \"");
281-
ErrorMessage.append(Options);
280+
ErrorMessage.append(pOptions);
282281
ErrorMessage.push_back('\"');
283282
ur_program_handle_t_ *UrProgram = new ur_program_handle_t_(
284-
ur_program_handle_t_::Invalid, Context, ErrorMessage);
285-
*Program = reinterpret_cast<ur_program_handle_t>(UrProgram);
283+
ur_program_handle_t_::Invalid, hContext, ErrorMessage);
284+
*phProgram = reinterpret_cast<ur_program_handle_t>(UrProgram);
286285
return UR_RESULT_ERROR_PROGRAM_LINK_FAILURE;
287286
}
288287

@@ -299,11 +298,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
299298
// potential if there was some other code that holds more than one of these
300299
// locks simultaneously with "exclusive" access. However, there is no such
301300
// code like that, so this is also not a danger.
302-
std::vector<std::shared_lock<ur_shared_mutex>> Guards(Count);
303-
for (uint32_t I = 0; I < Count; I++) {
304-
std::shared_lock<ur_shared_mutex> Guard(Programs[I]->Mutex);
301+
std::vector<std::shared_lock<ur_shared_mutex>> Guards(count);
302+
for (uint32_t I = 0; I < count; I++) {
303+
std::shared_lock<ur_shared_mutex> Guard(phPrograms[I]->Mutex);
305304
Guards[I].swap(Guard);
306-
if (Programs[I]->State != ur_program_handle_t_::Object) {
305+
if (phPrograms[I]->State != ur_program_handle_t_::Object) {
307306
return UR_RESULT_ERROR_INVALID_OPERATION;
308307
}
309308
}
@@ -316,23 +315,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
316315
// Construct a ze_module_program_exp_desc_t which contains information about
317316
// all of the modules that will be linked together.
318317
ZeStruct<ze_module_program_exp_desc_t> ZeExtModuleDesc;
319-
std::vector<size_t> CodeSizes(Count);
320-
std::vector<const uint8_t *> CodeBufs(Count);
321-
std::vector<const char *> BuildFlagPtrs(Count);
322-
std::vector<const ze_module_constants_t *> SpecConstPtrs(Count);
318+
std::vector<size_t> CodeSizes(count);
319+
std::vector<const uint8_t *> CodeBufs(count);
320+
std::vector<const char *> BuildFlagPtrs(count);
321+
std::vector<const ze_module_constants_t *> SpecConstPtrs(count);
323322
std::vector<ur_program_handle_t_::SpecConstantShim> SpecConstShims;
324-
SpecConstShims.reserve(Count);
323+
SpecConstShims.reserve(count);
325324

326-
for (uint32_t I = 0; I < Count; I++) {
327-
ur_program_handle_t Program = Programs[I];
325+
for (uint32_t I = 0; I < count; I++) {
326+
ur_program_handle_t Program = phPrograms[I];
328327
CodeSizes[I] = Program->CodeLength;
329328
CodeBufs[I] = Program->Code.get();
330329
BuildFlagPtrs[I] = Program->BuildFlags.c_str();
331330
SpecConstShims.emplace_back(Program);
332331
SpecConstPtrs[I] = SpecConstShims[I].ze();
333332
}
334333

335-
ZeExtModuleDesc.count = Count;
334+
ZeExtModuleDesc.count = count;
336335
ZeExtModuleDesc.inputSizes = CodeSizes.data();
337336
ZeExtModuleDesc.pInputModules = CodeBufs.data();
338337
ZeExtModuleDesc.pBuildFlags = BuildFlagPtrs.data();
@@ -366,8 +365,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
366365
//
367366
// TODO: Remove this workaround when the driver is fixed.
368367
if (!phDevices[0]->Platform->ZeDriverModuleProgramExtensionFound ||
369-
(Count == 1)) {
370-
if (Count == 1) {
368+
(count == 1)) {
369+
if (count == 1) {
371370
ZeModuleDesc.pNext = nullptr;
372371
ZeModuleDesc.inputSize = ZeExtModuleDesc.inputSizes[0];
373372
ZeModuleDesc.pInputModule = ZeExtModuleDesc.pInputModules[0];
@@ -382,7 +381,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
382381

383382
// Call the Level Zero API to compile, link, and create the module.
384383
ze_device_handle_t ZeDevice = phDevices[0]->ZeDevice;
385-
ze_context_handle_t ZeContext = Context->ZeContext;
384+
ze_context_handle_t ZeContext = hContext->ZeContext;
386385
ze_module_handle_t ZeModule = nullptr;
387386
ze_module_build_log_handle_t ZeBuildLog = nullptr;
388387
ze_result_t ZeResult =
@@ -420,8 +419,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
420419
? ur_program_handle_t_::Exe
421420
: ur_program_handle_t_::Invalid;
422421
ur_program_handle_t_ *UrProgram =
423-
new ur_program_handle_t_(State, Context, ZeModule, ZeBuildLog);
424-
*Program = reinterpret_cast<ur_program_handle_t>(UrProgram);
422+
new ur_program_handle_t_(State, hContext, ZeModule, ZeBuildLog);
423+
*phProgram = reinterpret_cast<ur_program_handle_t>(UrProgram);
425424
} catch (const std::bad_alloc &) {
426425
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
427426
} catch (...) {

source/ur/ur.hpp

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -295,41 +295,3 @@ class UrReturnHelper {
295295
void *param_value;
296296
size_t *param_value_size_ret;
297297
};
298-
299-
// Needed to have compatibility with piProgramBuild
300-
// when passing a specific list of devices
301-
// See: https://github.com/oneapi-src/unified-runtime/issues/912
302-
UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(
303-
ur_context_handle_t hContext, ///< [in] handle of the context instance.
304-
ur_program_handle_t hProgram, ///< [in] Handle of the program to build.
305-
uint32_t numDevices, ur_device_handle_t *phDevices,
306-
const char *pOptions ///< [in][optional] pointer to build options
307-
///< null-terminated string.
308-
);
309-
310-
// Needed to have compatibility with piProgramCompile
311-
// when passing a specific list of devices
312-
// See: https://github.com/oneapi-src/unified-runtime/issues/912
313-
UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(
314-
ur_context_handle_t Context, ///< [in] handle of the context instance.
315-
ur_program_handle_t
316-
Program, ///< [in][out] handle of the program to compile.
317-
uint32_t numDevices, ur_device_handle_t *phDevices,
318-
const char *Options ///< [in][optional] pointer to build options
319-
///< null-terminated string.
320-
);
321-
322-
// Needed to have compatibility with piProgramLink
323-
// when passing a specific list of devices
324-
// See: https://github.com/oneapi-src/unified-runtime/issues/912
325-
UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
326-
ur_context_handle_t Context, ///< [in] handle of the context instance.
327-
uint32_t Count, ///< [in] number of program handles in `phPrograms`.
328-
const ur_program_handle_t *Programs, ///< [in][range(0, count)] pointer to
329-
///< array of program handles.
330-
uint32_t numDevices, ur_device_handle_t *phDevices,
331-
const char *Options, ///< [in][optional] pointer to linker options
332-
///< null-terminated string.
333-
ur_program_handle_t
334-
*Program ///< [out] pointer to handle of program object created.
335-
);

0 commit comments

Comments
 (0)