|
10 | 10 |
|
11 | 11 | #pragma once
|
12 | 12 |
|
13 |
| -#include <unordered_map> |
| 13 | +#include <mutex> |
| 14 | +#include <set> |
14 | 15 | #include <ur_api.h>
|
15 | 16 |
|
16 | 17 | #include "common.hpp"
|
|
19 | 20 |
|
20 | 21 | namespace native_cpu {
|
21 | 22 | struct usm_alloc_info {
|
22 |
| - const ur_usm_type_t type; |
| 23 | + ur_usm_type_t type; |
23 | 24 | const void *base_ptr;
|
24 |
| - const size_t size; |
25 |
| - const ur_device_handle_t device; |
26 |
| - const ur_usm_pool_handle_t pool; |
27 |
| - usm_alloc_info(ur_usm_type_t type, const void *base_ptr, size_t size, |
28 |
| - ur_device_handle_t device, ur_usm_pool_handle_t pool) |
29 |
| - : type(type), base_ptr(base_ptr), size(size), device(device), pool(pool) { |
30 |
| - } |
| 25 | + size_t size; |
| 26 | + ur_device_handle_t device; |
| 27 | + ur_usm_pool_handle_t pool; |
| 28 | + |
| 29 | + // We store a pointer to the actual allocation because it is needed when |
| 30 | + // freeing memory. |
| 31 | + void *base_alloc_ptr; |
| 32 | + constexpr usm_alloc_info(ur_usm_type_t type, const void *base_ptr, |
| 33 | + size_t size, ur_device_handle_t device, |
| 34 | + ur_usm_pool_handle_t pool, void *base_alloc_ptr) |
| 35 | + : type(type), base_ptr(base_ptr), size(size), device(device), pool(pool), |
| 36 | + base_alloc_ptr(base_alloc_ptr) {} |
31 | 37 | };
|
| 38 | + |
| 39 | +constexpr usm_alloc_info usm_alloc_info_null_entry(UR_USM_TYPE_UNKNOWN, nullptr, |
| 40 | + 0, nullptr, nullptr, |
| 41 | + nullptr); |
| 42 | + |
| 43 | +constexpr size_t alloc_header_size = sizeof(usm_alloc_info); |
| 44 | + |
| 45 | +// Computes the padding that we need to add to ensure the |
| 46 | +// pointer returned by UR is aligned as the user requested. |
| 47 | +static size_t get_padding(uint32_t alignment) { |
| 48 | + assert(alignment >= alignof(usm_alloc_info) && |
| 49 | + "memory not aligned to usm_alloc_info"); |
| 50 | + if (!alignment || alloc_header_size % alignment == 0) |
| 51 | + return 0; |
| 52 | + size_t padd = 0; |
| 53 | + if (alignment <= alloc_header_size) { |
| 54 | + padd = alignment - (alloc_header_size % alignment); |
| 55 | + } else { |
| 56 | + padd = alignment - alloc_header_size; |
| 57 | + } |
| 58 | + return padd; |
| 59 | +} |
| 60 | + |
| 61 | +// In order to satisfy the MemAllocInfo queries we allocate extra memory |
| 62 | +// for the native_cpu::usm_alloc_info struct. |
| 63 | +// To satisfy the alignment requirements we "pad" the memory |
| 64 | +// allocation so that the pointer returned to the user |
| 65 | +// always satisfies (ptr % align) == 0. |
| 66 | +static inline void *malloc_impl(uint32_t alignment, size_t size) { |
| 67 | + void *ptr = nullptr; |
| 68 | + assert(alignment >= alignof(usm_alloc_info) && |
| 69 | + "memory not aligned to usm_alloc_info"); |
| 70 | +#ifdef _MSC_VER |
| 71 | + ptr = _aligned_malloc(alloc_header_size + get_padding(alignment) + size, |
| 72 | + alignment); |
| 73 | + |
| 74 | +#else |
| 75 | + ptr = std::aligned_alloc(alignment, |
| 76 | + alloc_header_size + get_padding(alignment) + size); |
| 77 | +#endif |
| 78 | + return ptr; |
| 79 | +} |
| 80 | + |
| 81 | +// The info struct is retrieved by subtracting its size from the pointer |
| 82 | +// returned to the user. |
| 83 | +static inline uint8_t *get_alloc_info_addr(const void *ptr) { |
| 84 | + return (uint8_t *)const_cast<void *>(ptr) - alloc_header_size; |
| 85 | +} |
| 86 | + |
| 87 | +static usm_alloc_info get_alloc_info(void *ptr) { |
| 88 | + return *(usm_alloc_info *)get_alloc_info_addr(ptr); |
| 89 | +} |
| 90 | + |
32 | 91 | } // namespace native_cpu
|
33 | 92 |
|
34 | 93 | struct ur_context_handle_t_ : RefCounted {
|
35 | 94 | ur_context_handle_t_(ur_device_handle_t_ *phDevices) : _device{phDevices} {}
|
36 | 95 |
|
37 | 96 | ur_device_handle_t _device;
|
38 | 97 |
|
39 |
| - void add_alloc_info_entry(const void *ptr, ur_usm_type_t type, size_t size, |
40 |
| - ur_usm_pool_handle_t pool) { |
41 |
| - native_cpu::usm_alloc_info info(type, ptr, size, this->_device, pool); |
42 |
| - alloc_info.insert(std::make_pair(ptr, info)); |
| 98 | + ur_result_t remove_alloc(void *ptr) { |
| 99 | + std::lock_guard<std::mutex> lock(alloc_mutex); |
| 100 | + const native_cpu::usm_alloc_info &info = native_cpu::get_alloc_info(ptr); |
| 101 | + UR_ASSERT(info.type != UR_USM_TYPE_UNKNOWN, |
| 102 | + UR_RESULT_ERROR_INVALID_MEM_OBJECT); |
| 103 | +#ifdef _MSC_VER |
| 104 | + _aligned_free(info.base_alloc_ptr); |
| 105 | +#else |
| 106 | + free(info.base_alloc_ptr); |
| 107 | +#endif |
| 108 | + allocations.erase(ptr); |
| 109 | + return UR_RESULT_SUCCESS; |
43 | 110 | }
|
44 | 111 |
|
45 |
| - native_cpu::usm_alloc_info get_alloc_info_entry(const void *ptr) const { |
46 |
| - auto it = alloc_info.find(ptr); |
47 |
| - if (it == alloc_info.end()) { |
48 |
| - return native_cpu::usm_alloc_info(UR_USM_TYPE_UNKNOWN, ptr, 0, nullptr, |
49 |
| - nullptr); |
| 112 | + const native_cpu::usm_alloc_info & |
| 113 | + get_alloc_info_entry(const void *ptr) const { |
| 114 | + auto it = allocations.find(ptr); |
| 115 | + if (it == allocations.end()) { |
| 116 | + return native_cpu::usm_alloc_info_null_entry; |
50 | 117 | }
|
51 |
| - return it->second; |
| 118 | + |
| 119 | + return *(native_cpu::usm_alloc_info *)native_cpu::get_alloc_info_addr(ptr); |
52 | 120 | }
|
53 | 121 |
|
54 |
| - void remove_alloc_info_entry(void *ptr) { alloc_info.erase(ptr); } |
| 122 | + void *add_alloc(uint32_t alignment, ur_usm_type_t type, size_t size, |
| 123 | + ur_usm_pool_handle_t pool) { |
| 124 | + std::lock_guard<std::mutex> lock(alloc_mutex); |
| 125 | + // We need to ensure that we align to at least alignof(usm_alloc_info), |
| 126 | + // otherwise its start address may be unaligned. |
| 127 | + alignment = |
| 128 | + std::max<size_t>(alignment, alignof(native_cpu::usm_alloc_info)); |
| 129 | + void *alloc = native_cpu::malloc_impl(alignment, size); |
| 130 | + if (!alloc) |
| 131 | + return nullptr; |
| 132 | + // Compute the address of the pointer that we'll return to the user. |
| 133 | + void *ptr = native_cpu::alloc_header_size + |
| 134 | + native_cpu::get_padding(alignment) + (uint8_t *)alloc; |
| 135 | + uint8_t *info_addr = native_cpu::get_alloc_info_addr(ptr); |
| 136 | + if (!info_addr) |
| 137 | + return nullptr; |
| 138 | + // Do a placement new of the alloc_info to avoid allocation and copy |
| 139 | + auto info = new (info_addr) |
| 140 | + native_cpu::usm_alloc_info(type, ptr, size, this->_device, pool, alloc); |
| 141 | + if (!info) |
| 142 | + return nullptr; |
| 143 | + allocations.insert(ptr); |
| 144 | + return ptr; |
| 145 | + } |
55 | 146 |
|
56 | 147 | private:
|
57 |
| - std::unordered_map<const void *, native_cpu::usm_alloc_info> alloc_info; |
| 148 | + std::mutex alloc_mutex; |
| 149 | + std::set<const void *> allocations; |
58 | 150 | };
|
0 commit comments