Skip to content

Commit 13e11ac

Browse files
Jianbo Liufacebook-github-bot
authored andcommitted
Support get/set the whole row of metaheader+weight+optimizer from backend for checkpoint saving/loading (#4429)
Summary: X-link: facebookresearch/FBGEMM#1495 X-link: pytorch/torchrec#3148 # Context In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again. # This diff only contains backend change * updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend) Differential Revision: D77604158
1 parent c70539a commit 13e11ac

File tree

7 files changed

+322
-22
lines changed

7 files changed

+322
-22
lines changed

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h

Lines changed: 252 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
105105
int64_t num_shards = 8,
106106
int64_t num_threads = 32,
107107
int64_t row_storage_bitwidth = 32,
108+
bool backend_return_whole_row = false,
108109
bool enable_async_update = false,
109110
std::optional<at::Tensor> table_dims = std::nullopt,
110111
std::optional<at::Tensor> hash_size_cumsum = std::nullopt)
@@ -126,6 +127,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
126127
block_alignment_,
127128
/*blocks_per_chunk=*/8192)),
128129
elem_size_(row_storage_bitwidth / 8),
130+
backend_return_whole_row_(backend_return_whole_row),
129131
feature_evict_config_(feature_evict_config) {
130132
executor_ = std::make_unique<folly::CPUThreadPoolExecutor>(std::max<size_t>(
131133
num_threads, facebook::Proc::getCpuInfo().numCpuCores));
@@ -608,11 +610,15 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
608610
void set_range_to_storage(
609611
const at::Tensor& weights,
610612
const int64_t start,
611-
const int64_t length) {
612-
const auto seq_indices =
613-
at::arange(start, start + length, at::TensorOptions().dtype(at::kLong));
614-
const auto count = at::tensor({length}, at::ScalarType::Long);
615-
folly::coro::blockingWait(set_kv_db_async(seq_indices, weights, count));
613+
const int64_t length) override {
614+
if (backend_return_whole_row_) {
615+
set_kv_with_metaheader_to_storage(weights);
616+
} else {
617+
const auto seq_indices = at::arange(
618+
start, start + length, at::TensorOptions().dtype(at::kLong));
619+
const auto count = at::tensor({length}, at::ScalarType::Long);
620+
folly::coro::blockingWait(set_kv_db_async(seq_indices, weights, count));
621+
}
616622
}
617623

618624
void get_range_from_snapshot(
@@ -625,10 +631,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
625631
CHECK(snapshot_handle == nullptr);
626632
const auto seq_indices =
627633
at::arange(start, start + length, at::TensorOptions().dtype(at::kLong));
628-
const auto count = at::tensor({length}, at::ScalarType::Long);
629-
get_kv_db_async_impl(
630-
seq_indices, weights, count, width_offset, width_length)
631-
.wait();
634+
635+
if (backend_return_whole_row_) {
636+
get_kv_with_metaheader_from_storage(seq_indices, weights);
637+
} else {
638+
const auto count = at::tensor({length}, at::ScalarType::Long);
639+
get_kv_db_async_impl(
640+
seq_indices, weights, count, width_offset, width_length)
641+
.wait();
642+
}
643+
632644
// this is called by checkpoint mostly, and checkpoint should wait until
633645
// eviction finishes so that we could reacha consistent state before/after
634646
// state_dict() calls
@@ -642,8 +654,41 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
642654
int64_t width_offset = 0,
643655
std::optional<int64_t> width_length = std::nullopt) override {
644656
CHECK(snapshot_handle == nullptr);
657+
658+
if (backend_return_whole_row_) {
659+
get_kv_with_metaheader_from_storage(
660+
ids, weights, width_offset, width_length);
661+
} else {
662+
const auto count = at::tensor({ids.size(0)}, at::ScalarType::Long);
663+
get_kv_db_async_impl(ids, weights, count, width_offset, width_length)
664+
.wait();
665+
}
666+
}
667+
668+
// used for ckpt, get kv with metaheader from storage
669+
void get_kv_with_metaheader_from_storage(
670+
const at::Tensor& ids,
671+
const at::Tensor& weights_with_metaheader,
672+
int64_t width_offset = 0,
673+
std::optional<int64_t> width_length = std::nullopt) {
645674
const auto count = at::tensor({ids.size(0)}, at::ScalarType::Long);
646-
get_kv_db_async_impl(ids, weights, count, width_offset, width_length)
675+
get_kv_db_with_metaheader_async_impl(
676+
ids, weights_with_metaheader, count, width_offset, width_length)
677+
.wait();
678+
}
679+
680+
void set_kv_with_metaheader_to_storage(
681+
const at::Tensor& weights_with_metaheader) {
682+
std::vector<int64_t> keys(weights_with_metaheader.size(0), 0);
683+
for (int64_t i = 0; i < weights_with_metaheader.size(0); ++i) {
684+
keys[i] = FixedBlockPool::get_key(weights_with_metaheader[i].data_ptr());
685+
}
686+
auto indices =
687+
torch::from_blob(keys.data(), {int64_t(keys.size())}, torch::kInt64);
688+
const auto count =
689+
at::tensor({weights_with_metaheader.size(0)}, at::ScalarType::Long);
690+
set_kv_db_with_metaheader_async_impl(
691+
indices, weights_with_metaheader, count)
647692
.wait();
648693
// this is called by checkpoint mostly, and checkpoint should wait until
649694
// eviction finishes so that we could reacha consistent state before/after
@@ -826,6 +871,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
826871

827872
void flush_or_compact(const int64_t timestep) override {}
828873

874+
bool get_backend_return_whole_row() override {
875+
return backend_return_whole_row_;
876+
}
877+
878+
int64_t get_metaheader_width_in_front() override {
879+
return backend_return_whole_row_
880+
? FixedBlockPool::get_metaheader_dim<weight_type>()
881+
: 0;
882+
}
883+
829884
void resume_ongoing_eviction() override {
830885
if (feature_evict_) {
831886
feature_evict_->resume();
@@ -930,6 +985,192 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
930985
return ret;
931986
}
932987

988+
/// Get embeddings and metaheader from kvstore.
989+
///
990+
/// @param indices The 1D embedding index tensor, should skip on negative
991+
/// value
992+
/// @param weights_with_metaheader The 2D tensor that each row(embeddings) is
993+
/// paired up with relative element in <indices>. This tensor will be
994+
/// filled up with the returned embeddings from KVstore.
995+
/// @param count A single element tensor that contains the number of indices
996+
/// to be processed
997+
///
998+
/// @return None
999+
folly::SemiFuture<std::vector<folly::Unit>>
1000+
get_kv_db_with_metaheader_async_impl(
1001+
const at::Tensor& indices,
1002+
const at::Tensor& weights_with_metaheader,
1003+
const at::Tensor& count,
1004+
int64_t width_offset = 0,
1005+
std::optional<int64_t> width_length = std::nullopt) {
1006+
std::vector<folly::Future<folly::Unit>> futures;
1007+
auto row_width = weights_with_metaheader.size(1);
1008+
auto copy_width = width_length.value_or(row_width);
1009+
CHECK_LE(row_width, block_size_);
1010+
CHECK_EQ(copy_width, row_width);
1011+
auto shardid_to_indexes = shard_input(indices, count);
1012+
1013+
for (auto iter = shardid_to_indexes.begin();
1014+
iter != shardid_to_indexes.end();
1015+
iter++) {
1016+
const auto shard_id = iter->first;
1017+
const auto indexes = iter->second;
1018+
auto f =
1019+
folly::via(executor_.get())
1020+
.thenValue([this,
1021+
shard_id,
1022+
indexes,
1023+
&indices,
1024+
&weights_with_metaheader,
1025+
width_offset,
1026+
row_width](folly::Unit) {
1027+
FBGEMM_DISPATCH_INTEGRAL_TYPES(
1028+
indices.scalar_type(),
1029+
"dram_kvstore_get_with_metaheader",
1030+
[this,
1031+
shard_id,
1032+
indexes,
1033+
&indices,
1034+
&weights_with_metaheader,
1035+
width_offset,
1036+
row_width] {
1037+
using index_t = scalar_t;
1038+
CHECK(indices.is_contiguous());
1039+
CHECK(weights_with_metaheader.is_contiguous());
1040+
CHECK_EQ(
1041+
indices.size(0), weights_with_metaheader.size(0));
1042+
auto wlmap = kv_store_.by(shard_id).wlock();
1043+
auto indices_data_ptr = indices.data_ptr<index_t>();
1044+
auto weights_data_ptr =
1045+
weights_with_metaheader.data_ptr<weight_type>();
1046+
{
1047+
for (auto index_iter = indexes.begin();
1048+
index_iter != indexes.end();
1049+
index_iter++) {
1050+
const auto weights_row_index = *index_iter;
1051+
auto weight_idx =
1052+
int64_t(indices_data_ptr[weights_row_index]);
1053+
const auto cached_iter = wlmap->find(weight_idx);
1054+
// Defensive programming
1055+
// it shouldn't occur under normal circumstances
1056+
if (cached_iter == wlmap->end()) {
1057+
std::memset(
1058+
&(weights_data_ptr
1059+
[weights_row_index * row_width]),
1060+
0,
1061+
row_width);
1062+
continue;
1063+
}
1064+
1065+
// For weight KVT, offset=0 and it will read the whole
1066+
// row. For optimizer, offset=dim(metaheader) +
1067+
// emb_dim so it will only read the optimizer part
1068+
const auto* ptr_offset_from_front =
1069+
FixedBlockPool::ptr_offset_from_front<
1070+
weight_type>(
1071+
cached_iter->second, width_offset);
1072+
std::copy(
1073+
ptr_offset_from_front,
1074+
ptr_offset_from_front + row_width,
1075+
&(weights_data_ptr
1076+
[weights_row_index * row_width]));
1077+
}
1078+
}
1079+
});
1080+
});
1081+
futures.push_back(std::move(f));
1082+
}
1083+
return folly::collect(futures);
1084+
}
1085+
1086+
/// insert embeddings and metaheader into kvstore.
1087+
/// current underlying memory management is done through F14FastMap
1088+
/// key value pair will be sharded into multiple shards to increase
1089+
/// parallelism.
1090+
///
1091+
/// @param indices The 1D embedding index tensor, should skip on negative
1092+
/// value
1093+
/// @param weights_with_metaheader The 2D tensor that each row(embeddings with
1094+
/// metaheader) is paired up with relative element in <indices>
1095+
/// @param count A single element tensor that contains the number of indices
1096+
/// to be processed
1097+
///
1098+
/// @return None
1099+
folly::SemiFuture<std::vector<folly::Unit>>
1100+
set_kv_db_with_metaheader_async_impl(
1101+
const at::Tensor& indices,
1102+
const at::Tensor& weights_with_metaheader,
1103+
const at::Tensor& count) {
1104+
std::vector<folly::Future<folly::Unit>> futures;
1105+
auto shardid_to_indexes = shard_input(indices, count);
1106+
for (auto iter = shardid_to_indexes.begin();
1107+
iter != shardid_to_indexes.end();
1108+
iter++) {
1109+
const auto shard_id = iter->first;
1110+
const auto indexes = iter->second;
1111+
auto f =
1112+
folly::via(executor_.get())
1113+
.thenValue(
1114+
[this, shard_id, indexes, &indices, &weights_with_metaheader](
1115+
folly::Unit) {
1116+
FBGEMM_DISPATCH_INTEGRAL_TYPES(
1117+
indices.scalar_type(),
1118+
"dram_kv_set_with_metaheader",
1119+
[this,
1120+
shard_id,
1121+
indexes,
1122+
&indices,
1123+
&weights_with_metaheader] {
1124+
using index_t = scalar_t;
1125+
CHECK(indices.is_contiguous());
1126+
CHECK(weights_with_metaheader.is_contiguous());
1127+
CHECK_EQ(
1128+
indices.size(0), weights_with_metaheader.size(0));
1129+
{
1130+
auto wlmap = kv_store_.by(shard_id).wlock();
1131+
auto* pool = kv_store_.pool_by(shard_id);
1132+
int64_t stride = weights_with_metaheader.size(1);
1133+
auto indices_data_ptr = indices.data_ptr<index_t>();
1134+
auto weights_data_ptr =
1135+
weights_with_metaheader.data_ptr<weight_type>();
1136+
for (auto index_iter = indexes.begin();
1137+
index_iter != indexes.end();
1138+
index_iter++) {
1139+
const auto& id_index = *index_iter;
1140+
auto id = int64_t(indices_data_ptr[id_index]);
1141+
// Defensive programming
1142+
// it shouldn't occur under normal circumstances
1143+
auto used = FixedBlockPool::get_used(
1144+
weights_data_ptr + id_index * stride);
1145+
if (!used) {
1146+
continue;
1147+
}
1148+
// use mempool
1149+
weight_type* block = nullptr;
1150+
// First check if the key already exists
1151+
auto it = wlmap->find(id);
1152+
if (it != wlmap->end()) {
1153+
block = it->second;
1154+
} else {
1155+
// Key doesn't exist, allocate new block and
1156+
// insert.
1157+
block =
1158+
pool->template allocate_t<weight_type>();
1159+
wlmap->insert({id, block});
1160+
}
1161+
std::copy(
1162+
weights_data_ptr + id_index * stride,
1163+
weights_data_ptr + (id_index + 1) * stride,
1164+
block);
1165+
}
1166+
}
1167+
});
1168+
});
1169+
futures.push_back(std::move(f));
1170+
}
1171+
return folly::collect(futures);
1172+
}
1173+
9331174
std::unique_ptr<folly::CPUThreadPoolExecutor> executor_;
9341175
// background thread
9351176
folly::FunctionScheduler scheduler_;
@@ -942,6 +1183,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
9421183
std::atomic_bool is_eviction_ongoing_ = false;
9431184
std::vector<std::unique_ptr<ssd::Initializer>> initializers_;
9441185
int64_t elem_size_;
1186+
bool backend_return_whole_row_;
9451187
std::vector<int64_t> sub_table_dims_;
9461188
std::vector<int64_t> sub_table_hash_cumsum_;
9471189
std::optional<c10::intrusive_ptr<FeatureEvictConfig>> feature_evict_config_;

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
3737
int64_t row_storage_bitwidth = 32,
3838
const std::optional<at::Tensor>& table_dims = std::nullopt,
3939
const std::optional<at::Tensor>& hash_size_cumsum = std::nullopt,
40+
bool backend_return_whole_row = false,
4041
bool enable_async_update = false) {
4142
if (row_storage_bitwidth == 16) {
4243
impl_ = std::make_shared<kv_mem::DramKVEmbeddingCache<at::Half>>(
@@ -47,6 +48,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
4748
num_shards,
4849
num_threads,
4950
row_storage_bitwidth,
51+
backend_return_whole_row,
5052
enable_async_update,
5153
table_dims,
5254
hash_size_cumsum);
@@ -59,6 +61,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
5961
num_shards,
6062
num_threads,
6163
row_storage_bitwidth,
64+
backend_return_whole_row,
6265
enable_async_update,
6366
table_dims,
6467
hash_size_cumsum);

fbgemm_gpu/src/dram_kv_embedding_cache/fixed_block_pool.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ class FixedBlockPool : public std::pmr::memory_resource {
100100
return std::max(alignof(FixedBlockPool::MetaHeader), alignof(scalar_t));
101101
}
102102

103+
// Get dimension of Metaheader
104+
template <typename scalar_t>
105+
static size_t get_metaheader_dim() {
106+
return sizeof(FixedBlockPool::MetaHeader) / sizeof(scalar_t);
107+
}
108+
103109
// Data pointer retrieval
104110
template <typename scalar_t>
105111
static scalar_t* data_ptr(scalar_t* block) {
@@ -114,6 +120,22 @@ class FixedBlockPool : public std::pmr::memory_resource {
114120
sizeof(FixedBlockPool::MetaHeader));
115121
}
116122

123+
template <typename scalar_t>
124+
static scalar_t* ptr_offset_from_front(
125+
scalar_t* block,
126+
const int64_t offset) {
127+
return reinterpret_cast<scalar_t*>(
128+
reinterpret_cast<char*>(block) + offset * sizeof(scalar_t));
129+
}
130+
131+
template <typename scalar_t>
132+
static const scalar_t* ptr_offset_from_front(
133+
const scalar_t* block,
134+
const int64_t offset) {
135+
return reinterpret_cast<const scalar_t*>(
136+
reinterpret_cast<const char*>(block) + offset * sizeof(scalar_t));
137+
}
138+
117139
template <typename scalar_t>
118140
static scalar_t get_l2weight(scalar_t* block, size_t dimension) {
119141
scalar_t* data = FixedBlockPool::data_ptr(block);

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,16 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
402402
return max_D_;
403403
}
404404

405+
virtual bool get_backend_return_whole_row() {
406+
// only DRAM backend can enable this for now
407+
return false;
408+
}
409+
410+
virtual int64_t get_metaheader_width_in_front() {
411+
// will return non-zero if DRAM enables backend_return_whole_row
412+
return 0;
413+
}
414+
405415
#ifdef FBGEMM_FBCODE
406416
folly::coro::Task<void> tensor_stream(
407417
const at::Tensor& indices,

0 commit comments

Comments
 (0)