From 686a2e984f0ab892689257591b46b67f9fd4bd69 Mon Sep 17 00:00:00 2001 From: Basil Wong Date: Sun, 6 Jul 2025 15:40:15 -0700 Subject: [PATCH] split_table_batched_embeddings_ops_training int32 support behind jk Summary: ### tl;dr After this diff stack int32 indices and offsets will be supported for FBGEMM embedding lookup kernels. This will be able to enabled via config on APS. ### Implementation https://docs.google.com/document/d/1GoFghmJcDSGf6XhVkoTJs4C0jTemvpGe1fCNi6oQDRo/edit?usp=sharing ### Context https://docs.google.com/document/d/1YVfxsafqXkxAAdRyXbjmSH4AEz3-6DBiTGjs1rT8ZHQ/edit?usp=sharing ### Diff specific changes Putting the ability to cast to int32 behind jk killswitch which we can turn off at any time at the kernel level. Differential Revision: D77843253 --- ...lit_table_batched_embeddings_ops_training.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index fe8fad0af1..71d9416bcc 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -25,6 +25,7 @@ # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers +import pyjk as justknobs from fbgemm_gpu.config import FeatureGate, FeatureGateName from fbgemm_gpu.runtime_monitor import ( @@ -3600,16 +3601,22 @@ def prepare_inputs( # NOTE: Force offsets to have the same dtype as indices since the # kernels assume same dtype. We might need to revisit the assumption # of same dtypes in the future. - if self.embedding_table_index_type == torch.int32: + if ( + self.embedding_table_index_type == torch.int32 + and self.embedding_table_offset_type == torch.int32 + and justknobs.check("pytorch/torchrec:int32_rollout_killswitch") + ): self.log( - "Casting indices to int32 based on embedding_table_index_type input." + "Casting indices and offsets to int32 based on embedding_table_index_type and embedding_table_offset_type inputs." ) indices = indices.to(torch.int32) - if self.embedding_table_index_type != self.embedding_table_offset_type: + offsets = offsets.to(torch.int32) + else: self.log( - f"Force casting offsets to {self.embedding_table_index_type} so that it is the same as the indices type." + "Casting indices and offsets to int64 as either embedding_table_index_type or embedding_table_offset_type is not int32." ) - offsets = offsets.to(dtype=indices.dtype) + indices = indices.to(torch.int64) + offsets = offsets.to(torch.int64) # Force casting per_sample_weights to float if per_sample_weights is not None: