@@ -105,6 +105,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
105
105
int64_t num_shards = 8 ,
106
106
int64_t num_threads = 32 ,
107
107
int64_t row_storage_bitwidth = 32 ,
108
+ bool backend_return_whole_row = false ,
108
109
bool enable_async_update = false ,
109
110
std::optional<at::Tensor> table_dims = std::nullopt,
110
111
std::optional<at::Tensor> hash_size_cumsum = std::nullopt)
@@ -126,6 +127,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
126
127
block_alignment_,
127
128
/* blocks_per_chunk=*/ 8192 )),
128
129
elem_size_(row_storage_bitwidth / 8 ),
130
+ backend_return_whole_row_(backend_return_whole_row),
129
131
feature_evict_config_(feature_evict_config) {
130
132
executor_ = std::make_unique<folly::CPUThreadPoolExecutor>(std::max<size_t >(
131
133
num_threads, facebook::Proc::getCpuInfo ().numCpuCores ));
@@ -609,10 +611,14 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
609
611
const at::Tensor& weights,
610
612
const int64_t start,
611
613
const int64_t length) {
612
- const auto seq_indices =
613
- at::arange (start, start + length, at::TensorOptions ().dtype (at::kLong ));
614
- const auto count = at::tensor ({length}, at::ScalarType::Long);
615
- folly::coro::blockingWait (set_kv_db_async (seq_indices, weights, count));
614
+ if (backend_return_whole_row_) {
615
+ set_kv_with_metaheader_to_storage (weights);
616
+ } else {
617
+ const auto seq_indices = at::arange (
618
+ start, start + length, at::TensorOptions ().dtype (at::kLong ));
619
+ const auto count = at::tensor ({length}, at::ScalarType::Long);
620
+ folly::coro::blockingWait (set_kv_db_async (seq_indices, weights, count));
621
+ }
616
622
}
617
623
618
624
void get_range_from_snapshot (
@@ -625,10 +631,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
625
631
CHECK (snapshot_handle == nullptr );
626
632
const auto seq_indices =
627
633
at::arange (start, start + length, at::TensorOptions ().dtype (at::kLong ));
628
- const auto count = at::tensor ({length}, at::ScalarType::Long);
629
- get_kv_db_async_impl (
630
- seq_indices, weights, count, width_offset, width_length)
631
- .wait ();
634
+
635
+ if (backend_return_whole_row_) {
636
+ get_kv_with_metaheader_from_storage (seq_indices, weights);
637
+ } else {
638
+ const auto count = at::tensor ({length}, at::ScalarType::Long);
639
+ get_kv_db_async_impl (
640
+ seq_indices, weights, count, width_offset, width_length)
641
+ .wait ();
642
+ }
643
+
632
644
// this is called by checkpoint mostly, and checkpoint should wait until
633
645
// eviction finishes so that we could reacha consistent state before/after
634
646
// state_dict() calls
@@ -642,8 +654,41 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
642
654
int64_t width_offset = 0 ,
643
655
std::optional<int64_t > width_length = std::nullopt) override {
644
656
CHECK (snapshot_handle == nullptr );
657
+
658
+ if (backend_return_whole_row_) {
659
+ get_kv_with_metaheader_from_storage (
660
+ ids, weights, width_offset, width_length);
661
+ } else {
662
+ const auto count = at::tensor ({ids.size (0 )}, at::ScalarType::Long);
663
+ get_kv_db_async_impl (ids, weights, count, width_offset, width_length)
664
+ .wait ();
665
+ }
666
+ }
667
+
668
+ // used for ckpt, get kv with metaheader from storage
669
+ void get_kv_with_metaheader_from_storage (
670
+ const at::Tensor& ids,
671
+ const at::Tensor& weights_with_metaheader,
672
+ int64_t width_offset = 0 ,
673
+ std::optional<int64_t > width_length = std::nullopt) {
645
674
const auto count = at::tensor ({ids.size (0 )}, at::ScalarType::Long);
646
- get_kv_db_async_impl (ids, weights, count, width_offset, width_length)
675
+ get_kv_db_with_metaheader_async_impl (
676
+ ids, weights_with_metaheader, count, width_offset, width_length)
677
+ .wait ();
678
+ }
679
+
680
+ void set_kv_with_metaheader_to_storage (
681
+ const at::Tensor& weights_with_metaheader) {
682
+ std::vector<int64_t > keys (weights_with_metaheader.size (0 ), 0 );
683
+ for (int64_t i = 0 ; i < weights_with_metaheader.size (0 ); ++i) {
684
+ keys[i] = FixedBlockPool::get_key (weights_with_metaheader[i].data_ptr ());
685
+ }
686
+ auto indices =
687
+ torch::from_blob (keys.data (), {int64_t (keys.size ())}, torch::kInt64 );
688
+ const auto count =
689
+ at::tensor ({weights_with_metaheader.size (0 )}, at::ScalarType::Long);
690
+ set_kv_db_with_metaheader_async_impl (
691
+ indices, weights_with_metaheader, count)
647
692
.wait ();
648
693
// this is called by checkpoint mostly, and checkpoint should wait until
649
694
// eviction finishes so that we could reacha consistent state before/after
@@ -826,6 +871,10 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
826
871
827
872
void flush_or_compact (const int64_t timestep) override {}
828
873
874
+ bool get_backend_return_whole_row () override {
875
+ return backend_return_whole_row_;
876
+ }
877
+
829
878
void resume_ongoing_eviction () override {
830
879
if (feature_evict_) {
831
880
feature_evict_->resume ();
@@ -930,6 +979,192 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
930
979
return ret;
931
980
}
932
981
982
+ // / Get embeddings and metaheader from kvstore.
983
+ // /
984
+ // / @param indices The 1D embedding index tensor, should skip on negative
985
+ // / value
986
+ // / @param weights_with_metaheader The 2D tensor that each row(embeddings) is
987
+ // / paired up with relative element in <indices>. This tensor will be
988
+ // / filled up with the returned embeddings from KVstore.
989
+ // / @param count A single element tensor that contains the number of indices
990
+ // / to be processed
991
+ // /
992
+ // / @return None
993
+ folly::SemiFuture<std::vector<folly::Unit>>
994
+ get_kv_db_with_metaheader_async_impl (
995
+ const at::Tensor& indices,
996
+ const at::Tensor& weights_with_metaheader,
997
+ const at::Tensor& count,
998
+ int64_t width_offset = 0 ,
999
+ std::optional<int64_t > width_length = std::nullopt) {
1000
+ std::vector<folly::Future<folly::Unit>> futures;
1001
+ auto row_width = weights_with_metaheader.size (1 );
1002
+ auto copy_width = width_length.value_or (row_width);
1003
+ CHECK_LE (row_width, block_size_);
1004
+ CHECK_EQ (copy_width, row_width);
1005
+ auto shardid_to_indexes = shard_input (indices, count);
1006
+
1007
+ for (auto iter = shardid_to_indexes.begin ();
1008
+ iter != shardid_to_indexes.end ();
1009
+ iter++) {
1010
+ const auto shard_id = iter->first ;
1011
+ const auto indexes = iter->second ;
1012
+ auto f =
1013
+ folly::via (executor_.get ())
1014
+ .thenValue ([this ,
1015
+ shard_id,
1016
+ indexes,
1017
+ &indices,
1018
+ &weights_with_metaheader,
1019
+ width_offset,
1020
+ row_width](folly::Unit) {
1021
+ FBGEMM_DISPATCH_INTEGRAL_TYPES (
1022
+ indices.scalar_type (),
1023
+ " dram_kvstore_get_with_metaheader" ,
1024
+ [this ,
1025
+ shard_id,
1026
+ indexes,
1027
+ &indices,
1028
+ &weights_with_metaheader,
1029
+ width_offset,
1030
+ row_width] {
1031
+ using index_t = scalar_t ;
1032
+ CHECK (indices.is_contiguous ());
1033
+ CHECK (weights_with_metaheader.is_contiguous ());
1034
+ CHECK_EQ (
1035
+ indices.size (0 ), weights_with_metaheader.size (0 ));
1036
+ auto wlmap = kv_store_.by (shard_id).wlock ();
1037
+ auto indices_data_ptr = indices.data_ptr <index_t >();
1038
+ auto weights_data_ptr =
1039
+ weights_with_metaheader.data_ptr <weight_type>();
1040
+ {
1041
+ for (auto index_iter = indexes.begin ();
1042
+ index_iter != indexes.end ();
1043
+ index_iter++) {
1044
+ const auto weights_row_index = *index_iter;
1045
+ auto weight_idx =
1046
+ int64_t (indices_data_ptr[weights_row_index]);
1047
+ const auto cached_iter = wlmap->find (weight_idx);
1048
+ // Defensive programming
1049
+ // it shouldn't occur under normal circumstances
1050
+ if (cached_iter == wlmap->end ()) {
1051
+ std::memset (
1052
+ &(weights_data_ptr
1053
+ [weights_row_index * row_width]),
1054
+ 0 ,
1055
+ row_width);
1056
+ continue ;
1057
+ }
1058
+
1059
+ // For weight KVT, offset=0 and it will read the whole
1060
+ // row. For optimizer, offset=dim(metaheader) +
1061
+ // emb_dim so it will only read the optimizer part
1062
+ const auto * ptr_offset_from_front =
1063
+ FixedBlockPool::ptr_offset_from_front<
1064
+ weight_type>(
1065
+ cached_iter->second , width_offset);
1066
+ std::copy (
1067
+ ptr_offset_from_front,
1068
+ ptr_offset_from_front + row_width,
1069
+ &(weights_data_ptr
1070
+ [weights_row_index * row_width]));
1071
+ }
1072
+ }
1073
+ });
1074
+ });
1075
+ futures.push_back (std::move (f));
1076
+ }
1077
+ return folly::collect (futures);
1078
+ }
1079
+
1080
+ // / insert embeddings and metaheader into kvstore.
1081
+ // / current underlying memory management is done through F14FastMap
1082
+ // / key value pair will be sharded into multiple shards to increase
1083
+ // / parallelism.
1084
+ // /
1085
+ // / @param indices The 1D embedding index tensor, should skip on negative
1086
+ // / value
1087
+ // / @param weights_with_metaheader The 2D tensor that each row(embeddings with
1088
+ // / metaheader) is paired up with relative element in <indices>
1089
+ // / @param count A single element tensor that contains the number of indices
1090
+ // / to be processed
1091
+ // /
1092
+ // / @return None
1093
+ folly::SemiFuture<std::vector<folly::Unit>>
1094
+ set_kv_db_with_metaheader_async_impl (
1095
+ const at::Tensor& indices,
1096
+ const at::Tensor& weights_with_metaheader,
1097
+ const at::Tensor& count) {
1098
+ std::vector<folly::Future<folly::Unit>> futures;
1099
+ auto shardid_to_indexes = shard_input (indices, count);
1100
+ for (auto iter = shardid_to_indexes.begin ();
1101
+ iter != shardid_to_indexes.end ();
1102
+ iter++) {
1103
+ const auto shard_id = iter->first ;
1104
+ const auto indexes = iter->second ;
1105
+ auto f =
1106
+ folly::via (executor_.get ())
1107
+ .thenValue (
1108
+ [this , shard_id, indexes, &indices, &weights_with_metaheader](
1109
+ folly::Unit) {
1110
+ FBGEMM_DISPATCH_INTEGRAL_TYPES (
1111
+ indices.scalar_type (),
1112
+ " dram_kv_set_with_metaheader" ,
1113
+ [this ,
1114
+ shard_id,
1115
+ indexes,
1116
+ &indices,
1117
+ &weights_with_metaheader] {
1118
+ using index_t = scalar_t ;
1119
+ CHECK (indices.is_contiguous ());
1120
+ CHECK (weights_with_metaheader.is_contiguous ());
1121
+ CHECK_EQ (
1122
+ indices.size (0 ), weights_with_metaheader.size (0 ));
1123
+ {
1124
+ auto wlmap = kv_store_.by (shard_id).wlock ();
1125
+ auto * pool = kv_store_.pool_by (shard_id);
1126
+ int64_t stride = weights_with_metaheader.size (1 );
1127
+ auto indices_data_ptr = indices.data_ptr <index_t >();
1128
+ auto weights_data_ptr =
1129
+ weights_with_metaheader.data_ptr <weight_type>();
1130
+ for (auto index_iter = indexes.begin ();
1131
+ index_iter != indexes.end ();
1132
+ index_iter++) {
1133
+ const auto & id_index = *index_iter;
1134
+ auto id = int64_t (indices_data_ptr[id_index]);
1135
+ // Defensive programming
1136
+ // it shouldn't occur under normal circumstances
1137
+ auto used = FixedBlockPool::get_used (
1138
+ weights_data_ptr + id_index * stride);
1139
+ if (!used) {
1140
+ continue ;
1141
+ }
1142
+ // use mempool
1143
+ weight_type* block = nullptr ;
1144
+ // First check if the key already exists
1145
+ auto it = wlmap->find (id);
1146
+ if (it != wlmap->end ()) {
1147
+ block = it->second ;
1148
+ } else {
1149
+ // Key doesn't exist, allocate new block and
1150
+ // insert.
1151
+ block =
1152
+ pool->template allocate_t <weight_type>();
1153
+ wlmap->insert ({id, block});
1154
+ }
1155
+ std::copy (
1156
+ weights_data_ptr + id_index * stride,
1157
+ weights_data_ptr + (id_index + 1 ) * stride,
1158
+ block);
1159
+ }
1160
+ }
1161
+ });
1162
+ });
1163
+ futures.push_back (std::move (f));
1164
+ }
1165
+ return folly::collect (futures);
1166
+ }
1167
+
933
1168
std::unique_ptr<folly::CPUThreadPoolExecutor> executor_;
934
1169
// background thread
935
1170
folly::FunctionScheduler scheduler_;
@@ -942,6 +1177,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
942
1177
std::atomic_bool is_eviction_ongoing_ = false ;
943
1178
std::vector<std::unique_ptr<ssd::Initializer>> initializers_;
944
1179
int64_t elem_size_;
1180
+ bool backend_return_whole_row_;
945
1181
std::vector<int64_t > sub_table_dims_;
946
1182
std::vector<int64_t > sub_table_hash_cumsum_;
947
1183
std::optional<c10::intrusive_ptr<FeatureEvictConfig>> feature_evict_config_;
0 commit comments