@@ -152,41 +152,21 @@ UR_APIEXPORT ur_result_t UR_APICALL
152
152
urUSMGetMemAllocInfo (ur_context_handle_t hContext, const void *pMem,
153
153
ur_usm_alloc_info_t propName, size_t propValueSize,
154
154
void *pPropValue, size_t *pPropValueSizeRet) {
155
- ur_result_t Result = UR_RESULT_SUCCESS;
156
- hipPointerAttribute_t hipPointerAttributeType;
157
-
158
155
UrReturnHelper ReturnValue (propValueSize, pPropValue, pPropValueSizeRet);
159
156
160
157
try {
161
158
switch (propName) {
162
159
case UR_USM_ALLOC_INFO_TYPE: {
163
- // do not throw if hipPointerGetAttribute returns hipErrorInvalidValue
164
- hipError_t Ret = hipPointerGetAttributes (&hipPointerAttributeType, pMem);
165
- if (Ret == hipErrorInvalidValue) {
166
- // pointer not known to the HIP subsystem
167
- return ReturnValue (UR_USM_TYPE_UNKNOWN);
168
- }
169
- // Direct usage of the function, instead of UR_CHECK_ERROR, so we can get
170
- // the line offset.
171
- checkErrorUR (Ret, __func__, __LINE__ - 5 , __FILE__);
172
- // ROCm 6.0.0 introduces hipMemoryTypeUnregistered in the hipMemoryType
173
- // enum to mark unregistered allocations (i.e., via system allocators).
174
- #if HIP_VERSION_MAJOR >= 6
175
- if (hipPointerAttributeType.type == hipMemoryTypeUnregistered) {
160
+ auto MaybePointerAttrs = getPointerAttributes (pMem);
161
+ if (!MaybePointerAttrs.has_value ()) {
176
162
// pointer not known to the HIP subsystem
177
163
return ReturnValue (UR_USM_TYPE_UNKNOWN);
178
164
}
179
- #endif
180
- unsigned int Value;
181
- #if HIP_VERSION >= 50600000
182
- Value = hipPointerAttributeType.type ;
183
- #else
184
- Value = hipPointerAttributeType.memoryType ;
185
- #endif
165
+ auto Value = getMemoryType (*MaybePointerAttrs);
186
166
UR_ASSERT (Value == hipMemoryTypeDevice || Value == hipMemoryTypeHost ||
187
167
Value == hipMemoryTypeManaged,
188
168
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
189
- if (hipPointerAttributeType. isManaged || Value == hipMemoryTypeManaged) {
169
+ if (MaybePointerAttrs-> isManaged || Value == hipMemoryTypeManaged) {
190
170
// pointer to managed memory
191
171
return ReturnValue (UR_USM_TYPE_SHARED);
192
172
}
@@ -202,15 +182,18 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
202
182
ur::unreachable ();
203
183
}
204
184
case UR_USM_ALLOC_INFO_DEVICE: {
205
- // get device index associated with this pointer
206
- UR_CHECK_ERROR (hipPointerGetAttributes (&hipPointerAttributeType, pMem));
185
+ auto MaybePointerAttrs = getPointerAttributes (pMem);
186
+ if (!MaybePointerAttrs.has_value ()) {
187
+ // pointer not known to the HIP subsystem
188
+ return ReturnValue (UR_USM_TYPE_UNKNOWN);
189
+ }
207
190
208
- int DeviceIdx = hipPointerAttributeType. device ;
191
+ int DeviceIdx = MaybePointerAttrs-> device ;
209
192
210
193
// hip backend has only one platform containing all devices
211
194
ur_platform_handle_t platform;
212
195
ur_adapter_handle_t AdapterHandle = &adapter;
213
- Result = urPlatformGet (&AdapterHandle, 1 , 1 , &platform, nullptr );
196
+ UR_CHECK_ERROR ( urPlatformGet (&AdapterHandle, 1 , 1 , &platform, nullptr ) );
214
197
215
198
// get the device from the platform
216
199
ur_device_handle_t Device = platform->Devices [DeviceIdx].get ();
@@ -227,20 +210,32 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
227
210
}
228
211
return ReturnValue (Pool);
229
212
}
213
+ case UR_USM_ALLOC_INFO_BASE_PTR:
214
+ // HIP gives us the ability to query the base pointer for a device
215
+ // pointer, so check whether we've got one of those.
216
+ if (auto MaybePointerAttrs = getPointerAttributes (pMem)) {
217
+ if (getMemoryType (*MaybePointerAttrs) == hipMemoryTypeDevice) {
218
+ void *Base = nullptr ;
219
+ UR_CHECK_ERROR (hipPointerGetAttribute (
220
+ &Base, HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
221
+ (hipDeviceptr_t)pMem));
222
+ return ReturnValue (Base);
223
+ }
224
+ }
225
+ // If not, we can't be sure.
226
+ return UR_RESULT_ERROR_INVALID_VALUE;
230
227
case UR_USM_ALLOC_INFO_SIZE: {
231
228
size_t RangeSize = 0 ;
232
229
UR_CHECK_ERROR (hipMemPtrGetInfo (const_cast <void *>(pMem), &RangeSize));
233
230
return ReturnValue (RangeSize);
234
231
}
235
- case UR_USM_ALLOC_INFO_BASE_PTR:
236
- return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
237
232
default :
238
233
return UR_RESULT_ERROR_INVALID_ENUMERATION;
239
234
}
240
235
} catch (ur_result_t Error) {
241
- Result = Error;
236
+ return Error;
242
237
}
243
- return Result ;
238
+ return UR_RESULT_SUCCESS ;
244
239
}
245
240
246
241
UR_APIEXPORT ur_result_t UR_APICALL urUSMImportExp (ur_context_handle_t Context,
0 commit comments