Skip to content

Commit 789ea2c

Browse files
emlinfacebook-github-bot
authored andcommitted
add virtual table eviction policy (#4433)
Summary: X-link: pytorch/torchrec#3172 X-link: facebookresearch/FBGEMM#1498 Pull Request resolved: #4433 Add eviction policy to embedding config and also enable config in mvai model family Reviewed By: duduyi2013, yixin94 Differential Revision: D75660955 fbshipit-source-id: e514f56a88b46f5000f8d54478531f7d4e739f21
1 parent 5d24e24 commit 789ea2c

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,71 @@ class EvictionPolicy(NamedTuple):
9696
60
9797
)
9898

99+
def validate(self) -> None:
100+
assert self.eviction_trigger_mode in [0, 1, 2, 3], (
101+
"eviction_trigger_mode must be 0, 1, 2, or 3, "
102+
f"actual {self.eviction_trigger_mode}"
103+
)
104+
if self.eviction_trigger_mode == 0:
105+
return
106+
107+
assert self.eviction_strategy in [0, 1, 2, 3], (
108+
"eviction_strategy must be 0, 1, 2, or 3, "
109+
f"actual {self.eviction_strategy}"
110+
)
111+
if self.eviction_trigger_mode == 1:
112+
assert (
113+
self.eviction_step_intervals is not None
114+
and self.eviction_step_intervals > 0
115+
), (
116+
"eviction_step_intervals must be positive if eviction_trigger_mode is 1, "
117+
f"actual {self.eviction_step_intervals}"
118+
)
119+
elif self.eviction_trigger_mode == 2:
120+
assert (
121+
self.eviction_mem_threshold_gb is not None
122+
), "eviction_mem_threshold_gb must be set if eviction_trigger_mode is 2"
123+
124+
if self.eviction_strategy == 0:
125+
assert self.ttls_in_mins is not None, (
126+
"ttls_in_mins must be set if eviction_strategy is 0, "
127+
f"actual {self.ttls_in_mins}"
128+
)
129+
elif self.eviction_strategy == 1:
130+
assert self.counter_thresholds is not None, (
131+
"counter_thresholds must be set if eviction_strategy is 1, "
132+
f"actual {self.counter_thresholds}"
133+
)
134+
assert self.counter_decay_rates is not None, (
135+
"counter_decay_rates must be set if eviction_strategy is 1, "
136+
f"actual {self.counter_decay_rates}"
137+
)
138+
assert len(self.counter_thresholds) == len(self.counter_decay_rates), (
139+
"counter_thresholds and counter_decay_rates must have the same length, "
140+
f"actual {self.counter_thresholds} vs {self.counter_decay_rates}"
141+
)
142+
elif self.eviction_strategy == 2:
143+
assert self.counter_thresholds is not None, (
144+
"counter_thresholds must be set if eviction_strategy is 2, "
145+
f"actual {self.counter_thresholds}"
146+
)
147+
assert self.counter_decay_rates is not None, (
148+
"counter_decay_rates must be set if eviction_strategy is 2, "
149+
f"actual {self.counter_decay_rates}"
150+
)
151+
assert self.ttls_in_mins is not None, (
152+
"ttls_in_mins must be set if eviction_strategy is 2, "
153+
f"actual {self.ttls_in_mins}"
154+
)
155+
assert len(self.counter_thresholds) == len(self.counter_decay_rates), (
156+
"counter_thresholds and counter_decay_rates must have the same length, "
157+
f"actual {self.counter_thresholds} vs {self.counter_decay_rates}"
158+
)
159+
assert len(self.counter_thresholds) == len(self.ttls_in_mins), (
160+
"counter_thresholds and ttls_in_mins must have the same length, "
161+
f"actual {self.counter_thresholds} vs {self.ttls_in_mins}"
162+
)
163+
99164

100165
class KVZCHParams(NamedTuple):
101166
# global bucket id start and global bucket id end offsets for each logical table,
@@ -113,6 +178,8 @@ def validate(self) -> None:
113178
"bucket_offsets and bucket_sizes must have the same length, "
114179
f"actual {self.bucket_offsets} vs {self.bucket_sizes}"
115180
)
181+
if self.eviction_policy is not None:
182+
self.eviction_policy.validate()
116183

117184

118185
class BackendType(enum.IntEnum):

fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,10 @@ class TimeBasedEvict : public FeatureEvict<weight_type> {
740740
protected:
741741
bool evict_block(weight_type* block, int sub_table_id) override {
742742
int64_t ttl = ttls_in_mins_[sub_table_id];
743+
if (ttl == 0) {
744+
// ttl = 0 means no eviction
745+
return false;
746+
}
743747
auto current_time = FixedBlockPool::current_timestamp();
744748
return current_time - FixedBlockPool::get_timestamp(block) > ttl * 60;
745749
}
@@ -776,8 +780,17 @@ class TimeCounterBasedEvict : public FeatureEvict<weight_type> {
776780
protected:
777781
bool evict_block(weight_type* block, int sub_table_id) override {
778782
int64_t ttl = ttls_in_mins_[sub_table_id];
783+
if (ttl == 0) {
784+
// ttl = 0 means no eviction
785+
return false;
786+
}
779787
double decay_rate = decay_rates_[sub_table_id];
780788
int64_t threshold = thresholds_[sub_table_id];
789+
if (threshold == 0) {
790+
// threshold = 0 means no eviction
791+
return false;
792+
}
793+
781794
// Apply decay and check the count threshold and ttl.
782795
auto current_time = FixedBlockPool::current_timestamp();
783796
auto current_count = FixedBlockPool::get_count(block);
@@ -818,6 +831,10 @@ class L2WeightBasedEvict : public FeatureEvict<weight_type> {
818831
bool evict_block(weight_type* block, int sub_table_id) override {
819832
size_t dimension = sub_table_dims_[sub_table_id];
820833
double threshold = thresholds_[sub_table_id];
834+
if (threshold == 0.0) {
835+
// threshold = 0 means no eviction
836+
return false;
837+
}
821838
auto l2weight = FixedBlockPool::get_l2weight(block, dimension);
822839
return l2weight < threshold;
823840
}

0 commit comments

Comments
 (0)