Skip to content

Commit c3c4f7c

Browse files
Jianbo Liufacebook-github-bot
authored andcommitted
Support get/set the whole row of metaheader+weight+optimizer from backend for checkpoint saving/loading (#4429)
Summary: X-link: facebookresearch/FBGEMM#1495 X-link: pytorch/torchrec#3148 # Context In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again. # This diff only contains backend change * updated dram backend and mem pool, so it can return the metaheader + weight + optimizer_state together, as well as set them back to backend (use pointers to skip metaheader part when write weight+opt to backend) Reviewed By: emlin Differential Revision: D77604158
1 parent 0dbd1bc commit c3c4f7c

File tree

8 files changed

+320
-25
lines changed

8 files changed

+320
-25
lines changed

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h

Lines changed: 245 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
105105
int64_t num_shards = 8,
106106
int64_t num_threads = 32,
107107
int64_t row_storage_bitwidth = 32,
108+
bool backend_return_whole_row = false,
108109
bool enable_async_update = false,
109110
std::optional<at::Tensor> table_dims = std::nullopt,
110111
std::optional<at::Tensor> hash_size_cumsum = std::nullopt)
@@ -126,6 +127,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
126127
block_alignment_,
127128
/*blocks_per_chunk=*/8192)),
128129
elem_size_(row_storage_bitwidth / 8),
130+
backend_return_whole_row_(backend_return_whole_row),
129131
feature_evict_config_(feature_evict_config) {
130132
executor_ = std::make_unique<folly::CPUThreadPoolExecutor>(std::max<size_t>(
131133
num_threads, facebook::Proc::getCpuInfo().numCpuCores));
@@ -609,10 +611,14 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
609611
const at::Tensor& weights,
610612
const int64_t start,
611613
const int64_t length) {
612-
const auto seq_indices =
613-
at::arange(start, start + length, at::TensorOptions().dtype(at::kLong));
614-
const auto count = at::tensor({length}, at::ScalarType::Long);
615-
folly::coro::blockingWait(set_kv_db_async(seq_indices, weights, count));
614+
if (backend_return_whole_row_) {
615+
set_kv_with_metaheader_to_storage(weights);
616+
} else {
617+
const auto seq_indices = at::arange(
618+
start, start + length, at::TensorOptions().dtype(at::kLong));
619+
const auto count = at::tensor({length}, at::ScalarType::Long);
620+
folly::coro::blockingWait(set_kv_db_async(seq_indices, weights, count));
621+
}
616622
}
617623

