@@ -117,6 +117,67 @@ void embedding_inplace_update_cpu(
117
117
});
118
118
}
119
119
120
+ void dram_kv_embedding_inplace_update_cpu (
121
+ torch::jit::Module* tbe_module,
122
+ std::string tbe_module_update_func_name,
123
+ Tensor weights_placements,
124
+ Tensor weights_offsets,
125
+ Tensor weights_tys,
126
+ Tensor D_offsets,
127
+ Tensor update_weights,
128
+ Tensor update_table_idx,
129
+ Tensor update_row_idx,
130
+ Tensor update_offsets,
131
+ const int64_t row_alignment) {
132
+ TENSOR_ON_CPU (weights_placements);
133
+ TENSOR_ON_CPU (weights_offsets);
134
+ TENSOR_ON_CPU (weights_tys);
135
+ TENSOR_ON_CPU (D_offsets);
136
+
137
+ TENSOR_ON_CPU (update_table_idx);
138
+ TENSOR_ON_CPU (update_row_idx);
139
+ TENSOR_ON_CPU (update_offsets);
140
+ TENSOR_ON_CPU (update_weights);
141
+
142
+ int64_t N = update_row_idx.numel ();
143
+ if (N == 0 ) {
144
+ return ;
145
+ }
146
+ auto embedding_inplace_update_method =
147
+ tbe_module->find_method (tbe_module_update_func_name);
148
+ TORCH_CHECK (embedding_inplace_update_method.has_value ());
149
+
150
+ const uint8_t * weights_tys_ptr = weights_tys.data_ptr <uint8_t >();
151
+ const int32_t * D_offsets_ptr = D_offsets.data_ptr <int32_t >();
152
+ const uint8_t * update_weights_ptr = update_weights.data_ptr <uint8_t >();
153
+ const int32_t * update_table_idx_ptr = update_table_idx.data_ptr <int32_t >();
154
+ const int64_t * update_row_idx_ptr = update_row_idx.data_ptr <int64_t >();
155
+ const int64_t * update_offsets_ptr = update_offsets.data_ptr <int64_t >();
156
+
157
+ for (int64_t n = 0 ; n < N; ++n) {
158
+ int32_t t = update_table_idx_ptr[n];
159
+ int64_t row_idx = update_row_idx_ptr[n];
160
+ SparseType weight_ty = static_cast <SparseType>(weights_tys_ptr[t]);
161
+ int32_t D_start = D_offsets_ptr[t];
162
+ int32_t D_end = D_offsets_ptr[t + 1 ];
163
+ int32_t D = D_end - D_start;
164
+ int32_t D_bytes =
165
+ nbit::padded_row_size_in_bytes (D, weight_ty, row_alignment);
166
+
167
+ int64_t update_weight_offset = update_offsets_ptr[n];
168
+ const uint8_t * update_weight_row =
169
+ update_weights_ptr + update_weight_offset;
170
+ std::vector<uint8_t > tmp (update_weight_row, update_weight_row + D_bytes);
171
+ at::Tensor update_weight =
172
+ at::from_blob (
173
+ tmp.data (), {1 , D_bytes}, at::TensorOptions ().dtype (at::kByte ))
174
+ .clone ();
175
+ at::Tensor row_id =
176
+ at::full ({1 }, row_idx, at::TensorOptions ().dtype (at::kLong ));
177
+ (*embedding_inplace_update_method)({t, row_id, update_weight});
178
+ }
179
+ }
180
+
120
181
Tensor pruned_array_lookup_from_row_idx_cpu (
121
182
const Tensor& update_row_indices,
122
183
const Tensor& update_table_indices,
0 commit comments