@@ -105,6 +105,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
105
105
int64_t num_shards = 8 ,
106
106
int64_t num_threads = 32 ,
107
107
int64_t row_storage_bitwidth = 32 ,
108
+ bool backend_return_whole_row = false ,
108
109
bool enable_async_update = false ,
109
110
std::optional<at::Tensor> table_dims = std::nullopt,
110
111
std::optional<at::Tensor> hash_size_cumsum = std::nullopt)
@@ -126,6 +127,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
126
127
block_alignment_,
127
128
/* blocks_per_chunk=*/ 8192 )),
128
129
elem_size_(row_storage_bitwidth / 8 ),
130
+ backend_return_whole_row_(backend_return_whole_row),
129
131
feature_evict_config_(feature_evict_config) {
130
132
executor_ = std::make_unique<folly::CPUThreadPoolExecutor>(std::max<size_t >(
131
133
num_threads, facebook::Proc::getCpuInfo ().numCpuCores ));
@@ -608,11 +610,15 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
608
610
void set_range_to_storage (
609
611
const at::Tensor& weights,
610
612
const int64_t start,
611
- const int64_t length) {
612
- const auto seq_indices =
613
- at::arange (start, start + length, at::TensorOptions ().dtype (at::kLong ));
614
- const auto count = at::tensor ({length}, at::ScalarType::Long);
615
- folly::coro::blockingWait (set_kv_db_async (seq_indices, weights, count));
613
+ const int64_t length) override {
614
+ if (backend_return_whole_row_) {
615
+ set_kv_with_metaheader_to_storage (weights);
616
+ } else {
617
+ const auto seq_indices = at::arange (
618
+ start, start + length, at::TensorOptions ().dtype (at::kLong ));
619
+ const auto count = at::tensor ({length}, at::ScalarType::Long);
620
+ folly::coro::blockingWait (set_kv_db_async (seq_indices, weights, count));
621
+ }
616
622
}
617
623
618
624
void get_range_from_snapshot (
@@ -625,10 +631,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
625
631
CHECK (snapshot_handle == nullptr );
626
632
const auto seq_indices =
627
633
at::arange (start, start + length, at::TensorOptions ().dtype (at::kLong ));
628
- const auto count = at::tensor ({length}, at::ScalarType::Long);
629
- get_kv_db_async_impl (
630
- seq_indices, weights, count, width_offset, width_length)
631
- .wait ();
634
+
635
+ if (backend_return_whole_row_) {
636
+ get_kv_with_metaheader_from_storage (seq_indices, weights);
637
+ } else {
638
+ const auto count = at::tensor ({length}, at::ScalarType::Long);
639
+ get_kv_db_async_impl (
640
+ seq_indices, weights, count, width_offset, width_length)
641
+ .wait ();
642
+ }
643
+
632
644
// this is called by checkpoint mostly, and checkpoint should wait until
633
645
// eviction finishes so that we could reacha consistent state before/after
634
646
// state_dict() calls
@@ -642,8 +654,41 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
642
654
int64_t width_offset = 0 ,
643
655
std::optional<int64_t > width_length = std::nullopt) override {
644
656
CHECK (snapshot_handle == nullptr );
657
+
658
+ if (backend_return_whole_row_) {
659
+ get_kv_with_metaheader_from_storage (
660
+ ids, weights, width_offset, width_length);
661
+ } else {
662
+ const auto count = at::tensor ({ids.size (0 )}, at::ScalarType::Long);
663
+ get_kv_db_async_impl (ids, weights, count, width_offset, width_length)
664
+ .wait ();
665
+ }
666
+ }
667
+
668
+ // used for ckpt, get kv with metaheader from storage
669
+ void get_kv_with_metaheader_from_storage (
670
+ const at::Tensor& ids,
671
+ const at::Tensor& weights_with_metaheader,
672
+ int64_t width_offset = 0 ,
673
+ std::optional<int64_t > width_length = std::nullopt) {
645
674
const auto count = at::tensor ({ids.size (0 )}, at::ScalarType::Long);
646
- get_kv_db_async_impl (ids, weights, count, width_offset, width_length)
675
+ get_kv_db_with_metaheader_async_impl (
676
+ ids, weights_with_metaheader, count, width_offset, width_length)
677
+ .wait ();
678
+ }
679
+
680
+ void set_kv_with_metaheader_to_storage (
681
+ const at::Tensor& weights_with_metaheader) {
682
+ std::vector<int64_t > keys (weights_with_metaheader.size (0 ), 0 );
683
+ for (int64_t i = 0 ; i < weights_with_metaheader.size (0 ); ++i) {
684
+ keys[i] = FixedBlockPool::get_key (weights_with_metaheader[i].data_ptr ());
685
+ }
686
+ auto indices =
687
+ torch::from_blob (keys.data (), {int64_t (keys.size ())}, torch::kInt64 );
688
+ const auto count =
689
+ at::tensor ({weights_with_metaheader.size (0 )}, at::ScalarType::Long);
690
+ set_kv_db_with_metaheader_async_impl (
691
+ indices, weights_with_metaheader, count)
647
692
.wait ();
648
693
// this is called by checkpoint mostly, and checkpoint should wait until
649
694
// eviction finishes so that we could reacha consistent state before/after
@@ -826,6 +871,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
826
871
827
872
void flush_or_compact (const int64_t timestep) override {}
828
873
874
+ bool get_backend_return_whole_row () override {
875
+ return backend_return_whole_row_;
876
+ }
877
+
878
+ int64_t get_metaheader_width_in_front () override {
879
+ return backend_return_whole_row_
880
+ ? FixedBlockPool::get_metaheader_dim<weight_type>()
881
+ : 0 ;
882
+ }
883
+
829
884
void resume_ongoing_eviction () override {
830
885
if (feature_evict_) {
831
886
feature_evict_->resume ();
@@ -930,6 +985,192 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
930
985
return ret;
931
986
}
932
987
988
+ // / Get embeddings and metaheader from kvstore.
989
+ // /
990
+ // / @param indices The 1D embedding index tensor, should skip on negative
991
+ // / value
992
+ // / @param weights_with_metaheader The 2D tensor that each row(embeddings) is
993
+ // / paired up with relative element in <indices>. This tensor will be
994
+ // / filled up with the returned embeddings from KVstore.
995
+ // / @param count A single element tensor that contains the number of indices
996
+ // / to be processed
997
+ // /
998
+ // / @return None
999
+ folly::SemiFuture<std::vector<folly::Unit>>
1000
+ get_kv_db_with_metaheader_async_impl (
1001
+ const at::Tensor& indices,
1002
+ const at::Tensor& weights_with_metaheader,
1003
+ const at::Tensor& count,
1004
+ int64_t width_offset = 0 ,
1005
+ std::optional<int64_t > width_length = std::nullopt) {
1006
+ std::vector<folly::Future<folly::Unit>> futures;
1007
+ auto row_width = weights_with_metaheader.size (1 );
1008
+ auto copy_width = width_length.value_or (row_width);
1009
+ CHECK_LE (row_width, block_size_);
1010
+ CHECK_EQ (copy_width, row_width);
1011
+ auto shardid_to_indexes = shard_input (indices, count);
1012
+
1013
+ for (auto iter = shardid_to_indexes.begin ();
1014
+ iter != shardid_to_indexes.end ();
1015
+ iter++) {
1016
+ const auto shard_id = iter->first ;
1017
+ const auto indexes = iter->second ;
1018
+ auto f =
1019
+ folly::via (executor_.get ())
1020
+ .thenValue ([this ,
1021
+ shard_id,
1022
+ indexes,
1023
+ &indices,
1024
+ &weights_with_metaheader,
1025
+ width_offset,
1026
+ row_width](folly::Unit) {
1027
+ FBGEMM_DISPATCH_INTEGRAL_TYPES (
1028
+ indices.scalar_type (),
1029
+ " dram_kvstore_get_with_metaheader" ,
1030
+ [this ,
1031
+ shard_id,
1032
+ indexes,
1033
+ &indices,
1034
+ &weights_with_metaheader,
1035
+ width_offset,
1036
+ row_width] {
1037
+ using index_t = scalar_t ;
1038
+ CHECK (indices.is_contiguous ());
1039
+ CHECK (weights_with_metaheader.is_contiguous ());
1040
+ CHECK_EQ (
1041
+ indices.size (0 ), weights_with_metaheader.size (0 ));
1042
+ auto wlmap = kv_store_.by (shard_id).wlock ();
1043
+ auto indices_data_ptr = indices.data_ptr <index_t >();
1044
+ auto weights_data_ptr =
1045
+ weights_with_metaheader.data_ptr <weight_type>();
1046
+ {
1047
+ for (auto index_iter = indexes.begin ();
1048
+ index_iter != indexes.end ();
1049
+ index_iter++) {
1050
+ const auto weights_row_index = *index_iter;
1051
+ auto weight_idx =
1052
+ int64_t (indices_data_ptr[weights_row_index]);
1053
+ const auto cached_iter = wlmap->find (weight_idx);
1054
+ // Defensive programming
1055
+ // it shouldn't occur under normal circumstances
1056
+ if (cached_iter == wlmap->end ()) {
1057
+ std::memset (
1058
+ &(weights_data_ptr
1059
+ [weights_row_index * row_width]),
1060
+ 0 ,
1061
+ row_width);
1062
+ continue ;
1063
+ }
1064
+
1065
+ // For weight KVT, offset=0 and it will read the whole
1066
+ // row. For optimizer, offset=dim(metaheader) +
1067
+ // emb_dim so it will only read the optimizer part
1068
+ const auto * ptr_offset_from_front =
1069
+ FixedBlockPool::ptr_offset_from_front<
1070
+ weight_type>(
1071
+ cached_iter->second , width_offset);
1072
+ std::copy (
1073
+ ptr_offset_from_front,
1074
+ ptr_offset_from_front + row_width,
1075
+ &(weights_data_ptr
1076
+ [weights_row_index * row_width]));
1077
+ }
1078
+ }
1079
+ });
1080
+ });
1081
+ futures.push_back (std::move (f));
1082
+ }
1083
+ return folly::collect (futures);
1084
+ }
1085
+
1086
+ // / insert embeddings and metaheader into kvstore.
1087
+ // / current underlying memory management is done through F14FastMap
1088
+ // / key value pair will be sharded into multiple shards to increase
1089
+ // / parallelism.
1090
+ // /
1091
+ // / @param indices The 1D embedding index tensor, should skip on negative
1092
+ // / value
1093
+ // / @param weights_with_metaheader The 2D tensor that each row(embeddings with
1094
+ // / metaheader) is paired up with relative element in <indices>
1095
+ // / @param count A single element tensor that contains the number of indices
1096
+ // / to be processed
1097
+ // /
1098
+ // / @return None
1099
+ folly::SemiFuture<std::vector<folly::Unit>>
1100
+ set_kv_db_with_metaheader_async_impl (
1101
+ const at::Tensor& indices,
1102
+ const at::Tensor& weights_with_metaheader,
1103
+ const at::Tensor& count) {
1104
+ std::vector<folly::Future<folly::Unit>> futures;
1105
+ auto shardid_to_indexes = shard_input (indices, count);
1106
+ for (auto iter = shardid_to_indexes.begin ();
1107
+ iter != shardid_to_indexes.end ();
1108
+ iter++) {
1109
+ const auto shard_id = iter->first ;
1110
+ const auto indexes = iter->second ;
1111
+ auto f =
1112
+ folly::via (executor_.get ())
1113
+ .thenValue (
1114
+ [this , shard_id, indexes, &indices, &weights_with_metaheader](
1115
+ folly::Unit) {
1116
+ FBGEMM_DISPATCH_INTEGRAL_TYPES (
1117
+ indices.scalar_type (),
1118
+ " dram_kv_set_with_metaheader" ,
1119
+ [this ,
1120
+ shard_id,
1121
+ indexes,
1122
+ &indices,
1123
+ &weights_with_metaheader] {
1124
+ using index_t = scalar_t ;
1125
+ CHECK (indices.is_contiguous ());
1126
+ CHECK (weights_with_metaheader.is_contiguous ());
1127
+ CHECK_EQ (
1128
+ indices.size (0 ), weights_with_metaheader.size (0 ));
1129
+ {
1130
+ auto wlmap = kv_store_.by (shard_id).wlock ();
1131
+ auto * pool = kv_store_.pool_by (shard_id);
1132
+ int64_t stride = weights_with_metaheader.size (1 );
1133
+ auto indices_data_ptr = indices.data_ptr <index_t >();
1134
+ auto weights_data_ptr =
1135
+ weights_with_metaheader.data_ptr <weight_type>();
1136
+ for (auto index_iter = indexes.begin ();
1137
+ index_iter != indexes.end ();
1138
+ index_iter++) {
1139
+ const auto & id_index = *index_iter;
1140
+ auto id = int64_t (indices_data_ptr[id_index]);
1141
+ // Defensive programming
1142
+ // it shouldn't occur under normal circumstances
1143
+ auto used = FixedBlockPool::get_used (
1144
+ weights_data_ptr + id_index * stride);
1145
+ if (!used) {
1146
+ continue ;
1147
+ }
1148
+ // use mempool
1149
+ weight_type* block = nullptr ;
1150
+ // First check if the key already exists
1151
+ auto it = wlmap->find (id);
1152
+ if (it != wlmap->end ()) {
1153
+ block = it->second ;
1154
+ } else {
1155
+ // Key doesn't exist, allocate new block and
1156
+ // insert.
1157
+ block =
1158
+ pool->template allocate_t <weight_type>();
1159
+ wlmap->insert ({id, block});
1160
+ }
1161
+ std::copy (
1162
+ weights_data_ptr + id_index * stride,
1163
+ weights_data_ptr + (id_index + 1 ) * stride,
1164
+ block);
1165
+ }
1166
+ }
1167
+ });
1168
+ });
1169
+ futures.push_back (std::move (f));
1170
+ }
1171
+ return folly::collect (futures);
1172
+ }
1173
+
933
1174
std::unique_ptr<folly::CPUThreadPoolExecutor> executor_;
934
1175
// background thread
935
1176
folly::FunctionScheduler scheduler_;
@@ -942,6 +1183,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
942
1183
std::atomic_bool is_eviction_ongoing_ = false ;
943
1184
std::vector<std::unique_ptr<ssd::Initializer>> initializers_;
944
1185
int64_t elem_size_;
1186
+ bool backend_return_whole_row_;
945
1187
std::vector<int64_t > sub_table_dims_;
946
1188
std::vector<int64_t > sub_table_hash_cumsum_;
947
1189
std::optional<c10::intrusive_ptr<FeatureEvictConfig>> feature_evict_config_;
0 commit comments