618624
void get_range_from_snapshot(
@@ -625,10 +631,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
625631
CHECK(snapshot_handle == nullptr);
626632
const auto seq_indices =
627633
at::arange(start, start + length, at::TensorOptions().dtype(at::kLong));
628-
const auto count = at::tensor({length}, at::ScalarType::Long);
629-
get_kv_db_async_impl(
630-
seq_indices, weights, count, width_offset, width_length)
631-
.wait();
634+
635+
if (backend_return_whole_row_) {
636+
get_kv_with_metaheader_from_storage(seq_indices, weights);
637+
} else {
638+
const auto count = at::tensor({length}, at::ScalarType::Long);
639+
get_kv_db_async_impl(
640+
seq_indices, weights, count, width_offset, width_length)
641+
.wait();
642+
}
643+
632644
// this is called by checkpoint mostly, and checkpoint should wait until
633645
// eviction finishes so that we could reacha consistent state before/after
634646
// state_dict() calls
@@ -642,8 +654,41 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
642654
int64_t width_offset = 0,
643655
std::optional<int64_t> width_length = std::nullopt) override {
644656
CHECK(snapshot_handle == nullptr);
657+
658+
if (backend_return_whole_row_) {
659+
get_kv_with_metaheader_from_storage(
660+
ids, weights, width_offset, width_length);
661+
} else {
662+
const auto count = at::tensor({ids.size(0)}, at::ScalarType::Long);
663+
get_kv_db_async_impl(ids, weights, count, width_offset, width_length)
664+
.wait();
665+
}
666+
}
667+
668+
// used for ckpt, get kv with metaheader from storage
669+
void get_kv_with_metaheader_from_storage(
670+
const at::Tensor& ids,
671+
const at::Tensor& weights_with_metaheader,
672+
int64_t width_offset = 0,
673+
std::optional<int64_t> width_length = std::nullopt) {
645674
const auto count = at::tensor({ids.size(0)}, at::ScalarType::Long);
646-
get_kv_db_async_impl(ids, weights, count, width_offset, width_length)
675+
get_kv_db_with_metaheader_async_impl(
676+
ids, weights_with_metaheader, count, width_offset, width_length)
677+
.wait();
678+
}
679+
680+
void set_kv_with_metaheader_to_storage(
681+
const at::Tensor& weights_with_metaheader) {
682+
std::vector<int64_t> keys(weights_with_metaheader.size(0), 0);
683+
for (int64_t i = 0; i < weights_with_metaheader.size(0); ++i) {
684+
keys[i] = FixedBlockPool::get_key(weights_with_metaheader[i].data_ptr());
685+
}
686+
auto indices =
687+
torch::from_blob(keys.data(), {int64_t(keys.size())}, torch::kInt64);
688+
const auto count =
689+
at::tensor({weights_with_metaheader.size(0)}, at::ScalarType::Long);
690+
set_kv_db_with_metaheader_async_impl(
691+
indices, weights_with_metaheader, count)
647692
.wait();
648693
// this is called by checkpoint mostly, and checkpoint should wait until
649694
// eviction finishes so that we could reacha consistent state before/after
@@ -826,6 +871,10 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
826871

827872
void flush_or_compact(const int64_t timestep) override {}
828873

874+
bool get_backend_return_whole_row() override {
875+
return backend_return_whole_row_;
876+
}
877+
829878
void resume_ongoing_eviction() override {
830879
if (feature_evict_) {
831880
feature_evict_->resume();
@@ -930,6 +979,192 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
930979
return ret;
931980
}
932981

982+
/// Get embeddings and metaheader from kvstore.
983+
///
984+
/// @param indices The 1D embedding index tensor, should skip on negative
985+
/// value
986+
/// @param weights_with_metaheader The 2D tensor that each row(embeddings) is
987+
/// paired up with relative element in <indices>. This tensor will be
988+
/// filled up with the returned embeddings from KVstore.
989+
/// @param count A single element tensor that contains the number of indices
990+
/// to be processed
991+
///
992+
/// @return None
993+
folly::SemiFuture<std::vector<folly::Unit>>
994+
get_kv_db_with_metaheader_async_impl(
995+
const at::Tensor& indices,
996+
const at::Tensor& weights_with_metaheader,
997+
const at::Tensor& count,
998+
int64_t width_offset = 0,
999+
std::optional<int64_t> width_length = std::nullopt) {
1000+
std::vector<folly::Future<folly::Unit>> futures;
1001+
auto row_width = weights_with_metaheader.size(1);
1002+
auto copy_width = width_length.value_or(row_width);
1003+
CHECK_LE(row_width, block_size_);
1004+
CHECK_EQ(copy_width, row_width);
1005+
auto shardid_to_indexes = shard_input(indices, count);
1006+
1007+
for (auto iter = shardid_to_indexes.begin();
1008+
iter != shardid_to_indexes.end();
1009+
iter++) {
1010+
const auto shard_id = iter->first;
1011+
const auto indexes = iter->second;
1012+
auto f =
1013+
folly::via(executor_.get())
1014+
.thenValue([this,
1015+
shard_id,
1016+
indexes,
1017+
&indices,
1018+
&weights_with_metaheader,
1019+
width_offset,
1020+
row_width](folly::Unit) {
1021+
FBGEMM_DISPATCH_INTEGRAL_TYPES(
1022+
indices.scalar_type(),
1023+
"dram_kvstore_get_with_metaheader",
1024+
[this,
1025+
shard_id,
1026+
indexes,
1027+
&indices,
1028+
&weights_with_metaheader,
1029+
width_offset,
1030+
row_width] {
1031+
using index_t = scalar_t;
1032+
CHECK(indices.is_contiguous());
1033+
CHECK(weights_with_metaheader.is_contiguous());
1034+
CHECK_EQ(
1035+
indices.size(0), weights_with_metaheader.size(0));
1036+
auto wlmap = kv_store_.by(shard_id).wlock();
1037+
auto indices_data_ptr = indices.data_ptr<index_t>();
1038+
auto weights_data_ptr =
1039+
weights_with_metaheader.data_ptr<weight_type>();
1040+
{
1041+
for (auto index_iter = indexes.begin();
1042+
index_iter != indexes.end();
1043+
index_iter++) {
1044+
const auto weights_row_index = *index_iter;
1045+
auto weight_idx =
1046+
int64_t(indices_data_ptr[weights_row_index]);
1047+
const auto cached_iter = wlmap->find(weight_idx);
1048+
// Defensive programming
1049+
// it shouldn't occur under normal circumstances
1050+
if (cached_iter == wlmap->end()) {
1051+
std::memset(
1052+
&(weights_data_ptr
1053+
[weights_row_index * row_width]),
1054+
0,
1055+
row_width);
1056+
continue;
1057+
}
1058+
1059+
// For weight KVT, offset=0 and it will read the whole
1060+
// row. For optimizer, offset=dim(metaheader) +
1061+
// emb_dim so it will only read the optimizer part
1062+
const auto* ptr_offset_from_front =
1063+
FixedBlockPool::ptr_offset_from_front<
1064+
weight_type>(
1065+
cached_iter->second, width_offset);
1066+
std::copy(
1067+
ptr_offset_from_front,
1068+
ptr_offset_from_front + row_width,
1069+
&(weights_data_ptr
1070+
[weights_row_index * row_width]));
1071+
}
1072+
}
1073+
});
1074+
});
1075+
futures.push_back(std::move(f));
1076+
}
1077+
return folly::collect(futures);
1078+
}
1079+
1080+
/// insert embeddings and metaheader into kvstore.
1081+
/// current underlying memory management is done through F14FastMap
1082+
/// key value pair will be sharded into multiple shards to increase
1083+
/// parallelism.
1084+
///
1085+
/// @param indices The 1D embedding index tensor, should skip on negative
1086+
/// value
1087+
/// @param weights_with_metaheader The 2D tensor that each row(embeddings with
1088+
/// metaheader) is paired up with relative element in <indices>
1089+
/// @param count A single element tensor that contains the number of indices
1090+
/// to be processed
1091+
///
1092+
/// @return None
1093+
folly::SemiFuture<std::vector<folly::Unit>>
1094+
set_kv_db_with_metaheader_async_impl(
1095+
const at::Tensor& indices,
1096+
const at::Tensor& weights_with_metaheader,
1097+
const at::Tensor& count) {
1098+
std::vector<folly::Future<folly::Unit>> futures;
1099+
auto shardid_to_indexes = shard_input(indices, count);
1100+
for (auto iter = shardid_to_indexes.begin();
1101+
iter != shardid_to_indexes.end();
1102+
iter++) {
1103+
const auto shard_id = iter->first;
1104+
const auto indexes = iter->second;
1105+
auto f =
1106+
folly::via(executor_.get())
1107+
.thenValue(
1108+
[this, shard_id, indexes, &indices, &weights_with_metaheader](
1109+
folly::Unit) {
1110+
FBGEMM_DISPATCH_INTEGRAL_TYPES(
1111+
indices.scalar_type(),
1112+
"dram_kv_set_with_metaheader",
1113+
[this,
1114+
shard_id,
1115+
indexes,
1116+
&indices,
1117+
&weights_with_metaheader] {
1118+
using index_t = scalar_t;
1119+
CHECK(indices.is_contiguous());
1120+
CHECK(weights_with_metaheader.is_contiguous());
1121+
CHECK_EQ(
1122+
indices.size(0), weights_with_metaheader.size(0));
1123+
{
1124+
auto wlmap = kv_store_.by(shard_id).wlock();
1125+
auto* pool = kv_store_.pool_by(shard_id);
1126+
int64_t stride = weights_with_metaheader.size(1);
1127+
auto indices_data_ptr = indices.data_ptr<index_t>();
1128+
auto weights_data_ptr =
1129+
weights_with_metaheader.data_ptr<weight_type>();
1130+
for (auto index_iter = indexes.begin();
1131+
index_iter != indexes.end();
1132+
index_iter++) {
1133+
const auto& id_index = *index_iter;
1134+
auto id = int64_t(indices_data_ptr[id_index]);
1135+
// Defensive programming
1136+
// it shouldn't occur under normal circumstances
1137+
auto used = FixedBlockPool::get_used(
1138+
weights_data_ptr + id_index * stride);
1139+
if (!used) {
1140+
continue;
1141+
}
1142+
// use mempool
1143+
weight_type* block = nullptr;
1144+
// First check if the key already exists
1145+
auto it = wlmap->find(id);
1146+
if (it != wlmap->end()) {
1147+
block = it->second;
1148+
} else {
1149+
// Key doesn't exist, allocate new block and
1150+
// insert.
1151+
block =
1152+
pool->template allocate_t<weight_type>();
1153+
wlmap->insert({id, block});
1154+
}
1155+
std::copy(
1156+
weights_data_ptr + id_index * stride,
1157+
weights_data_ptr + (id_index + 1) * stride,
1158+
block);
1159+
}
1160+
}
1161+
});
1162+
});
1163+
futures.push_back(std::move(f));
1164+
}
1165+
return folly::collect(futures);
1166+
}
1167+
9331168
std::unique_ptr<folly::CPUThreadPoolExecutor> executor_;
9341169
// background thread
9351170
folly::FunctionScheduler scheduler_;
@@ -942,6 +1177,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
9421177
std::atomic_bool is_eviction_ongoing_ = false;
9431178
std::vector<std::unique_ptr<ssd::Initializer>> initializers_;
9441179
int64_t elem_size_;
1180+
bool backend_return_whole_row_;
9451181
std::vector<int64_t> sub_table_dims_;
9461182
std::vector<int64_t> sub_table_hash_cumsum_;
9471183
std::optional<c10::intrusive_ptr<FeatureEvictConfig>> feature_evict_config_;

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
3737
int64_t row_storage_bitwidth = 32,
3838
const std::optional<at::Tensor>& table_dims = std::nullopt,
3939
const std::optional<at::Tensor>& hash_size_cumsum = std::nullopt,
40+
bool backend_return_whole_row = false,
4041
bool enable_async_update = false) {
4142
if (row_storage_bitwidth == 16) {
4243
impl_ = std::make_shared<kv_mem::DramKVEmbeddingCache<at::Half>>(
@@ -47,6 +48,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
4748
num_shards,
4849
num_threads,
4950
row_storage_bitwidth,
51+
backend_return_whole_row,
5052
enable_async_update,
5153
table_dims,
5254
hash_size_cumsum);
@@ -59,6 +61,7 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
5961
num_shards,
6062
num_threads,
6163
row_storage_bitwidth,
64+
backend_return_whole_row,
6265
enable_async_update,
6366
table_dims,
6467
hash_size_cumsum);

fbgemm_gpu/src/dram_kv_embedding_cache/fixed_block_pool.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,22 @@ class FixedBlockPool : public std::pmr::memory_resource {
114114
sizeof(FixedBlockPool::MetaHeader));
115115
}
116116

117+
template <typename scalar_t>
118+
static scalar_t* ptr_offset_from_front(
119+
scalar_t* block,
120+
const int64_t offset) {
121+
return reinterpret_cast<scalar_t*>(
122+
reinterpret_cast<char*>(block) + offset * sizeof(scalar_t));
123+
}
124+
125+
template <typename scalar_t>
126+
static const scalar_t* ptr_offset_from_front(
127+
const scalar_t* block,
128+
const int64_t offset) {
129+
return reinterpret_cast<const scalar_t*>(
130+
reinterpret_cast<const char*>(block) + offset * sizeof(scalar_t));
131+
}
132+
117133
template <typename scalar_t>
118134
static scalar_t get_l2weight(scalar_t* block, size_t dimension) {
119135
scalar_t* data = FixedBlockPool::data_ptr(block);

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,10 +357,10 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
357357
const at::Tensor& weights,
358358
const int64_t start,
359359
const int64_t length) {
360-
const auto seq_indices =
361-
at::arange(start, start + length, at::TensorOptions().dtype(at::kLong));
362-
const auto count = at::tensor({length}, at::ScalarType::Long);
363-
folly::coro::blockingWait(set_kv_db_async(seq_indices, weights, count));
360+
(void)weights;
361+
(void)start;
362+
(void)length;
363+
FBEXCEPTION("Not implemented");
364364
}
365365

366366
virtual void get_range_from_snapshot(
@@ -402,6 +402,11 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
402402
return max_D_;
403403
}
404404

405+
virtual bool get_backend_return_whole_row() {
406+
// only DRAM backend can enable this for now
407+
return false;
408+
}
409+
405410
#ifdef FBGEMM_FBCODE
406411
folly::coro::Task<void> tensor_stream(
407412
const at::Tensor& indices,

0 commit comments

Comments
 (0)