Skip to content

Commit c86b841

Browse files
committed
Moved conflicted changes to setKernelParams
1 parent c55dc2a commit c86b841

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

source/adapters/cuda/enqueue.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -284,27 +284,28 @@ setKernelParams(const ur_context_handle_t Context,
284284
CudaImplicitOffset);
285285
}
286286

287-
if (Context->getDevice()->maxLocalMemSizeChosen()) {
287+
auto Device = Context->getDevice();
288+
if (LocalSize > static_cast<uint32_t>(Device->getMaxCapacityLocalMem())) {
289+
setErrorMessage("Too much local memory allocated for device",
290+
UR_RESULT_ERROR_ADAPTER_SPECIFIC);
291+
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
292+
}
293+
294+
if (Device->maxLocalMemSizeChosen()) {
288295
// Set up local memory requirements for kernel.
289-
auto Device = Context->getDevice();
290296
if (Device->getMaxChosenLocalMem() < 0) {
291297
bool EnvVarHasURPrefix =
292-
std::getenv("UR_CUDA_MAX_LOCAL_MEM_SIZE") != nullptr;
298+
(std::getenv("UR_CUDA_MAX_LOCAL_MEM_SIZE") != nullptr);
293299
setErrorMessage(EnvVarHasURPrefix ? "Invalid value specified for "
294300
"UR_CUDA_MAX_LOCAL_MEM_SIZE"
295301
: "Invalid value specified for "
296302
"SYCL_PI_CUDA_MAX_LOCAL_MEM_SIZE",
297303
UR_RESULT_ERROR_ADAPTER_SPECIFIC);
298304
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
299305
}
300-
if (LocalSize > static_cast<uint32_t>(Device->getMaxCapacityLocalMem())) {
301-
setErrorMessage("Too much local memory allocated for device",
302-
UR_RESULT_ERROR_ADAPTER_SPECIFIC);
303-
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
304-
}
305306
if (LocalSize > static_cast<uint32_t>(Device->getMaxChosenLocalMem())) {
306307
bool EnvVarHasURPrefix =
307-
std::getenv("UR_CUDA_MAX_LOCAL_MEM_SIZE") != nullptr;
308+
(std::getenv("UR_CUDA_MAX_LOCAL_MEM_SIZE") != nullptr);
308309
setErrorMessage(
309310
EnvVarHasURPrefix
310311
? "Local memory for kernel exceeds the amount requested using "
@@ -319,6 +320,10 @@ setKernelParams(const ur_context_handle_t Context,
319320
UR_CHECK_ERROR(cuFuncSetAttribute(
320321
CuFunc, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
321322
Device->getMaxChosenLocalMem()));
323+
324+
} else {
325+
UR_CHECK_ERROR(cuFuncSetAttribute(
326+
CuFunc, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, LocalSize));
322327
}
323328

324329
} catch (ur_result_t Err) {

0 commit comments

Comments
 (0)