diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index fe8fad0af1..71d9416bcc 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -25,6 +25,7 @@ # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers +import pyjk as justknobs from fbgemm_gpu.config import FeatureGate, FeatureGateName from fbgemm_gpu.runtime_monitor import ( @@ -3600,16 +3601,22 @@ def prepare_inputs( # NOTE: Force offsets to have the same dtype as indices since the # kernels assume same dtype. We might need to revisit the assumption # of same dtypes in the future. - if self.embedding_table_index_type == torch.int32: + if ( + self.embedding_table_index_type == torch.int32 + and self.embedding_table_offset_type == torch.int32 + and justknobs.check("pytorch/torchrec:int32_rollout_killswitch") + ): self.log( - "Casting indices to int32 based on embedding_table_index_type input." + "Casting indices and offsets to int32 based on embedding_table_index_type and embedding_table_offset_type inputs." ) indices = indices.to(torch.int32) - if self.embedding_table_index_type != self.embedding_table_offset_type: + offsets = offsets.to(torch.int32) + else: self.log( - f"Force casting offsets to {self.embedding_table_index_type} so that it is the same as the indices type." + "Casting indices and offsets to int64 as either embedding_table_index_type or embedding_table_offset_type is not int32." ) - offsets = offsets.to(dtype=indices.dtype) + indices = indices.to(torch.int64) + offsets = offsets.to(torch.int64) # Force casting per_sample_weights to float if per_sample_weights is not None: