Skip to content

Commit 62f3c82

Browse files
duduyi2013facebook-github-bot
authored andcommitted
patch fixes for eviction (#4304)
Summary: Pull Request resolved: #4304 X-link: facebookresearch/FBGEMM#1380 # change set ## Eviction Related 1. move trigger eviction to the beginning of each get call, since get is called once per iteration 2. move resume() to the end of each set calls, there are 2 set calls each train iteration, one happens at the end of prefetch during forward, the other happen when SP embedding is updated during backward 3. change dram kv iteration counter to the get() call which will only bump up once on each train iteration 4. make evict_flag_ auto updated by the last finished shard, to get clearer and deterministic state transition in different cases 5. each eviction round will issue num_shards of long running threads that can be paused/resumed, instead of every pause/resume will destroy/create a new work item for thread pool 6. introduce an additional eviction trigger scheduling logic to avoid repeated and intensive insufficient eviction effort(eviction scan all item but evict nothing), by force a fixed wait interval for 2 consecutive evictions to happen 7. fix get_feature_evict_metric, we need to make a copy before getting out of the mutex scope, otherwise, metrics might be updated async by eviction threads ## Miscellaneous 1. wrap all eviction configs in FeatureEvictConfig and pass it down all the way from TBE to feature_evict, all the future eviction configs will be added inside FeatureEvictConfig 2. make state_dict wait until ongoing eviction finishes 3. fix misuse between feature hash cumsum with table hash cumsum, basically for feature evict, we want table level hash cumsum instead of feature level 4. add UT for different corner case of eviction and make sure state transition is expected Reviewed By: emlin Differential Revision: D76244371 fbshipit-source-id: 96b8e0f0563d5615e56d31d0f91c779be1ba1be5
1 parent 509724d commit 62f3c82

13 files changed

+1372
-540
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import enum
1313
from dataclasses import dataclass
14-
from typing import List, NamedTuple, Tuple
14+
from typing import List, NamedTuple, Optional, Tuple
1515

1616
import torch
1717
from torch import Tensor
@@ -60,6 +60,43 @@ def from_str(cls, key: str):
6060
raise ValueError(f"Cannot parse value into EmbeddingLocation: {key}")
6161

6262

63+
class EvictionPolicy(NamedTuple):
64+
eviction_trigger_mode: int = (
65+
0 # disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
66+
)
67+
eviction_strategy: int = (
68+
0 # 0: timestamp, 1: counter (feature score), 2: counter (feature score) + timestamp, 3: feature l2 norm
69+
)
70+
eviction_step_intervals: Optional[int] = (
71+
None # trigger_step_interval if trigger mode is iteration
72+
)
73+
eviction_mem_threshold_gb: Optional[int] = (
74+
None # eviction trigger condition if trigger mode is mem_util
75+
)
76+
counter_thresholds: Optional[List[int]] = (
77+
None # count_thresholds for each table if eviction strategy is feature score
78+
)
79+
ttls_in_mins: Optional[List[int]] = (
80+
None # ttls_in_mins for each table if eviction strategy is timestamp
81+
)
82+
counter_decay_rates: Optional[List[float]] = (
83+
None # count_decay_rates for each table if eviction strategy is feature score
84+
)
85+
l2_weight_thresholds: Optional[List[float]] = (
86+
None # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
87+
)
88+
interval_for_insufficient_eviction_s: int = (
89+
# wait at least # seconds before trigger next round of eviction, if last finished eviction is insufficient
90+
# insufficient means we didn't evict enough rows, so we want to wait longer time to
91+
# avoid another insufficient eviction
92+
600
93+
)
94+
interval_for_sufficient_eviction_s: int = (
95+
# wait at least # seconds before trigger next round of eviction, if last finished eviction is sufficient
96+
60
97+
)
98+
99+
63100
class KVZCHParams(NamedTuple):
64101
# global bucket id start and global bucket id end offsets for each logical table,
65102
# where start offset is inclusive and end offset is exclusive
@@ -69,6 +106,7 @@ class KVZCHParams(NamedTuple):
69106
bucket_sizes: List[int] = []
70107
# enable optimizer offloading or not
71108
enable_optimizer_offloading: bool = False
109+
eviction_policy: Optional[EvictionPolicy] = None
72110

73111
def validate(self) -> None:
74112
assert len(self.bucket_offsets) == len(self.bucket_sizes), (

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,12 @@ def __init__(
248248
self.total_hash_size_bits: int = 0
249249
else:
250250
self.total_hash_size_bits: int = int(log2(float(hash_size_cumsum[-1])) + 1)
251+
self.register_buffer(
252+
"table_hash_size_cumsum",
253+
torch.tensor(
254+
hash_size_cumsum, device=self.current_device, dtype=torch.int64
255+
),
256+
)
251257
# The last element is to easily access # of rows of each table by
252258
self.total_hash_size_bits = int(log2(float(hash_size_cumsum[-1])) + 1)
253259
self.total_hash_size: int = hash_size_cumsum[-1]
@@ -288,6 +294,10 @@ def __init__(
288294
"feature_dims",
289295
torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
290296
)
297+
self.register_buffer(
298+
"table_dims",
299+
torch.tensor(dims, device="cpu", dtype=torch.int64),
300+
)
291301

292302
(info_B_num_bits_, info_B_mask_) = torch.ops.fbgemm.get_infos_metadata(
293303
self.D_offsets, # unused tensor
@@ -518,6 +528,7 @@ def __init__(
518528
logging.warning("dist is not initialized, treating as single gpu cases")
519529
tbe_unique_id = SSDTableBatchedEmbeddingBags._local_instance_index
520530
self.tbe_unique_id = tbe_unique_id
531+
self.l2_cache_size = l2_cache_size
521532
logging.info(f"tbe_unique_id: {tbe_unique_id}")
522533
if self.backend_type == BackendType.SSD:
523534
logging.info(
@@ -564,12 +575,12 @@ def __init__(
564575
self.res_params.table_offsets,
565576
self.res_params.table_sizes,
566577
(
567-
tensor_pad4(self.feature_dims.cpu())
578+
tensor_pad4(self.table_dims)
568579
if self.enable_optimizer_offloading
569580
else None
570581
),
571582
(
572-
self.hash_size_cumsum.cpu()
583+
self.table_hash_size_cumsum.cpu()
573584
if self.enable_optimizer_offloading
574585
else None
575586
),
@@ -609,28 +620,42 @@ def __init__(
609620
f"feature_dims={self.feature_dims},"
610621
f"hash_size_cumsum={self.hash_size_cumsum}"
611622
)
623+
table_dims = (
624+
tensor_pad4(self.table_dims)
625+
if self.enable_optimizer_offloading
626+
else None
627+
) # table_dims
628+
eviction_config = None
629+
if self.kv_zch_params and self.kv_zch_params.eviction_policy:
630+
eviction_mem_threshold_gb = (
631+
self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb
632+
if self.kv_zch_params.eviction_policy.eviction_mem_threshold_gb
633+
else self.l2_cache_size
634+
)
635+
eviction_config = torch.classes.fbgemm.FeatureEvictConfig(
636+
self.kv_zch_params.eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
637+
self.kv_zch_params.eviction_policy.eviction_strategy, # evict_trigger_strategy: 0: timestamp, 1: counter (feature score), 2: counter (feature score) + timestamp, 3: feature l2 norm
638+
self.kv_zch_params.eviction_policy.eviction_step_intervals, # trigger_step_interval if trigger mode is iteration
639+
eviction_mem_threshold_gb, # mem_util_threshold_in_GB if trigger mode is mem_util
640+
self.kv_zch_params.eviction_policy.ttls_in_mins, # ttls_in_mins for each table if eviction strategy is timestamp
641+
self.kv_zch_params.eviction_policy.counter_thresholds, # counter_thresholds for each table if eviction strategy is feature score
642+
self.kv_zch_params.eviction_policy.counter_decay_rates, # counter_decay_rates for each table if eviction strategy is feature score
643+
self.kv_zch_params.eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
644+
table_dims.tolist() if table_dims is not None else None,
645+
self.kv_zch_params.eviction_policy.interval_for_insufficient_eviction_s,
646+
self.kv_zch_params.eviction_policy.interval_for_sufficient_eviction_s,
647+
)
612648
self._ssd_db = torch.classes.fbgemm.DramKVEmbeddingCacheWrapper(
613649
self.cache_row_dim,
614650
ssd_uniform_init_lower,
615651
ssd_uniform_init_upper,
616-
0, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual
617-
0, # trigger_step_interval if trigger mode is iteration
618-
0, # mem_util_threshold_in_GB if trigger mode is mem_util
619-
0, # evict_trigger_strategy: 0: timestamp, 1: counter (feature score), 2: counter (feature score) + timestamp, 3: feature l2 norm
620-
None, # count_thresholds for each table if eviction strategy is feature score
621-
None, # ttls_in_mins for each table if eviction strategy is timestamp
622-
None, # count_decay_rates for each table if eviction strategy is feature score
623-
None, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
652+
eviction_config,
624653
ssd_rocksdb_shards, # num_shards
625654
ssd_rocksdb_shards, # num_threads
626655
weights_precision.bit_rate(), # row_storage_bitwidth
656+
table_dims,
627657
(
628-
tensor_pad4(self.feature_dims.cpu())
629-
if self.enable_optimizer_offloading
630-
else None
631-
), # table_dims
632-
(
633-
self.hash_size_cumsum.cpu()
658+
self.table_hash_size_cumsum.cpu()
634659
if self.enable_optimizer_offloading
635660
else None
636661
), # hash_size_cumsum
@@ -2434,6 +2459,13 @@ def _may_create_snapshot_for_state_dict(
24342459
f"created snapshot for weight states: {snapshot_handle}, latency: {(time.time() - start_time) * 1000} ms"
24352460
)
24362461
elif self.backend_type == BackendType.DRAM:
2462+
# if there is any ongoing eviction, lets wait until eviction is finished before state_dict
2463+
# so that we can reach consistent model state before/after state_dict
2464+
evict_wait_start_time = time.time()
2465+
self.ssd_db.wait_until_eviction_done()
2466+
logging.info(
2467+
f"state_dict wait for ongoing eviction: {time.time() - evict_wait_start_time} s"
2468+
)
24372469
self.flush(force=should_flush)
24382470
return snapshot_handle, checkpoint_handle
24392471

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_base.h

Lines changed: 0 additions & 51 deletions
This file was deleted.

0 commit comments

Comments
 (0)