Skip to content

Commit b60e109

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
kv embedding dram delta loading in predictor (pytorch#4438)
Summary: Pull Request resolved: pytorch#4438 X-link: facebookresearch/FBGEMM#1502 support dram kv embedding delta loading. Reviewed By: emlin Differential Revision: D76356547 fbshipit-source-id: 82dcbec798f86d7d841c4c8c4291f734c8285a19
1 parent f2e75f5 commit b60e109

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

fbgemm_gpu/include/fbgemm_gpu/embedding_inplace_update.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,19 @@ void embedding_inplace_update_cpu(
9090
std::nullopt // Not used, to match cache interface for CUDA op
9191
);
9292

93+
void dram_kv_embedding_inplace_update_cpu(
94+
torch::jit::Module* tbe_module,
95+
std::string tbe_module_update_func_name,
96+
Tensor weights_placements,
97+
Tensor weights_offsets,
98+
Tensor weights_tys,
99+
Tensor D_offsets,
100+
Tensor update_weights,
101+
Tensor update_table_idx,
102+
Tensor update_row_idx,
103+
Tensor update_offsets,
104+
const int64_t row_alignment);
105+
93106
/**
94107
* Index remapping function that returns the remapped indices.
95108
*

fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,67 @@ void embedding_inplace_update_cpu(
117117
});
118118
}
119119

120+
void dram_kv_embedding_inplace_update_cpu(
121+
torch::jit::Module* tbe_module,
122+
std::string tbe_module_update_func_name,
123+
Tensor weights_placements,
124+
Tensor weights_offsets,
125+
Tensor weights_tys,
126+
Tensor D_offsets,
127+
Tensor update_weights,
128+
Tensor update_table_idx,
129+
Tensor update_row_idx,
130+
Tensor update_offsets,
131+
const int64_t row_alignment) {
132+
TENSOR_ON_CPU(weights_placements);
133+
TENSOR_ON_CPU(weights_offsets);
134+
TENSOR_ON_CPU(weights_tys);
135+
TENSOR_ON_CPU(D_offsets);
136+
137+
TENSOR_ON_CPU(update_table_idx);
138+
TENSOR_ON_CPU(update_row_idx);
139+
TENSOR_ON_CPU(update_offsets);
140+
TENSOR_ON_CPU(update_weights);
141+
142+
int64_t N = update_row_idx.numel();
143+
if (N == 0) {
144+
return;
145+
}
146+
auto embedding_inplace_update_method =
147+
tbe_module->find_method(tbe_module_update_func_name);
148+
TORCH_CHECK(embedding_inplace_update_method.has_value());
149+
150+
const uint8_t* weights_tys_ptr = weights_tys.data_ptr<uint8_t>();
151+
const int32_t* D_offsets_ptr = D_offsets.data_ptr<int32_t>();
152+
const uint8_t* update_weights_ptr = update_weights.data_ptr<uint8_t>();
153+
const int32_t* update_table_idx_ptr = update_table_idx.data_ptr<int32_t>();
154+
const int64_t* update_row_idx_ptr = update_row_idx.data_ptr<int64_t>();
155+
const int64_t* update_offsets_ptr = update_offsets.data_ptr<int64_t>();
156+
157+
for (int64_t n = 0; n < N; ++n) {
158+
int32_t t = update_table_idx_ptr[n];
159+
int64_t row_idx = update_row_idx_ptr[n];
160+
SparseType weight_ty = static_cast<SparseType>(weights_tys_ptr[t]);
161+
int32_t D_start = D_offsets_ptr[t];
162+
int32_t D_end = D_offsets_ptr[t + 1];
163+
int32_t D = D_end - D_start;
164+
int32_t D_bytes =
165+
nbit::padded_row_size_in_bytes(D, weight_ty, row_alignment);
166+
167+
int64_t update_weight_offset = update_offsets_ptr[n];
168+
const uint8_t* update_weight_row =
169+
update_weights_ptr + update_weight_offset;
170+
std::vector<uint8_t> tmp(update_weight_row, update_weight_row + D_bytes);
171+
at::Tensor update_weight =
172+
at::from_blob(
173+
tmp.data(), {1, D_bytes}, at::TensorOptions().dtype(at::kByte))
174+
.clone();
175+
at::Tensor row_id =
176+
at::full({1}, row_idx, at::TensorOptions().dtype(at::kLong));
177+
(*embedding_inplace_update_method)({t, row_id, update_weight});
178+
}
179+
}
180+
120181
Tensor pruned_array_lookup_from_row_idx_cpu(
121182
const Tensor& update_row_indices,
122183
const Tensor& update_table_indices,

0 commit comments

Comments
 (0)