Skip to content

Commit f06bc02

Browse files
authored
Merge pull request #1669 from JackAKirk/fix-usmptr-get-dev
[CUDA][HIP] Fix urUSMGetMemAllocInfo impl to use single platform.
2 parents 42c0b02 + 9868e3b commit f06bc02

File tree

2 files changed

+8
-14
lines changed

2 files changed

+8
-14
lines changed

source/adapters/cuda/usm.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -258,16 +258,13 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
258258
CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
259259
(CUdeviceptr)pMem));
260260

261-
// currently each device is in its own platform, so find the platform at
262-
// the same index
263-
std::vector<ur_platform_handle_t> Platforms;
264-
Platforms.resize(DeviceIndex + 1);
261+
// cuda backend has only one platform containing all devices
262+
ur_platform_handle_t platform;
265263
ur_adapter_handle_t AdapterHandle = &adapter;
266-
Result = urPlatformGet(&AdapterHandle, 1, DeviceIndex + 1,
267-
Platforms.data(), nullptr);
264+
Result = urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr);
268265

269266
// get the device from the platform
270-
ur_device_handle_t Device = Platforms[DeviceIndex]->Devices[0].get();
267+
ur_device_handle_t Device = platform->Devices[DeviceIndex].get();
271268
return ReturnValue(Device);
272269
}
273270
case UR_USM_ALLOC_INFO_POOL: {

source/adapters/hip/usm.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -207,16 +207,13 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
207207

208208
int DeviceIdx = hipPointerAttributeType.device;
209209

210-
// currently each device is in its own platform, so find the platform at
211-
// the same index
212-
std::vector<ur_platform_handle_t> Platforms;
213-
Platforms.resize(DeviceIdx + 1);
210+
// hip backend has only one platform containing all devices
211+
ur_platform_handle_t platform;
214212
ur_adapter_handle_t AdapterHandle = &adapter;
215-
Result = urPlatformGet(&AdapterHandle, 1, DeviceIdx + 1, Platforms.data(),
216-
nullptr);
213+
Result = urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr);
217214

218215
// get the device from the platform
219-
ur_device_handle_t Device = Platforms[DeviceIdx]->Devices[0].get();
216+
ur_device_handle_t Device = platform->Devices[DeviceIdx].get();
220217
return ReturnValue(Device);
221218
}
222219
case UR_USM_ALLOC_INFO_POOL: {

0 commit comments

Comments
 (0)