diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index a140bee4fa..cabf1b02a4 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -49,6 +49,7 @@ ) from fbgemm_gpu.split_table_batched_embeddings_ops_training_common import ( generate_vbe_metadata, + is_torchdynamo_compiling, ) from torch import distributed as dist, nn, Tensor # usort:skip from dataclasses import dataclass @@ -465,6 +466,15 @@ def __init__( self.timestep = 0 + # Store the iteration number on GPU and CPU (used for certain optimizers) + persistent_iter_ = optimizer in (OptimType.PARTIAL_ROWWISE_ADAM,) + self.register_buffer( + "iter", + torch.zeros(1, dtype=torch.int64, device=self.current_device), + persistent=persistent_iter_, + ) + self.iter_cpu: torch.Tensor = torch.zeros(1, dtype=torch.int64, device="cpu") + # Dummy profile configuration for measuring the SSD get/set time # get and set are executed by another thread which (for some reason) is # not traceable by PyTorch's Kineto. We workaround this problem by @@ -2059,6 +2069,26 @@ def _generate_vbe_metadata( self.current_device, ) + def _increment_iteration(self) -> int: + # Although self.iter_cpu is created on CPU. It might be transferred to + # GPU by the user. So, we need to transfer it to CPU explicitly. This + # should be done only once. + self.iter_cpu = self.iter_cpu.cpu() + + # Sync with loaded state + # Wrap to make it compatible with PT2 compile + if not is_torchdynamo_compiling(): + if self.iter_cpu.item() == 0: + self.iter_cpu.fill_(self.iter.cpu().item()) + + # Increment the iteration counter + # The CPU counterpart is used for local computation + iter_int = int(self.iter_cpu.add_(1).item()) + # The GPU counterpart is used for checkpointing + self.iter.add_(1) + + return iter_int + def forward( self, indices: Tensor, @@ -2154,6 +2184,9 @@ def forward( self.timesteps_prefetched.pop(0) self.step += 1 + # Increment the iteration (value is used for certain optimizers) + self._increment_iteration() + if self.optimizer == OptimType.EXACT_SGD: raise AssertionError( "SSDTableBatchedEmbeddingBags currently does not support SGD"