Skip to content

Commit 8fb51c1

Browse files
houseroadfacebook-github-bot
authored andcommitted
Make initialize_weights scriptable in IntNBitTableBatchedEmbeddingBagsCodegen (#1622)
Summary: Pull Request resolved: #1622 Need to init the weight during model loading with XL V2, so export it. Reviewed By: qxy11, jianyuh Differential Revision: D43709984 fbshipit-source-id: 1730ea6f108303fee0811eb9de3fa5bef2fa8267
1 parent 4aaa389 commit 8fb51c1

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2451,7 +2451,8 @@ def _apply_split(
24512451
if uvm_size > 0:
24522452
assert not self.use_cpu
24532453
if enforce_hbm:
2454-
logging.info("Enforce hbm for the cache location")
2454+
if not torch.jit.is_scripting():
2455+
logging.info("Enforce hbm for the cache location")
24552456
self.weights_uvm = torch.zeros(
24562457
uvm_size,
24572458
device=self.current_device,
@@ -2767,6 +2768,7 @@ def split_embedding_weights(
27672768

27682769
return splits
27692770

2771+
@torch.jit.export
27702772
def initialize_weights(self) -> None:
27712773
if not self.weight_initialized:
27722774
self._apply_split(
@@ -2777,7 +2779,7 @@ def initialize_weights(self) -> None:
27772779
self.weights_physical_offsets,
27782780
self.enforce_hbm,
27792781
)
2780-
self.weight_initialized: bool = True
2782+
self.weight_initialized = True
27812783

27822784
def fill_random_weights(self) -> None:
27832785
"""

0 commit comments

Comments
 (0)