Skip to content

Commit f2da3ab

Browse files
chenyuzfacebook-github-bot
authored andcommitted
kv embedding inference cache wrapper (#4343)
Summary: X-link: facebookresearch/FBGEMM#1411 Pull Request resolved: #4343 **Context** This diff adds a kv embedding inference cache wrapper that can be used in python operator. The kv implementation is reusing the training backend [DramKVEmbeddingCache](https://fburl.com/code/l8rivkd3) Design doc for kv embedding cache for inference https://docs.google.com/document/d/1TNJMnj-PPKWitMgwB8HJIsFT3OwotiEnqgh60fI9P48/edit?tab=t.0#heading=h.o9irumwgl8gj Differential Revision: D72587941
1 parent 7dbd73a commit f2da3ab

File tree

2 files changed

+188
-0
lines changed

2 files changed

+188
-0
lines changed
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "deeplearning/fbgemm/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.h"
10+
#include <torch/custom_class.h>
11+
#include "deeplearning/fbgemm/fbgemm_gpu/include/fbgemm_gpu/embedding_common.h" // @manual=//deeplearning/fbgemm/fbgemm_gpu:fbgemm_gpu
12+
13+
namespace fbgemm_gpu {
14+
15+
DramKVEmbeddingInferenceWrapper::DramKVEmbeddingInferenceWrapper(
16+
int64_t num_shards,
17+
double uniform_init_lower,
18+
double uniform_init_upper,
19+
int64_t evict_trigger_mode)
20+
: num_shards_(num_shards),
21+
uniform_init_lower_(uniform_init_lower),
22+
uniform_init_upper_(uniform_init_upper),
23+
evict_trigger_mode_(evict_trigger_mode) {}
24+
25+
void DramKVEmbeddingInferenceWrapper::init(
26+
const std::vector<SerializedSepcType>& specs,
27+
const int64_t row_alignment,
28+
const int64_t scale_bias_size_in_bytes) {
29+
int64_t max_D = 0;
30+
for (auto i = 0; i < specs.size(); ++i) {
31+
max_D = std::max(max_D, std::get<1>(specs[i]));
32+
}
33+
max_row_bytes_ = nbit::padded_row_size_in_bytes(
34+
static_cast<int32_t>(max_D),
35+
static_cast<fbgemm_gpu::SparseType>(std::get<2>(specs[0])),
36+
static_cast<int32_t>(row_alignment),
37+
static_cast<int32_t>(scale_bias_size_in_bytes));
38+
dram_cache_ = std::make_unique<kv_mem::DramKVEmbeddingCache<uint8_t>>(
39+
max_row_bytes_,
40+
uniform_init_lower_,
41+
uniform_init_upper_,
42+
evict_trigger_mode_,
43+
0 /* trigger_step_intervals */,
44+
0 /* mem_util_threshold_in_GB */,
45+
1 /* evict_trigger_strategy */,
46+
std::nullopt /* counter_thresholds */,
47+
std::nullopt /* ttls_in_mins */,
48+
std::nullopt /* counter_decay_rates */,
49+
std::nullopt /* l2_weight_thresholds */,
50+
num_shards_ /* num_shards */,
51+
num_shards_ /* num_threads */,
52+
8 /* row_storage_bitwidth */);
53+
return;
54+
}
55+
56+
void DramKVEmbeddingInferenceWrapper::set_embeddings(
57+
const at::Tensor& indices,
58+
const at::Tensor& weights) {
59+
const auto count = at::tensor({indices.numel()}, at::ScalarType::Long);
60+
folly::coro::blockingWait(
61+
dram_cache_->set_kv_db_async(indices, weights, count));
62+
return;
63+
}
64+
65+
at::Tensor DramKVEmbeddingInferenceWrapper::get_embeddings(
66+
const at::Tensor& indices) {
67+
const auto count = at::tensor({indices.numel()}, at::ScalarType::Long);
68+
auto weights = at::empty(
69+
{
70+
indices.numel(),
71+
max_row_bytes_,
72+
},
73+
at::kByte);
74+
folly::coro::blockingWait(
75+
dram_cache_->get_kv_db_async(indices, weights, count));
76+
return weights;
77+
}
78+
79+
c10::List<at::Tensor> DramKVEmbeddingInferenceWrapper::serialize() const {
80+
c10::List<at::Tensor> results;
81+
results.push_back(
82+
torch::tensor({num_shards_, evict_trigger_mode_}, torch::kInt64));
83+
results.push_back(torch::tensor(
84+
{uniform_init_lower_, uniform_init_upper_}, torch::kDouble));
85+
return results;
86+
}
87+
88+
void DramKVEmbeddingInferenceWrapper::deserialize(
89+
const c10::List<at::Tensor>& states) {
90+
if (states.empty()) {
91+
return;
92+
}
93+
TORCH_CHECK(states.size() >= 2);
94+
95+
auto* intPtr = states[0].data_ptr<int64_t>();
96+
TORCH_CHECK(states[0].numel() >= 2)
97+
num_shards_ = intPtr[0];
98+
evict_trigger_mode_ = intPtr[1];
99+
100+
TORCH_CHECK(states[1].numel() >= 2)
101+
auto* floatPtr = states[1].data_ptr<double>();
102+
uniform_init_lower_ = floatPtr[0];
103+
uniform_init_upper_ = floatPtr[1];
104+
}
105+
106+
} // namespace fbgemm_gpu
107+
108+
static auto dram_kv_embedding_inference_wrapper =
109+
torch::class_<fbgemm_gpu::DramKVEmbeddingInferenceWrapper>(
110+
"fbgemm",
111+
"DramKVEmbeddingInferenceWrapper")
112+
.def(torch::init<int64_t, double, double, int64_t>())
113+
.def("init", &fbgemm_gpu::DramKVEmbeddingInferenceWrapper::init)
114+
.def(
115+
"set_embeddings",
116+
&fbgemm_gpu::DramKVEmbeddingInferenceWrapper::set_embeddings)
117+
.def(
118+
"get_embeddings",
119+
&fbgemm_gpu::DramKVEmbeddingInferenceWrapper::get_embeddings)
120+
.def(
121+
"serialize",
122+
&fbgemm_gpu::DramKVEmbeddingInferenceWrapper::serialize)
123+
.def(
124+
"deserialize",
125+
&fbgemm_gpu::DramKVEmbeddingInferenceWrapper::deserialize)
126+
.def_pickle(
127+
// __getstate__
128+
[](const c10::intrusive_ptr<
129+
fbgemm_gpu::DramKVEmbeddingInferenceWrapper>& self)
130+
-> c10::List<at::Tensor> { return self->serialize(); },
131+
// __setstate__
132+
[](const c10::List<at::Tensor>& states) {
133+
auto ptr = c10::make_intrusive<
134+
fbgemm_gpu::DramKVEmbeddingInferenceWrapper>(
135+
fbgemm_gpu::DramKVEmbeddingInferenceWrapper());
136+
ptr->deserialize(states);
137+
return ptr;
138+
});
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <torch/custom_class.h>
12+
#include "deeplearning/fbgemm/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h"
13+
14+
namespace fbgemm_gpu {
15+
16+
class DramKVEmbeddingInferenceWrapper : public torch::jit::CustomClassHolder {
17+
public:
18+
DramKVEmbeddingInferenceWrapper(
19+
int64_t num_shards = 32,
20+
double uniform_init_lower = 0.0,
21+
double uniform_init_upper = 0.0,
22+
int64_t evict_trigger_mode = 0);
23+
24+
using SerializedSepcType =
25+
std::tuple<int64_t, int64_t, int64_t>; // (rows, dime, sparse_type)
26+
27+
void init(
28+
const std::vector<SerializedSepcType>& specs,
29+
const int64_t row_alignment,
30+
const int64_t scale_bias_size_in_bytes);
31+
32+
void set_embeddings(const at::Tensor& indices, const at::Tensor& weights);
33+
34+
at::Tensor get_embeddings(const at::Tensor& indices);
35+
36+
c10::List<at::Tensor> serialize() const;
37+
38+
void deserialize(const c10::List<at::Tensor>& states);
39+
40+
private:
41+
int64_t num_shards_ = 32;
42+
double uniform_init_lower_ = 0.0;
43+
double uniform_init_upper_ = 0.0;
44+
int64_t evict_trigger_mode_ = 0;
45+
46+
std::unique_ptr<kv_mem::DramKVEmbeddingCache<uint8_t>> dram_cache_;
47+
int64_t max_row_bytes_ = 0;
48+
};
49+
50+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)