Skip to content

Commit a6007f4

Browse files
qxy11facebook-github-bot
authored andcommitted
Back out "Move tbe weights as buffers so export can track properly" (#4469)
Summary: Pull Request resolved: #4469 X-link: facebookresearch/FBGEMM#1527 Reviewed By: kqfu Differential Revision: D78106760 fbshipit-source-id: e8a6e467436e8009b453d7e0f5156268d735ace5
1 parent 6bdbc78 commit a6007f4

File tree

1 file changed

+8
-19
lines changed

1 file changed

+8
-19
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -517,29 +517,18 @@ def max_ty_D(ty: SparseType) -> int:
517517
)
518518
self.weight_initialized: bool = False
519519

520-
self.register_buffer(
521-
"weights_dev",
522-
torch.zeros(
523-
0,
524-
device=self.current_device,
525-
dtype=torch.uint8,
526-
),
527-
persistent=False,
520+
self.weights_dev: torch.Tensor = torch.zeros(
521+
0,
522+
device=self.current_device,
523+
dtype=torch.uint8,
528524
)
529525

530-
self.register_buffer(
531-
"weights_host",
532-
torch.zeros(
533-
0,
534-
device=self.current_device,
535-
dtype=torch.uint8,
536-
),
526+
self.weights_host: torch.Tensor = torch.zeros(
527+
0, device=self.current_device, dtype=torch.uint8
537528
)
538529

539-
self.register_buffer(
540-
"weights_uvm",
541-
torch.empty(0, device=self.current_device, dtype=torch.uint8),
542-
persistent=False,
530+
self.weights_uvm: torch.Tensor = torch.empty(
531+
0, device=self.current_device, dtype=torch.uint8
543532
)
544533

545534
cached_dims = [

0 commit comments

Comments
 (0)