@@ -284,27 +284,28 @@ setKernelParams(const ur_context_handle_t Context,
284
284
CudaImplicitOffset);
285
285
}
286
286
287
- if (Context->getDevice ()->maxLocalMemSizeChosen ()) {
287
+ auto Device = Context->getDevice ();
288
+ if (LocalSize > static_cast <uint32_t >(Device->getMaxCapacityLocalMem ())) {
289
+ setErrorMessage (" Too much local memory allocated for device" ,
290
+ UR_RESULT_ERROR_ADAPTER_SPECIFIC);
291
+ return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
292
+ }
293
+
294
+ if (Device->maxLocalMemSizeChosen ()) {
288
295
// Set up local memory requirements for kernel.
289
- auto Device = Context->getDevice ();
290
296
if (Device->getMaxChosenLocalMem () < 0 ) {
291
297
bool EnvVarHasURPrefix =
292
- std::getenv (" UR_CUDA_MAX_LOCAL_MEM_SIZE" ) != nullptr ;
298
+ ( std::getenv (" UR_CUDA_MAX_LOCAL_MEM_SIZE" ) != nullptr ) ;
293
299
setErrorMessage (EnvVarHasURPrefix ? " Invalid value specified for "
294
300
" UR_CUDA_MAX_LOCAL_MEM_SIZE"
295
301
: " Invalid value specified for "
296
302
" SYCL_PI_CUDA_MAX_LOCAL_MEM_SIZE" ,
297
303
UR_RESULT_ERROR_ADAPTER_SPECIFIC);
298
304
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
299
305
}
300
- if (LocalSize > static_cast <uint32_t >(Device->getMaxCapacityLocalMem ())) {
301
- setErrorMessage (" Too much local memory allocated for device" ,
302
- UR_RESULT_ERROR_ADAPTER_SPECIFIC);
303
- return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
304
- }
305
306
if (LocalSize > static_cast <uint32_t >(Device->getMaxChosenLocalMem ())) {
306
307
bool EnvVarHasURPrefix =
307
- std::getenv (" UR_CUDA_MAX_LOCAL_MEM_SIZE" ) != nullptr ;
308
+ ( std::getenv (" UR_CUDA_MAX_LOCAL_MEM_SIZE" ) != nullptr ) ;
308
309
setErrorMessage (
309
310
EnvVarHasURPrefix
310
311
? " Local memory for kernel exceeds the amount requested using "
@@ -319,6 +320,10 @@ setKernelParams(const ur_context_handle_t Context,
319
320
UR_CHECK_ERROR (cuFuncSetAttribute (
320
321
CuFunc, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
321
322
Device->getMaxChosenLocalMem ()));
323
+
324
+ } else {
325
+ UR_CHECK_ERROR (cuFuncSetAttribute (
326
+ CuFunc, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, LocalSize));
322
327
}
323
328
324
329
} catch (ur_result_t Err) {
0 commit comments