Skip to content

Commit 9f3c6ae

Browse files
committed
to_nf4 and support for _to_copy
1 parent 55e5d40 commit 9f3c6ae

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

torchao/dtypes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from .nf4tensor import NF4Tensor, to_nf4
12
from .uint4 import UInt4Tensor
23

34
__all__ = [
5+
"NF4Tensor",
6+
"to_nf4",
47
"UInt4Tensor"
58
]

torchao/dtypes/nf4tensor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ def decorator(func):
2828
def noop_detach(func, *args, **kwargs):
2929
return args[0][0]
3030

31+
@implements([torch.ops.aten._to_copy.default])
32+
def _to_copy(func, *args, **kwargs):
33+
return args[0][0].get_original_weight().to(args[1]['dtype'])
34+
3135

3236
@dataclass
3337
class SubclassTensorArgs:
@@ -416,6 +420,10 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride
416420
inner_tensors["nf4"],
417421
)
418422

423+
424+
def __str__(self):
425+
return self.to(torch.float32).__str__()
426+
419427
@classmethod
420428
def __torch_dispatch__(cls, func, types, args, kwargs=None):
421429
"""TODO we are not supporting torch dispatch at the moment
@@ -465,3 +473,6 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor:
465473
weight: NF4Tensor weight
466474
"""
467475
return LinearNF4.apply(input, weight)
476+
477+
def to_nf4(tensor):
478+
return NF4Tensor.from_tensor(tensor)

0 commit comments

Comments
 (0)