@@ -75,12 +75,14 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
75
75
return UR_RESULT_SUCCESS;
76
76
}
77
77
78
+ std::scoped_lock<ur_shared_mutex> Guard (Mutex);
78
79
auto &Allocation = Allocations[Device];
80
+ ur_result_t URes = UR_RESULT_SUCCESS;
79
81
if (!Allocation) {
80
82
ur_usm_desc_t USMDesc{};
81
83
USMDesc.align = getAlignment ();
82
84
ur_usm_pool_handle_t Pool{};
83
- ur_result_t URes = getContext ()->interceptor ->allocateMemory (
85
+ URes = getContext ()->interceptor ->allocateMemory (
84
86
Context, Device, &USMDesc, Pool, Size, AllocType::MEM_BUFFER,
85
87
ur_cast<void **>(&Allocation));
86
88
if (URes != UR_RESULT_SUCCESS) {
@@ -105,7 +107,60 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
105
107
106
108
Handle = Allocation;
107
109
108
- return UR_RESULT_SUCCESS;
110
+ if (!LastSyncedDevice.hDevice ) {
111
+ LastSyncedDevice = MemBuffer::Device_t{Device, Handle};
112
+ return URes;
113
+ }
114
+
115
+ // If the device required to allocate memory is not the previous one, we
116
+ // need to do data migration.
117
+ if (Device != LastSyncedDevice.hDevice ) {
118
+ auto &HostAllocation = Allocations[nullptr ];
119
+ if (!HostAllocation) {
120
+ ur_usm_desc_t USMDesc{};
121
+ USMDesc.align = getAlignment ();
122
+ ur_usm_pool_handle_t Pool{};
123
+ URes = getContext ()->interceptor ->allocateMemory (
124
+ Context, nullptr , &USMDesc, Pool, Size, AllocType::HOST_USM,
125
+ ur_cast<void **>(&HostAllocation));
126
+ if (URes != UR_RESULT_SUCCESS) {
127
+ getContext ()->logger .error (" Failed to allocate {} bytes host "
128
+ " USM for buffer {} migration" ,
129
+ Size, this );
130
+ return URes;
131
+ }
132
+ }
133
+
134
+ // Copy data from last synced device to host
135
+ {
136
+ ManagedQueue Queue (Context, LastSyncedDevice.hDevice );
137
+ URes = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
138
+ Queue, true , HostAllocation, LastSyncedDevice.MemHandle , Size,
139
+ 0 , nullptr , nullptr );
140
+ if (URes != UR_RESULT_SUCCESS) {
141
+ getContext ()->logger .error (
142
+ " Failed to migrate memory buffer data" );
143
+ return URes;
144
+ }
145
+ }
146
+
147
+ // Sync data back to device
148
+ {
149
+ ManagedQueue Queue (Context, Device);
150
+ URes = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
151
+ Queue, true , Allocation, HostAllocation, Size, 0 , nullptr ,
152
+ nullptr );
153
+ if (URes != UR_RESULT_SUCCESS) {
154
+ getContext ()->logger .error (
155
+ " Failed to migrate memory buffer data" );
156
+ return URes;
157
+ }
158
+ }
159
+ }
160
+
161
+ LastSyncedDevice = MemBuffer::Device_t{Device, Handle};
162
+
163
+ return URes;
109
164
}
110
165
111
166
ur_result_t MemBuffer::free () {
0 commit comments