Skip to content

Commit acc9889

Browse files
set dtype
1 parent ac14d92 commit acc9889

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchao/prototype/moe_training/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
tensor: torch.Tensor,
6565
dtype: torch.dtype,
6666
):
67-
self._data = tensor
67+
self._data = tensor.to(dtype)
6868
self._dtype = dtype
6969

7070
@classmethod

0 commit comments

Comments
 (0)