@@ -76,11 +76,12 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
76
76
}
77
77
78
78
auto &Allocation = Allocations[Device];
79
+ ur_result_t URes = UR_RESULT_SUCCESS;
79
80
if (!Allocation) {
80
81
ur_usm_desc_t USMDesc{};
81
82
USMDesc.align = getAlignment ();
82
83
ur_usm_pool_handle_t Pool{};
83
- ur_result_t URes = getContext ()->interceptor ->allocateMemory (
84
+ URes = getContext ()->interceptor ->allocateMemory (
84
85
Context, Device, &USMDesc, Pool, Size, AllocType::MEM_BUFFER,
85
86
ur_cast<void **>(&Allocation));
86
87
if (URes != UR_RESULT_SUCCESS) {
@@ -103,9 +104,57 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
103
104
}
104
105
}
105
106
107
+ // If the device required to allocate memory is not the previous one, we
108
+ // need to do data migration.
109
+ if (Device != LastSyncedDevice && LastSyncedDevice != nullptr ) {
110
+ auto &HostAllocation = Allocations[nullptr ];
111
+ if (!HostAllocation) {
112
+ ur_usm_desc_t USMDesc{};
113
+ USMDesc.align = getAlignment ();
114
+ ur_usm_pool_handle_t Pool{};
115
+ URes = getContext ()->interceptor ->allocateMemory (
116
+ Context, nullptr , &USMDesc, Pool, Size, AllocType::HOST_USM,
117
+ ur_cast<void **>(&HostAllocation));
118
+ if (URes != UR_RESULT_SUCCESS) {
119
+ getContext ()->logger .error (" Failed to allocate {} bytes host "
120
+ " USM for buffer {} migration" ,
121
+ Size, this );
122
+ return URes;
123
+ }
124
+ }
125
+
126
+ // Copy data from last synced device to host
127
+ {
128
+ ManagedQueue Queue (Context, LastSyncedDevice);
129
+ char *Handle;
130
+ UR_CALL (getHandle (LastSyncedDevice, Handle));
131
+ URes = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
132
+ Queue, true , HostAllocation, Handle, Size, 0 , nullptr , nullptr );
133
+ if (URes != UR_RESULT_SUCCESS) {
134
+ getContext ()->logger .error (
135
+ " Failed to migrate memory buffer data" );
136
+ return URes;
137
+ }
138
+ }
139
+
140
+ // Sync data back to device
141
+ {
142
+ ManagedQueue Queue (Context, Device);
143
+ URes = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
144
+ Queue, true , Allocation, HostAllocation, Size, 0 , nullptr ,
145
+ nullptr );
146
+ if (URes != UR_RESULT_SUCCESS) {
147
+ getContext ()->logger .error (
148
+ " Failed to migrate memory buffer data" );
149
+ return URes;
150
+ }
151
+ }
152
+ }
153
+
154
+ LastSyncedDevice = Device;
106
155
Handle = Allocation;
107
156
108
- return UR_RESULT_SUCCESS ;
157
+ return URes ;
109
158
}
110
159
111
160
ur_result_t MemBuffer::free () {
0 commit comments