Skip to content

Commit 0647802

Browse files
committed
Fix urUSMGetMemAllocInfo
Now all devs are in a single platform this needs updating. Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
1 parent 905804c commit 0647802

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

source/adapters/cuda/usm.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -258,16 +258,14 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
258258
CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
259259
(CUdeviceptr)pMem));
260260

261-
// currently each device is in its own platform, so find the platform at
262-
// the same index
263-
std::vector<ur_platform_handle_t> Platforms;
264-
Platforms.resize(DeviceIndex + 1);
261+
// cuda backend has only one platform containing all devices
262+
ur_platform_handle_t platform;
265263
ur_adapter_handle_t AdapterHandle = &adapter;
266-
Result = urPlatformGet(&AdapterHandle, 1, DeviceIndex + 1,
267-
Platforms.data(), nullptr);
264+
Result = urPlatformGet(&AdapterHandle, 1, 1,
265+
&platform, nullptr);
268266

269267
// get the device from the platform
270-
ur_device_handle_t Device = Platforms[DeviceIndex]->Devices[0].get();
268+
ur_device_handle_t Device = platform->Devices[DeviceIndex].get();
271269
return ReturnValue(Device);
272270
}
273271
case UR_USM_ALLOC_INFO_POOL: {

source/adapters/hip/usm.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -207,16 +207,14 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
207207

208208
int DeviceIdx = hipPointerAttributeType.device;
209209

210-
// currently each device is in its own platform, so find the platform at
211-
// the same index
212-
std::vector<ur_platform_handle_t> Platforms;
213-
Platforms.resize(DeviceIdx + 1);
210+
// hip backend has only one platform containing all devices
211+
ur_platform_handle_t platform;
214212
ur_adapter_handle_t AdapterHandle = &adapter;
215-
Result = urPlatformGet(&AdapterHandle, 1, DeviceIdx + 1, Platforms.data(),
216-
nullptr);
213+
Result = urPlatformGet(&AdapterHandle, 1, 1,
214+
&platform, nullptr);
217215

218216
// get the device from the platform
219-
ur_device_handle_t Device = Platforms[DeviceIdx]->Devices[0].get();
217+
ur_device_handle_t Device = platform->Devices[DeviceIdx].get();
220218
return ReturnValue(Device);
221219
}
222220
case UR_USM_ALLOC_INFO_POOL: {

0 commit comments

Comments
 (0)