diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py index 1bd696ed8d..714637cabe 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py @@ -95,6 +95,7 @@ class EvictionPolicy(NamedTuple): # wait at least # seconds before trigger next round of eviction, if last finished eviction is sufficient 60 ) + meta_header_lens: Optional[List[int]] = None # metaheader length for each table def validate(self) -> None: assert self.eviction_trigger_mode in [0, 1, 2, 3], ( @@ -171,15 +172,20 @@ class KVZCHParams(NamedTuple): bucket_sizes: List[int] = [] # enable optimizer offloading or not enable_optimizer_offloading: bool = False - eviction_policy: Optional[EvictionPolicy] = None + # when enabled, backend will return whole row(metaheader + weight + optimizer) instead of weight only + # can only be enabled when enable_optimizer_offloading is enabled + backend_return_whole_row: bool = False + eviction_policy: EvictionPolicy = EvictionPolicy() def validate(self) -> None: assert len(self.bucket_offsets) == len(self.bucket_sizes), ( "bucket_offsets and bucket_sizes must have the same length, " f"actual {self.bucket_offsets} vs {self.bucket_sizes}" ) - if self.eviction_policy is not None: - self.eviction_policy.validate() + self.eviction_policy.validate() + assert ( + not self.backend_return_whole_row or self.enable_optimizer_offloading + ), "backend_return_whole_row can only be enabled when enable_optimizer_offloading is enabled" class BackendType(enum.IntEnum): diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index cabf1b02a4..7ef539e036 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -190,15 +190,27 @@ def __init__( self.kv_zch_params = kv_zch_params self.backend_type = backend_type self.enable_optimizer_offloading: bool = False + self.backend_return_whole_row: bool = False if self.kv_zch_params: self.kv_zch_params.validate() self.enable_optimizer_offloading = ( # pyre-ignore [16] self.kv_zch_params.enable_optimizer_offloading ) + self.backend_return_whole_row = ( + # pyre-ignore [16] + self.kv_zch_params.backend_return_whole_row + ) if self.enable_optimizer_offloading: logging.info("Optimizer state offloading is enabled") + if self.backend_return_whole_row: + assert ( + self.backend_type == BackendType.DRAM + ), f"Only DRAM backend supports backend_return_whole_row, but got {self.backend_type}" + logging.info( + "Backend will return whole row including metaheader, weight and optimizer for checkpoint" + ) self.pooling_mode = pooling_mode self.bounds_check_mode_int: int = bounds_check_mode.value @@ -625,13 +637,14 @@ def __init__( logging.info( f"Logging DRAM offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB," f"num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards}," - f"max_D={self.max_D}" + f"max_D={self.max_D}," f"uniform_init_lower={ssd_uniform_init_lower},uniform_init_upper={ssd_uniform_init_upper}," f"row_storage_bitwidth={weights_precision.bit_rate()}," f"self.cache_row_dim={self.cache_row_dim}," f"enable_optimizer_offloading={self.enable_optimizer_offloading}," f"feature_dims={self.feature_dims}," - f"hash_size_cumsum={self.hash_size_cumsum}" + f"hash_size_cumsum={self.hash_size_cumsum}," + f"backend_return_whole_row={self.backend_return_whole_row}" ) table_dims = ( tensor_pad4(self.table_dims) @@ -672,6 +685,7 @@ def __init__( if self.enable_optimizer_offloading else None ), # hash_size_cumsum + self.backend_return_whole_row, # backend_return_whole_row ) else: raise AssertionError(f"Invalid backend type {self.backend_type}") @@ -2282,16 +2296,19 @@ def split_optimizer_states( # pyre-ignore bucket_size = self.kv_zch_params.bucket_sizes[t] row_offset = table_offset - if sorted_id_tensor is None or sorted_id_tensor[t].numel() == 0: + if not self.backend_return_whole_row and ( + sorted_id_tensor is None or sorted_id_tensor[t].numel() == 0 + ): opt_list.append( torch.empty(0, dtype=self.optimizer.dtype(), device="cpu") # empty optimizer state for module initialization + # which will NOT be used for cp loading ) else: if not self.enable_optimizer_offloading: # convert global id back to local id, then linearize with table offset local_id_tensor = ( - sorted_id_tensor[t] + sorted_id_tensor[t] # pyre-ignore[16] - bucket_id_start * bucket_size + table_offset ) @@ -2300,27 +2317,79 @@ def split_optimizer_states( ) else: row_offset = table_offset - (bucket_id_start * bucket_size) - # using KVTensorWrapper to query backend to avoid OOM memory, since - # backend will return both weight and optimizer in one tensor, read the whole tensor - # out could OOM CPU memory. - tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( - shape=[emb_height, optimizer_dim], - dtype=dtype, - row_offset=row_offset, - snapshot_handle=snapshot_handle, - sorted_indices=sorted_id_tensor[t], - width_offset=pad4(emb_dim), - ) - ( - tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db) - if self.backend_type == BackendType.SSD - else tensor_wrapper.set_dram_db_wrapper(self.ssd_db) - ) - opt_list.append( - self.get_offloaded_optimizer_states( - tensor_wrapper, sorted_id_tensor[t].numel() + if self.backend_return_whole_row: + # When backend returns whole row, the optimizer will be returned as PMT directly + if ( + sorted_id_tensor[t].size(0) == 0 + and self.local_weight_counts[t] > 0 + ): + logging.info( + f"before opt PMT loading, resetting id tensor with {self.local_weight_counts[t]}" + ) + # pyre-ignore [16] + sorted_id_tensor[t] = torch.zeros( + (self.local_weight_counts[t], 1), + device=torch.device("cpu"), + dtype=torch.int64, + ) + + metaheader_dim = ( + # pyre-ignore[16] + self.kv_zch_params.eviction_policy.meta_header_lens[t] + ) + tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( + shape=[ + ( + sorted_id_tensor[t].size(0) + if sorted_id_tensor is not None + and sorted_id_tensor[t].size(0) > 0 + else emb_height + ), + optimizer_dim, + ], + dtype=dtype, + row_offset=row_offset, + snapshot_handle=snapshot_handle, + sorted_indices=sorted_id_tensor[t], + width_offset=( + metaheader_dim # metaheader is already padded so no need for pad4 + + pad4(emb_dim) + ), + read_only=True, # optimizer written to DB with weights, so skip write here + ) + ( + tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db) + if self.backend_type == BackendType.SSD + else tensor_wrapper.set_dram_db_wrapper(self.ssd_db) + ) + opt_list.append( + PartiallyMaterializedTensor( + tensor_wrapper, + True if self.kv_zch_params else False, + ) + ) + else: + # using KVTensorWrapper to query backend to avoid OOM memory, since + # backend will return both weight and optimizer in one tensor, read the whole tensor + # out could OOM CPU memory. + tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( + shape=[emb_height, optimizer_dim], + dtype=dtype, + row_offset=row_offset, + snapshot_handle=snapshot_handle, + sorted_indices=sorted_id_tensor[t], + width_offset=pad4(emb_dim), + ) + ( + tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db) + if self.backend_type == BackendType.SSD + else tensor_wrapper.set_dram_db_wrapper(self.ssd_db) + ) + opt_list.append( + self.get_offloaded_optimizer_states( + tensor_wrapper, sorted_id_tensor[t].numel() + ) ) - ) table_offset += emb_height logging.info( f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms, " @@ -2515,10 +2584,15 @@ def split_embedding_weights( bucket_ascending_id_tensor = None bucket_t = None row_offset = table_offset + metaheader_dim = 0 if self.kv_zch_params: bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i] # pyre-ignore bucket_size = self.kv_zch_params.bucket_sizes[i] + metaheader_dim = ( + # pyre-ignore[16] + self.kv_zch_params.eviction_policy.meta_header_lens[i] + ) # linearize with table offset table_input_id_start = table_offset @@ -2548,7 +2622,7 @@ def split_embedding_weights( and self.local_weight_counts[i] > 0 ): logging.info( - f"resetting bucket id tensor with {self.local_weight_counts[i]}" + f"before weight PMT loading, resetting id tensor with {self.local_weight_counts[i]}" ) bucket_ascending_id_tensor = torch.zeros( (self.local_weight_counts[i], 1), @@ -2574,7 +2648,19 @@ def split_embedding_weights( if bucket_ascending_id_tensor is not None else emb_height ), - emb_dim, + ( + ( + metaheader_dim # metaheader is already padded + + pad4(emb_dim) + + pad4( + self.optimizer.state_size_dim( + self.weights_precision.as_dtype() + ) + ) + ) + if self.backend_return_whole_row + else emb_dim + ), ], dtype=dtype, row_offset=row_offset, @@ -2611,6 +2697,11 @@ def split_embedding_weights( @torch.jit.ignore def apply_state_dict(self) -> None: + if self.backend_return_whole_row: + logging.info( + "backend_return_whole_row is enabled, no need to apply_state_dict" + ) + return # After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint. # Caller should call this function to apply the cached states to backend. if self.load_state_dict is False: @@ -2729,6 +2820,11 @@ def streaming_write_weight_and_id_per_table( @torch.jit.ignore def enable_load_state_dict_mode(self) -> None: + if self.backend_return_whole_row: + logging.info( + "backend_return_whole_row is enabled, no need to enable load_state_dict mode" + ) + return # Enable load state dict mode before loading checkpoint if self.load_state_dict: return diff --git a/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py b/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py index c9a871e6d8..307629ee6d 100644 --- a/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py +++ b/fbgemm_gpu/test/tbe/ssd/kv_backend_test.py @@ -784,3 +784,90 @@ def test_dram_kv_eviction(self) -> None: self.assertTrue(all(processed_counts >= shard_load)) self.assertTrue(all(full_duration_ms > 0)) self.assertTrue(all(exec_duration_ms >= 0)) + + @given( + T=st.integers(min_value=2, max_value=10), + D=st.integers(min_value=2, max_value=128), + log_E=st.integers(min_value=2, max_value=3), + weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]), + enable_l2=st.sampled_from([True, False]), + ) + @settings(**default_settings) + def test_dram_enable_backend_return_whole_row( + self, + T: int, + D: int, + log_E: int, + weights_precision: SparseType, + enable_l2: bool, + ) -> None: + kv_zch_params = KVZCHParams( + enable_optimizer_offloading=True, + backend_return_whole_row=True, # whole row will be returned to KVT + ) + metaheader_dim: int = 16 // (weights_precision.bit_rate() // 8) + opt_dim: int = 4 // (weights_precision.bit_rate() // 8) + emb, Es, Ds = self.generate_fbgemm_kv_tbe( + T, + D, + log_E, + weights_precision, + mixed=True, + enable_l2=enable_l2, + kv_zch_params=kv_zch_params, + backend_type=BackendType.DRAM, + ) + dtype = weights_precision.as_dtype() + row_offset = 0 + max_D = max(Ds) + N = 2 + + for E, D in zip(Es, Ds): + # create random index tensor with size N, valued from [0, N-1] unordered + indices = torch.randperm(N) + # insert the weights with the corresponding indices into the table + # which will also populate the metaheader with weight_id at front + weights = torch.arange(N * D, dtype=dtype).view(N, D) + padded_weights = torch.nn.functional.pad(weights, (0, max_D - D)) + # emb.ssd_db.set_kv_to_storage(indices + row_offset, padded_weights) + tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( + shape=[E, D], # only write D from weights + dtype=dtype, + row_offset=row_offset, + snapshot_handle=None, + ) + tensor_wrapper.set_dram_db_wrapper(emb.ssd_db) + tensor_wrapper.set_weights_and_ids(padded_weights, indices) + + # reset KVT's shape to full dim to get whole row + tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( + shape=[E, metaheader_dim + pad4(D) + pad4(opt_dim)], + dtype=dtype, + row_offset=row_offset, + snapshot_handle=None, + ) + tensor_wrapper.set_dram_db_wrapper(emb.ssd_db) + + # Call narrow which should fetch the whole row + narrowed = tensor_wrapper.narrow(0, 0, N) + opt_offset = metaheader_dim + pad4(D) + + for i in range(N): + # Check if the id matches + torch.testing.assert_close( + narrowed[i, : metaheader_dim // 2].view(torch.int64), + torch.tensor([i + row_offset], dtype=torch.int64), + ) + + # Check if weight matches the one passed in with weights + torch.testing.assert_close( + narrowed[i, metaheader_dim:opt_offset], + weights[indices.tolist().index(i)], + ) + + # The trailing opt part should all be init'ed with 0s + torch.testing.assert_close( + narrowed[:, opt_offset : opt_offset + opt_dim], + torch.zeros(N, opt_dim, dtype=dtype), + ) + row_offset += E diff --git a/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py b/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py index ef18965225..adb3102bfb 100644 --- a/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py +++ b/fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py @@ -647,3 +647,65 @@ def test_dram_kv_and_rdb_snapshot_check(self) -> None: tensor_wrapper.narrow(0, 0, 1) with self.assertRaises(RuntimeError): tensor_wrapper.get_weights_by_ids(torch.tensor([1])) + + def test_dram_kv_read_only_mode(self) -> None: + max_D = MAX_D # max emb dimension seen dram backend + D = 64 + N = 10 # window size + E = int(1e4) + weights_precision = SparseType.FP16 + weights_dtype = weights_precision.as_dtype() + + dram_kv = torch.classes.fbgemm.DramKVEmbeddingCacheWrapper( + max_D=max_D, + uniform_init_lower=-0.1, + uniform_init_upper=0.1, + num_shards=8, + num_threads=32, + row_storage_bitwidth=weights_precision.bit_rate(), + ) + + # create random index tensor with size N + indices = torch.arange(N) + # insert the weights with the corresponding indices into the table + weights = torch.arange(N * D, dtype=weights_dtype).view(N, D) + padded_weights = torch.nn.functional.pad(weights, (0, max_D - D), value=1.0) + count = torch.tensor([N]) + dram_kv.set(indices, padded_weights, count) + + tensor_wrapper_read_only = torch.classes.fbgemm.KVTensorWrapper( + shape=[E, D], dtype=weights_dtype, row_offset=0, read_only=True + ) + tensor_wrapper_read_only.set_dram_db_wrapper(dram_kv) + + # Get the weights that are already stored in the DRAM KV cache + narrowed_weights = tensor_wrapper_read_only.narrow(0, 0, N) + weights_by_ids = tensor_wrapper_read_only.get_weights_by_ids(indices) + self.assertTrue( + torch.equal(narrowed_weights, weights), + msg=( + f"Tensor value mismatch :\n" + f"actual\n{narrowed_weights}\n\nexpected\n{weights}" + ), + ) + self.assertTrue( + torch.equal(weights_by_ids, weights), + msg=( + f"Tensor value mismatch :\n" + f"actual\n{weights_by_ids}\n\nexpected\n{weights}" + ), + ) + + # Try to set_range() on a read-only tensor wrapper, which should be no-op + insert_weight = torch.randn(D, dtype=weights_dtype).view(1, D) + tensor_wrapper_read_only.set_range(0, N, 1, insert_weight) + + # narrow from the above, which should not match the original weights + narrowed_weight = tensor_wrapper_read_only.narrow(0, N, 1) + self.assertTrue( + not torch.equal(narrowed_weight, insert_weight), + msg=( + f"Tensor value should not match :\n" + f"actual\n{narrowed_weight}\n\nexpected\n{insert_weight}" + ), + ) diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index f483750a86..881900d321 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -21,6 +21,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( BackendType, BoundsCheckMode, + EvictionPolicy, KVZCHParams, PoolingMode, ) @@ -41,6 +42,9 @@ MAX_EXAMPLES = 40 MAX_PIPELINE_EXAMPLES = 10 KV_WORLD_SIZE = 4 +VIRTUAL_TABLE_ROWS = int( + 2**18 +) # relatively large for now given optimizer is still pre-allocated default_st: Dict["str", Any] = { "T": st.integers(min_value=1, max_value=10), @@ -284,6 +288,7 @@ def generate_kvzch_tbes( num_buckets: int = 10, mixed: bool = False, enable_optimizer_offloading: bool = False, + backend_return_whole_row: bool = False, ) -> Tuple[ SSDTableBatchedEmbeddingBags, List[torch.nn.EmbeddingBag], @@ -313,9 +318,7 @@ def generate_kvzch_tbes( torch.manual_seed(42) E = int(10**log_E) - virtual_E = int( - 2**18 - ) # relatively large for now given optimizer is still pre-allocated + virtual_E = VIRTUAL_TABLE_ROWS D = D * 4 bucket_sizes = [] @@ -327,11 +330,19 @@ def generate_kvzch_tbes( ) bucket_end = min(math.ceil(num_buckets / KV_WORLD_SIZE), num_buckets) bucket_offsets.append((bucket_start, bucket_end)) + + # In reality this will be populated with _populate_zero_collision_tbe_params + # from virtual_table_eviction_policy. For UT, we need to explicitly populate it kv_zch_param = KVZCHParams( bucket_offsets=bucket_offsets, bucket_sizes=bucket_sizes, enable_optimizer_offloading=enable_optimizer_offloading, + backend_return_whole_row=backend_return_whole_row, + eviction_policy=EvictionPolicy( + meta_header_lens=([16 // (weights_precision.bit_rate() // 8)] * T) + ), ) + E = min(E, (bucket_offsets[0][1] - bucket_offsets[0][0]) * bucket_sizes[0]) if not mixed: @@ -2048,6 +2059,276 @@ def test_kv_opt_state_w_offloading( rtol=tolerance, ) + @given( + **default_st, + num_buckets=st.integers(min_value=10, max_value=15), + ) + @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + def test_kv_state_dict_w_backend_return_whole_row( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + cache_set_scale: float, + pooling_mode: PoolingMode, + weights_precision: SparseType, + output_dtype: SparseType, + share_table: bool, + trigger_bounds_check: bool, + mixed_B: bool, + num_buckets: int, + ) -> None: + # Constants + lr = 0.5 + eps = 0.2 + ssd_shards = 2 + metaheader_dim = 16 // (weights_precision.bit_rate() // 8) # 8-byte metaheader + opt_dim = 4 // (weights_precision.bit_rate() // 8) # 4-byte optimizer state + + trigger_bounds_check = False # don't stimulate boundary check cases + assume(not weighted or pooling_mode == PoolingMode.SUM) + assume(not mixed_B or pooling_mode != PoolingMode.NONE) + + # Generate embedding modules and inputs + ( + emb, + emb_ref, + Es, + _, + bucket_offsets, + bucket_sizes, + ) = self.generate_kvzch_tbes( + T, + D, + B, + log_E, + L, + weighted, + lr=lr, + eps=eps, + ssd_shards=ssd_shards, + cache_set_scale=cache_set_scale, + pooling_mode=pooling_mode, + weights_precision=weights_precision, + output_dtype=output_dtype, + share_table=share_table, + num_buckets=num_buckets, + backend_type=BackendType.DRAM, + enable_optimizer_offloading=True, + backend_return_whole_row=True, + ) + + # Generate inputs + ( + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + batch_size_per_feature_per_rank, + ) = self.generate_inputs_( + B, + L, + Es, + emb.feature_table_map, + weights_precision=weights_precision, + trigger_bounds_check=trigger_bounds_check, + mixed_B=mixed_B, + bucket_offsets=bucket_offsets, + bucket_sizes=bucket_sizes, + is_kv_tbes=True, + ) + + # Execute forward + output_ref_list, output = self.execute_ssd_forward_( + emb, + emb_ref, + indices_list, + per_sample_weights_list, + indices, + offsets, + per_sample_weights, + B, + L, + weighted, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + + # Generate output gradient + output_grad_list = [torch.randn_like(out) for out in output_ref_list] + + # Execute torch EmbeddingBag backward + [out.backward(grad) for (out, grad) in zip(output_ref_list, output_grad_list)] + if batch_size_per_feature_per_rank is not None: + grad_test = self.concat_ref_tensors_vbe( + output_grad_list, batch_size_per_feature_per_rank + ) + else: + grad_test = self.concat_ref_tensors( + output_grad_list, + pooling_mode != PoolingMode.NONE, # do_pooling + B, + D * 4, + ) + + # Execute TBE SSD backward + output.backward(grad_test) + + tolerance = ( + 1.0e-4 + if weights_precision == SparseType.FP32 and output_dtype == SparseType.FP32 + else 1.0e-2 + ) + + emb.flush() + + # Compare emb state dict with expected values from nn.EmbeddingBag + emb_state_dict_list, bucket_asc_ids_list, num_active_id_per_bucket_list = ( + emb.split_embedding_weights(no_snapshot=False, should_flush=True) + ) + split_optimizer_states = emb.split_optimizer_states( + bucket_asc_ids_list, no_snapshot=False + ) + table_input_id_range = [] + for t, row in enumerate(Es): + bucket_id_start = bucket_offsets[t][0] + bucket_id_end = bucket_offsets[t][1] + bucket_size = bucket_sizes[t] + table_input_id_range.append( + ( + min(bucket_id_start * bucket_size, row), + min(bucket_id_end * bucket_size, row), + ) + ) + # since we use ref_emb in dense format, the rows start from id 0 + self.assertEqual(table_input_id_range[-1][0], 0) + + """ + validate optimizer states + """ + opt_validated = [] + for f, t in self.get_physical_table_arg_indices_(emb.feature_table_map): + # pyre-fixme[16]: Optional type has no attribute `float`. + ref_emb = emb_ref[f].weight.grad.float().to_dense().pow(2).cpu() + ref_optimizer_state = ref_emb.mean(dim=1)[ + table_input_id_range[t][0] : min( + table_input_id_range[t][1], emb_ref[f].weight.size(0) + ) + ] + # pyre-fixme[16]: Undefined attribute: `Optional` has no attribute `__getitem__`. + ref_kv_opt = ref_optimizer_state[bucket_asc_ids_list[t]].view(-1) + opt = ( + split_optimizer_states[t] + .narrow(0, 0, bucket_asc_ids_list[t].size(0)) + .view(-1) + .view(torch.float32) + .float() + ) + opt_validated.append(opt.clone().detach()) + torch.testing.assert_close( + opt, + ref_kv_opt, + atol=tolerance, + rtol=tolerance, + ) + + table_offset = 0 + for feature_index, table_index in self.get_physical_table_arg_indices_( + emb.feature_table_map + ): + """ + validate bucket_asc_ids_list and num_active_id_per_bucket_list + """ + bucket_asc_id = bucket_asc_ids_list[table_index] + num_active_id_per_bucket = num_active_id_per_bucket_list[table_index] + + bucket_id_start = bucket_offsets[table_index][0] + bucket_id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + num_active_id_per_bucket.view(-1) + ) + for bucket_idx, id_count in enumerate(num_active_id_per_bucket): + bucket_id = bucket_idx + bucket_id_start + active_id_cnt = 0 + for idx in range( + bucket_id_offsets[bucket_idx], + bucket_id_offsets[bucket_idx + 1], + ): + # for chunk-based hashing + self.assertEqual( + bucket_id, bucket_asc_id[idx] // bucket_sizes[table_index] + ) + active_id_cnt += 1 + self.assertEqual(active_id_cnt, id_count) + + """ + validate the whole embeddings rows (metaheader + weight + opt) + """ + num_ids = len(bucket_asc_ids_list[table_index]) + emb_r_w = emb_ref[feature_index].weight[ + bucket_asc_ids_list[table_index].view(-1) + ] + emb_r_w_g = ( + emb_ref[feature_index] + .weight.grad.float() + .to_dense()[bucket_asc_ids_list[table_index].view(-1)] + ) + self.assertLess(table_index, len(emb_state_dict_list)) + assert split_optimizer_states[table_index].size(0) == num_ids + new_ref_weight = torch.addcdiv( + emb_r_w.float(), + value=-lr, + tensor1=emb_r_w_g, + tensor2=opt_validated[table_index] + .clone() + .sqrt_() + .add_(eps) + .view( + num_ids, + 1, + ) + .cuda(), + ).cpu() + + emb_w = emb_state_dict_list[table_index].narrow( + 0, 0, bucket_asc_ids_list[table_index].size(0) + ) + # Compare the opt part + opt_extracted_from_emb_w = ( + emb_w[:, (metaheader_dim + D * 4) : (metaheader_dim + D * 4) + opt_dim] + .view(torch.float32) + .view(-1) + ) + torch.testing.assert_close( + opt_extracted_from_emb_w, + opt_validated[table_index], + atol=tolerance, + rtol=tolerance, + ) + + # Copmare the id part + id_extracted_from_emb_w = ( + emb_w[:, 0 : metaheader_dim // 2].view(torch.int64).view(-1) + ) + torch.testing.assert_close( + id_extracted_from_emb_w, + bucket_asc_ids_list[table_index].view(-1) + table_offset, + atol=tolerance, + rtol=tolerance, + ) + + # Compare the weight part + torch.testing.assert_close( + emb_w[:, metaheader_dim : metaheader_dim + D * 4].float(), + new_ref_weight, + atol=tolerance, + rtol=tolerance, + ) + + table_offset += VIRTUAL_TABLE_ROWS + @given( **default_st, num_buckets=st.integers(min_value=10, max_value=15),