Skip to content

Support get/set the whole row of metaheader+weight+optimizer from backend for checkpoint saving/loading #4435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class EvictionPolicy(NamedTuple):
# wait at least # seconds before trigger next round of eviction, if last finished eviction is sufficient
60
)
meta_header_lens: Optional[List[int]] = None # metaheader length for each table

def validate(self) -> None:
assert self.eviction_trigger_mode in [0, 1, 2, 3], (
Expand Down Expand Up @@ -171,15 +172,20 @@ class KVZCHParams(NamedTuple):
bucket_sizes: List[int] = []
# enable optimizer offloading or not
enable_optimizer_offloading: bool = False
eviction_policy: Optional[EvictionPolicy] = None
# when enabled, backend will return whole row(metaheader + weight + optimizer) instead of weight only
# can only be enabled when enable_optimizer_offloading is enabled
backend_return_whole_row: bool = False
eviction_policy: EvictionPolicy = EvictionPolicy()

def validate(self) -> None:
assert len(self.bucket_offsets) == len(self.bucket_sizes), (
"bucket_offsets and bucket_sizes must have the same length, "
f"actual {self.bucket_offsets} vs {self.bucket_sizes}"
)
if self.eviction_policy is not None:
self.eviction_policy.validate()
self.eviction_policy.validate()
assert (
not self.backend_return_whole_row or self.enable_optimizer_offloading
), "backend_return_whole_row can only be enabled when enable_optimizer_offloading is enabled"


class BackendType(enum.IntEnum):
Expand Down
148 changes: 122 additions & 26 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,27 @@ def __init__(
self.kv_zch_params = kv_zch_params
self.backend_type = backend_type
self.enable_optimizer_offloading: bool = False
self.backend_return_whole_row: bool = False
if self.kv_zch_params:
self.kv_zch_params.validate()
self.enable_optimizer_offloading = (
# pyre-ignore [16]
self.kv_zch_params.enable_optimizer_offloading
)
self.backend_return_whole_row = (
# pyre-ignore [16]
self.kv_zch_params.backend_return_whole_row
)

if self.enable_optimizer_offloading:
logging.info("Optimizer state offloading is enabled")
if self.backend_return_whole_row:
assert (
self.backend_type == BackendType.DRAM
), f"Only DRAM backend supports backend_return_whole_row, but got {self.backend_type}"
logging.info(
"Backend will return whole row including metaheader, weight and optimizer for checkpoint"
)

self.pooling_mode = pooling_mode
self.bounds_check_mode_int: int = bounds_check_mode.value
Expand Down Expand Up @@ -625,13 +637,14 @@ def __init__(
logging.info(
f"Logging DRAM offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB,"
f"num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards},"
f"max_D={self.max_D}"
f"max_D={self.max_D},"
f"uniform_init_lower={ssd_uniform_init_lower},uniform_init_upper={ssd_uniform_init_upper},"
f"row_storage_bitwidth={weights_precision.bit_rate()},"
f"self.cache_row_dim={self.cache_row_dim},"
f"enable_optimizer_offloading={self.enable_optimizer_offloading},"
f"feature_dims={self.feature_dims},"
f"hash_size_cumsum={self.hash_size_cumsum}"
f"hash_size_cumsum={self.hash_size_cumsum},"
f"backend_return_whole_row={self.backend_return_whole_row}"
)
table_dims = (
tensor_pad4(self.table_dims)
Expand Down Expand Up @@ -672,6 +685,7 @@ def __init__(
if self.enable_optimizer_offloading
else None
), # hash_size_cumsum
self.backend_return_whole_row, # backend_return_whole_row
)
else:
raise AssertionError(f"Invalid backend type {self.backend_type}")
Expand Down Expand Up @@ -2282,16 +2296,19 @@ def split_optimizer_states(
# pyre-ignore
bucket_size = self.kv_zch_params.bucket_sizes[t]
row_offset = table_offset
if sorted_id_tensor is None or sorted_id_tensor[t].numel() == 0:
if not self.backend_return_whole_row and (
sorted_id_tensor is None or sorted_id_tensor[t].numel() == 0
):
opt_list.append(
torch.empty(0, dtype=self.optimizer.dtype(), device="cpu")
# empty optimizer state for module initialization
# which will NOT be used for cp loading
)
else:
if not self.enable_optimizer_offloading:
# convert global id back to local id, then linearize with table offset
local_id_tensor = (
sorted_id_tensor[t]
sorted_id_tensor[t] # pyre-ignore[16]
- bucket_id_start * bucket_size
+ table_offset
)
Expand All @@ -2300,27 +2317,79 @@ def split_optimizer_states(
)
else:
row_offset = table_offset - (bucket_id_start * bucket_size)
# using KVTensorWrapper to query backend to avoid OOM memory, since
# backend will return both weight and optimizer in one tensor, read the whole tensor
# out could OOM CPU memory.
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
shape=[emb_height, optimizer_dim],
dtype=dtype,
row_offset=row_offset,
snapshot_handle=snapshot_handle,
sorted_indices=sorted_id_tensor[t],
width_offset=pad4(emb_dim),
)
(
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
if self.backend_type == BackendType.SSD
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
)
opt_list.append(
self.get_offloaded_optimizer_states(
tensor_wrapper, sorted_id_tensor[t].numel()
if self.backend_return_whole_row:
# When backend returns whole row, the optimizer will be returned as PMT directly
if (
sorted_id_tensor[t].size(0) == 0
and self.local_weight_counts[t] > 0
):
logging.info(
f"before opt PMT loading, resetting id tensor with {self.local_weight_counts[t]}"
)
# pyre-ignore [16]
sorted_id_tensor[t] = torch.zeros(
(self.local_weight_counts[t], 1),
device=torch.device("cpu"),
dtype=torch.int64,
)

metaheader_dim = (
# pyre-ignore[16]
self.kv_zch_params.eviction_policy.meta_header_lens[t]
)
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
shape=[
(
sorted_id_tensor[t].size(0)
if sorted_id_tensor is not None
and sorted_id_tensor[t].size(0) > 0
else emb_height
),
optimizer_dim,
],
dtype=dtype,
row_offset=row_offset,
snapshot_handle=snapshot_handle,
sorted_indices=sorted_id_tensor[t],
width_offset=(
metaheader_dim # metaheader is already padded so no need for pad4
+ pad4(emb_dim)
),
read_only=True, # optimizer written to DB with weights, so skip write here
)
(
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
if self.backend_type == BackendType.SSD
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
)
opt_list.append(
PartiallyMaterializedTensor(
tensor_wrapper,
True if self.kv_zch_params else False,
)
)
else:
# using KVTensorWrapper to query backend to avoid OOM memory, since
# backend will return both weight and optimizer in one tensor, read the whole tensor
# out could OOM CPU memory.
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
shape=[emb_height, optimizer_dim],
dtype=dtype,
row_offset=row_offset,
snapshot_handle=snapshot_handle,
sorted_indices=sorted_id_tensor[t],
width_offset=pad4(emb_dim),
)
(
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
if self.backend_type == BackendType.SSD
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
)
opt_list.append(
self.get_offloaded_optimizer_states(
tensor_wrapper, sorted_id_tensor[t].numel()
)
)
)
table_offset += emb_height
logging.info(
f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms, "
Expand Down Expand Up @@ -2515,10 +2584,15 @@ def split_embedding_weights(
bucket_ascending_id_tensor = None
bucket_t = None
row_offset = table_offset
metaheader_dim = 0
if self.kv_zch_params:
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
# pyre-ignore
bucket_size = self.kv_zch_params.bucket_sizes[i]
metaheader_dim = (
# pyre-ignore[16]
self.kv_zch_params.eviction_policy.meta_header_lens[i]
)

# linearize with table offset
table_input_id_start = table_offset
Expand Down Expand Up @@ -2548,7 +2622,7 @@ def split_embedding_weights(
and self.local_weight_counts[i] > 0
):
logging.info(
f"resetting bucket id tensor with {self.local_weight_counts[i]}"
f"before weight PMT loading, resetting id tensor with {self.local_weight_counts[i]}"
)
bucket_ascending_id_tensor = torch.zeros(
(self.local_weight_counts[i], 1),
Expand All @@ -2574,7 +2648,19 @@ def split_embedding_weights(
if bucket_ascending_id_tensor is not None
else emb_height
),
emb_dim,
(
(
metaheader_dim # metaheader is already padded
+ pad4(emb_dim)
+ pad4(
self.optimizer.state_size_dim(
self.weights_precision.as_dtype()
)
)
)
if self.backend_return_whole_row
else emb_dim
),
],
dtype=dtype,
row_offset=row_offset,
Expand Down Expand Up @@ -2611,6 +2697,11 @@ def split_embedding_weights(

@torch.jit.ignore
def apply_state_dict(self) -> None:
if self.backend_return_whole_row:
logging.info(
"backend_return_whole_row is enabled, no need to apply_state_dict"
)
return
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
# Caller should call this function to apply the cached states to backend.
if self.load_state_dict is False:
Expand Down Expand Up @@ -2729,6 +2820,11 @@ def streaming_write_weight_and_id_per_table(

@torch.jit.ignore
def enable_load_state_dict_mode(self) -> None:
if self.backend_return_whole_row:
logging.info(
"backend_return_whole_row is enabled, no need to enable load_state_dict mode"
)
return
# Enable load state dict mode before loading checkpoint
if self.load_state_dict:
return
Expand Down
87 changes: 87 additions & 0 deletions fbgemm_gpu/test/tbe/ssd/kv_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,3 +784,90 @@ def test_dram_kv_eviction(self) -> None:
self.assertTrue(all(processed_counts >= shard_load))
self.assertTrue(all(full_duration_ms > 0))
self.assertTrue(all(exec_duration_ms >= 0))

@given(
T=st.integers(min_value=2, max_value=10),
D=st.integers(min_value=2, max_value=128),
log_E=st.integers(min_value=2, max_value=3),
weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]),
enable_l2=st.sampled_from([True, False]),
)
@settings(**default_settings)
def test_dram_enable_backend_return_whole_row(
self,
T: int,
D: int,
log_E: int,
weights_precision: SparseType,
enable_l2: bool,
) -> None:
kv_zch_params = KVZCHParams(
enable_optimizer_offloading=True,
backend_return_whole_row=True, # whole row will be returned to KVT
)
metaheader_dim: int = 16 // (weights_precision.bit_rate() // 8)
opt_dim: int = 4 // (weights_precision.bit_rate() // 8)
emb, Es, Ds = self.generate_fbgemm_kv_tbe(
T,
D,
log_E,
weights_precision,
mixed=True,
enable_l2=enable_l2,
kv_zch_params=kv_zch_params,
backend_type=BackendType.DRAM,
)
dtype = weights_precision.as_dtype()
row_offset = 0
max_D = max(Ds)
N = 2

for E, D in zip(Es, Ds):
# create random index tensor with size N, valued from [0, N-1] unordered
indices = torch.randperm(N)
# insert the weights with the corresponding indices into the table
# which will also populate the metaheader with weight_id at front
weights = torch.arange(N * D, dtype=dtype).view(N, D)
padded_weights = torch.nn.functional.pad(weights, (0, max_D - D))
# emb.ssd_db.set_kv_to_storage(indices + row_offset, padded_weights)
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
shape=[E, D], # only write D from weights
dtype=dtype,
row_offset=row_offset,
snapshot_handle=None,
)
tensor_wrapper.set_dram_db_wrapper(emb.ssd_db)
tensor_wrapper.set_weights_and_ids(padded_weights, indices)

# reset KVT's shape to full dim to get whole row
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
shape=[E, metaheader_dim + pad4(D) + pad4(opt_dim)],
dtype=dtype,
row_offset=row_offset,
snapshot_handle=None,
)
tensor_wrapper.set_dram_db_wrapper(emb.ssd_db)

# Call narrow which should fetch the whole row
narrowed = tensor_wrapper.narrow(0, 0, N)
opt_offset = metaheader_dim + pad4(D)

for i in range(N):
# Check if the id matches
torch.testing.assert_close(
narrowed[i, : metaheader_dim // 2].view(torch.int64),
torch.tensor([i + row_offset], dtype=torch.int64),
)

# Check if weight matches the one passed in with weights
torch.testing.assert_close(
narrowed[i, metaheader_dim:opt_offset],
weights[indices.tolist().index(i)],
)

# The trailing opt part should all be init'ed with 0s
torch.testing.assert_close(
narrowed[:, opt_offset : opt_offset + opt_dim],
torch.zeros(N, opt_dim, dtype=dtype),
)
row_offset += E
Loading
Loading