Skip to content

Commit 9c1fdfd

Browse files
committed
Remove bfloat16 constraint from to_nf4
1 parent 145b1df commit 9c1fdfd

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

torchao/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from . import dtypes
2+
3+
__all__ = [
4+
"dtypes"
5+
]

torchao/dtypes/nf4tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,4 +511,5 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor:
511511
return LinearNF4.apply(input, weight)
512512

513513
def to_nf4(tensor):
514-
return NF4Tensor.from_tensor(tensor)
514+
tensor1 = tensor.to(torch.bfloat16)
515+
return NF4Tensor.from_tensor(tensor1)

0 commit comments

Comments
 (0)