Skip to content

Commit 99aea27

Browse files
tugsbayasgalanfacebook-github-bot
authored andcommitted
Move tbe weights as buffers so export can track properly (#4369)
Summary: X-link: facebookresearch/FBGEMM#1438 Pull Request resolved: #4369 Title Reviewed By: SherlockNoMad Differential Revision: D76846172 fbshipit-source-id: 2074df3e5d9d971dba1c9442c8399ae5fc82297e
1 parent dbd59e8 commit 99aea27

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -517,18 +517,29 @@ def max_ty_D(ty: SparseType) -> int:
517517
)
518518
self.weight_initialized: bool = False
519519

520-
self.weights_dev: torch.Tensor = torch.zeros(
521-
0,
522-
device=self.current_device,
523-
dtype=torch.uint8,
520+
self.register_buffer(
521+
"weights_dev",
522+
torch.zeros(
523+
0,
524+
device=self.current_device,
525+
dtype=torch.uint8,
526+
),
527+
persistent=False,
524528
)
525529

526-
self.weights_host: torch.Tensor = torch.zeros(
527-
0, device=self.current_device, dtype=torch.uint8
530+
self.register_buffer(
531+
"weights_host",
532+
torch.zeros(
533+
0,
534+
device=self.current_device,
535+
dtype=torch.uint8,
536+
),
528537
)
529538

530-
self.weights_uvm: torch.Tensor = torch.empty(
531-
0, device=self.current_device, dtype=torch.uint8
539+
self.register_buffer(
540+
"weights_uvm",
541+
torch.empty(0, device=self.current_device, dtype=torch.uint8),
542+
persistent=False,
532543
)
533544

534545
cached_dims = [

0 commit comments

Comments
 (0)