9
9
*/
10
10
11
11
#include "memory_tracker.h"
12
+ #include "critnib/critnib.h"
13
+
14
+ #include <umf/memory_pool.h>
12
15
#include <umf/memory_provider.h>
13
16
#include <umf/memory_provider_ops.h>
14
17
15
- #include < cassert>
16
- #include < map>
17
- #include < mutex>
18
- #include < shared_mutex>
18
+ #include <assert.h>
19
+ #include <errno.h>
19
20
#include <stdlib.h>
20
21
21
- #ifdef _WIN32
22
- #include < windows.h>
23
- #endif
24
-
25
- // TODO: reimplement in C and optimize...
26
- struct umf_memory_tracker_t {
27
- enum umf_result_t add (void *pool, const void *ptr, size_t size) {
28
- std::unique_lock<std::shared_mutex> lock (mtx);
22
+ #if !defined(_WIN32 )
23
+ critnib * TRACKER = NULL ;
24
+ void __attribute__((constructor )) createLibTracker () {
25
+ TRACKER = critnib_new ();
26
+ }
27
+ void __attribute__((destructor )) deleteLibTracker () { critnib_delete (TRACKER ); }
29
28
30
- if (size == 0 ) {
31
- return UMF_RESULT_SUCCESS;
32
- }
29
+ umf_memory_tracker_handle_t umfMemoryTrackerGet (void ) {
30
+ return (umf_memory_tracker_handle_t )TRACKER ;
31
+ }
32
+ #endif
33
33
34
- auto ret =
35
- map. try_emplace ( reinterpret_cast < uintptr_t >(ptr), size, pool) ;
36
- return ret. second ? UMF_RESULT_SUCCESS : UMF_RESULT_ERROR_UNKNOWN ;
37
- }
34
+ struct tracker_value_t {
35
+ umf_memory_pool_handle_t pool ;
36
+ size_t size ;
37
+ };
38
38
39
- enum umf_result_t remove (const void *ptr, size_t size) {
40
- std::unique_lock<std::shared_mutex> lock (mtx);
39
+ static enum umf_result_t
40
+ umfMemoryTrackerAdd (umf_memory_tracker_handle_t hTracker ,
41
+ umf_memory_pool_handle_t pool , const void * ptr ,
42
+ size_t size ) {
43
+ assert (ptr );
41
44
42
- map.erase (reinterpret_cast <uintptr_t >(ptr));
45
+ struct tracker_value_t * value =
46
+ (struct tracker_value_t * )malloc (sizeof (struct tracker_value_t ));
47
+ value -> pool = pool ;
48
+ value -> size = size ;
43
49
44
- // TODO: handle removing part of the range
45
- (void )size;
50
+ int ret = critnib_insert ((critnib * )hTracker , (uintptr_t )ptr , value , 0 );
46
51
52
+ if (ret == 0 ) {
47
53
return UMF_RESULT_SUCCESS ;
48
54
}
49
55
50
- void *find (const void *ptr) {
51
- std::shared_lock<std::shared_mutex> lock (mtx);
52
-
53
- auto intptr = reinterpret_cast <uintptr_t >(ptr);
54
- auto it = map.upper_bound (intptr);
55
- if (it == map.begin ()) {
56
- return nullptr ;
57
- }
58
-
59
- --it;
60
-
61
- auto address = it->first ;
62
- auto size = it->second .first ;
63
- auto pool = it->second .second ;
64
-
65
- if (intptr >= address && intptr < address + size) {
66
- return pool;
67
- }
56
+ free (value );
68
57
69
- return nullptr ;
58
+ if (ret == ENOMEM ) {
59
+ return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY ;
70
60
}
71
61
72
- private:
73
- std::shared_mutex mtx;
74
- std::map<uintptr_t , std::pair<size_t , void *>> map;
75
- };
76
-
77
- static enum umf_result_t
78
- umfMemoryTrackerAdd (umf_memory_tracker_handle_t hTracker, void *pool,
79
- const void *ptr, size_t size) {
80
- return hTracker->add (pool, ptr, size);
62
+ // This should not happen
63
+ // TODO: add logging here
64
+ return UMF_RESULT_ERROR_UNKNOWN ;
81
65
}
82
66
83
67
static enum umf_result_t
84
68
umfMemoryTrackerRemove (umf_memory_tracker_handle_t hTracker , const void * ptr ,
85
69
size_t size ) {
86
- return hTracker->remove (ptr, size);
87
- }
70
+ assert (ptr );
71
+
72
+ // TODO: there is no support for removing partial ranges (or multipe entires
73
+ // in a single remove call) yet.
74
+ // Every umfMemoryTrackerAdd(..., ptr, ...) should have a corresponsding
75
+ // umfMemoryTrackerRemove call with the same ptr value.
76
+ (void )size ;
77
+
78
+ void * value = critnib_remove ((critnib * )hTracker , (uintptr_t )ptr );
79
+ if (!value ) {
80
+ // This should not happen
81
+ // TODO: add logging here
82
+ return UMF_RESULT_ERROR_UNKNOWN ;
83
+ }
88
84
89
- extern " C " {
85
+ free ( value );
90
86
91
- #if defined(_WIN32) && defined(UMF_SHARED_LIBRARY)
92
- umf_memory_tracker_t *tracker = nullptr ;
93
- BOOL APIENTRY DllMain (HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved) {
94
- if (fdwReason == DLL_PROCESS_DETACH) {
95
- delete tracker;
96
- } else if (fdwReason == DLL_PROCESS_ATTACH) {
97
- tracker = new umf_memory_tracker_t ;
98
- }
99
- return TRUE ;
100
- }
101
- #elif defined(_WIN32)
102
- umf_memory_tracker_t trackerInstance;
103
- umf_memory_tracker_t *tracker = &trackerInstance;
104
- #else
105
- umf_memory_tracker_t *tracker = nullptr ;
106
- void __attribute__ ((constructor)) createLibTracker() {
107
- tracker = new umf_memory_tracker_t ;
87
+ return UMF_RESULT_SUCCESS ;
108
88
}
109
89
110
- void __attribute__ ((destructor)) deleteLibTracker() { delete tracker; }
111
- #endif
90
+ umf_memory_pool_handle_t
91
+ umfMemoryTrackerGetPool (umf_memory_tracker_handle_t hTracker , const void * ptr ) {
92
+ assert (ptr );
112
93
113
- umf_memory_tracker_handle_t umfMemoryTrackerGet (void ) { return tracker; }
94
+ uintptr_t rkey ;
95
+ struct tracker_value_t * rvalue ;
96
+ int found = critnib_find ((critnib * )hTracker , (uintptr_t )ptr , FIND_LE ,
97
+ (void * )& rkey , (void * * )& rvalue );
98
+ if (!found ) {
99
+ return NULL ;
100
+ }
114
101
115
- void *umfMemoryTrackerGetPool (umf_memory_tracker_handle_t hTracker,
116
- const void *ptr) {
117
- return hTracker->find (ptr);
102
+ return (rkey + rvalue -> size >= (uintptr_t )ptr ) ? rvalue -> pool : NULL ;
118
103
}
119
104
120
105
struct umf_tracking_memory_provider_t {
@@ -136,7 +121,7 @@ static enum umf_result_t trackingAlloc(void *hProvider, size_t size,
136
121
}
137
122
138
123
ret = umfMemoryProviderAlloc (p -> hUpstream , size , alignment , ptr );
139
- if (ret != UMF_RESULT_SUCCESS) {
124
+ if (ret != UMF_RESULT_SUCCESS || ! * ptr ) {
140
125
return ret ;
141
126
}
142
127
@@ -159,9 +144,11 @@ static enum umf_result_t trackingFree(void *hProvider, void *ptr, size_t size) {
159
144
// to avoid a race condition. If the order would be different, other thread
160
145
// could allocate the memory at address `ptr` before a call to umfMemoryTrackerRemove
161
146
// resulting in inconsistent state.
162
- ret = umfMemoryTrackerRemove (p->hTracker , ptr, size);
163
- if (ret != UMF_RESULT_SUCCESS) {
164
- return ret;
147
+ if (ptr ) {
148
+ ret = umfMemoryTrackerRemove (p -> hTracker , ptr , size );
149
+ if (ret != UMF_RESULT_SUCCESS ) {
150
+ return ret ;
151
+ }
165
152
}
166
153
167
154
ret = umfMemoryProviderFree (p -> hUpstream , ptr , size );
@@ -267,4 +254,3 @@ void umfTrackingMemoryProviderGetUpstreamProvider(
267
254
(umf_tracking_memory_provider_t * )hTrackingProvider ;
268
255
* hUpstream = p -> hUpstream ;
269
256
}
270
- }
0 commit comments