Skip to content

Commit ea13b2f

Browse files
[DeviceSanitizer] Handle the case of urMemGetNativeHandle getting a nullptr Device (#1969)
Co-authored-by: Yang Zhao <allanzyne@outlook.com>
1 parent d52dccb commit ea13b2f

File tree

4 files changed

+37
-9
lines changed

4 files changed

+37
-9
lines changed

source/loader/layers/sanitizer/asan_buffer.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,15 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
7575
return UR_RESULT_SUCCESS;
7676
}
7777

78+
// Device may be null, we follow the L0 adapter's practice to use the first
79+
// device
80+
if (!Device) {
81+
auto Devices = GetDevices(Context);
82+
assert(Devices.size() > 0 && "Devices should not be empty");
83+
Device = Devices[0];
84+
}
85+
assert((void *)Device != nullptr && "Device cannot be nullptr");
86+
7887
std::scoped_lock<ur_shared_mutex> Guard(Mutex);
7988
auto &Allocation = Allocations[Device];
8089
ur_result_t URes = UR_RESULT_SUCCESS;

source/loader/layers/sanitizer/asan_interceptor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ ur_result_t SanitizerInterceptor::updateShadowMemory(
571571
ur_result_t
572572
SanitizerInterceptor::registerDeviceGlobals(ur_context_handle_t Context,
573573
ur_program_handle_t Program) {
574-
std::vector<ur_device_handle_t> Devices = GetProgramDevices(Program);
574+
std::vector<ur_device_handle_t> Devices = GetDevices(Program);
575575

576576
auto ContextInfo = getContextInfo(Context);
577577

source/loader/layers/sanitizer/ur_sanitizer_utils.cpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,22 @@ ur_device_handle_t GetDevice(ur_queue_handle_t Queue) {
7272
return Device;
7373
}
7474

75+
std::vector<ur_device_handle_t> GetDevices(ur_context_handle_t Context) {
76+
std::vector<ur_device_handle_t> Devices{};
77+
uint32_t DeviceNum = 0;
78+
[[maybe_unused]] ur_result_t Result;
79+
Result = getContext()->urDdiTable.Context.pfnGetInfo(
80+
Context, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(uint32_t), &DeviceNum,
81+
nullptr);
82+
assert(Result == UR_RESULT_SUCCESS && "getDevices(Context) failed");
83+
Devices.resize(DeviceNum);
84+
Result = getContext()->urDdiTable.Context.pfnGetInfo(
85+
Context, UR_CONTEXT_INFO_DEVICES,
86+
sizeof(ur_device_handle_t) * DeviceNum, Devices.data(), nullptr);
87+
assert(Result == UR_RESULT_SUCCESS && "getDevices(Context) failed");
88+
return Devices;
89+
}
90+
7591
ur_program_handle_t GetProgram(ur_kernel_handle_t Kernel) {
7692
ur_program_handle_t Program{};
7793
[[maybe_unused]] auto Result = getContext()->urDdiTable.Kernel.pfnGetInfo(
@@ -169,18 +185,20 @@ bool GetDeviceUSMCapability(ur_device_handle_t Device,
169185
return (bool)Flag;
170186
}
171187

172-
std::vector<ur_device_handle_t> GetProgramDevices(ur_program_handle_t Program) {
173-
size_t PropSize;
188+
std::vector<ur_device_handle_t> GetDevices(ur_program_handle_t Program) {
189+
uint32_t DeviceNum = 0;
174190
[[maybe_unused]] ur_result_t Result =
175191
getContext()->urDdiTable.Program.pfnGetInfo(
176-
Program, UR_PROGRAM_INFO_DEVICES, 0, nullptr, &PropSize);
177-
assert(Result == UR_RESULT_SUCCESS);
192+
Program, UR_PROGRAM_INFO_NUM_DEVICES, sizeof(DeviceNum), &DeviceNum,
193+
nullptr);
194+
assert(Result == UR_RESULT_SUCCESS && "getDevices(Program) failed");
178195

179196
std::vector<ur_device_handle_t> Devices;
180-
Devices.resize(PropSize / sizeof(ur_device_handle_t));
197+
Devices.resize(DeviceNum);
181198
Result = getContext()->urDdiTable.Program.pfnGetInfo(
182-
Program, UR_PROGRAM_INFO_DEVICES, PropSize, Devices.data(), nullptr);
183-
assert(Result == UR_RESULT_SUCCESS);
199+
Program, UR_PROGRAM_INFO_DEVICES,
200+
DeviceNum * sizeof(ur_device_handle_t), Devices.data(), nullptr);
201+
assert(Result == UR_RESULT_SUCCESS && "getDevices(Program) failed");
184202

185203
return Devices;
186204
}

source/loader/layers/sanitizer/ur_sanitizer_utils.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ ur_context_handle_t GetContext(ur_queue_handle_t Queue);
3434
ur_context_handle_t GetContext(ur_program_handle_t Program);
3535
ur_context_handle_t GetContext(ur_kernel_handle_t Kernel);
3636
ur_device_handle_t GetDevice(ur_queue_handle_t Queue);
37+
std::vector<ur_device_handle_t> GetDevices(ur_context_handle_t Context);
38+
std::vector<ur_device_handle_t> GetDevices(ur_program_handle_t Program);
3739
DeviceType GetDeviceType(ur_context_handle_t Context,
3840
ur_device_handle_t Device);
3941
ur_device_handle_t GetParentDevice(ur_device_handle_t Device);
@@ -42,7 +44,6 @@ bool GetDeviceUSMCapability(ur_device_handle_t Device,
4244
std::string GetKernelName(ur_kernel_handle_t Kernel);
4345
size_t GetDeviceLocalMemorySize(ur_device_handle_t Device);
4446
ur_program_handle_t GetProgram(ur_kernel_handle_t Kernel);
45-
std::vector<ur_device_handle_t> GetProgramDevices(ur_program_handle_t Program);
4647
ur_device_handle_t GetUSMAllocDevice(ur_context_handle_t Context,
4748
const void *MemPtr);
4849
uint32_t GetKernelNumArgs(ur_kernel_handle_t Kernel);

0 commit comments

Comments
 (0)