Skip to content

Commit 0babda5

Browse files
authored
Merge pull request #43 from pytorch-labs/nf4to1
to_nf4 and support for _to_copy
2 parents ebde5e6 + 40a95c5 commit 0babda5

File tree

5 files changed

+46
-17
lines changed

5 files changed

+46
-17
lines changed

test/dtypes/test_uint4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
compute_error,
1919
)
2020
from torchao.quantization.quant_api import (
21-
replace_with_custom_fn_if_matches_filter,
21+
_replace_with_custom_fn_if_matches_filter,
2222
)
2323
from torch.ao.quantization.observer import ObserverBase
2424
from torch import nn
@@ -36,7 +36,7 @@ def fn(mod):
3636
mod.weight = torch.nn.Parameter(PerChannelSymmetricWeightUInt4Tensor.from_float(mod.weight), requires_grad=False)
3737
return mod
3838

39-
replace_with_custom_fn_if_matches_filter(
39+
_replace_with_custom_fn_if_matches_filter(
4040
model,
4141
lambda mod: fn(mod),
4242
lambda mod, fqn: isinstance(mod, torch.nn.Linear),

test/modules/test_nf4_linear.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from torch import nn
66
from torch.testing._internal.common_utils import TestCase
7-
from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor
7+
from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor, to_nf4
88
import torch.nn.functional as F
99
import io
1010
from collections import OrderedDict
@@ -48,7 +48,7 @@ class TestNF4Linear(TestCase):
4848
class TestMod(nn.Module):
4949
def __init__(self, tensor, block_size, scaler_block_size):
5050
super().__init__()
51-
self.param = torch.nn.Parameter(NF4Tensor.from_tensor(tensor, block_size, scaler_block_size))
51+
self.param = torch.nn.Parameter(to_nf4(tensor, block_size, scaler_block_size))
5252

5353
def save_state_dict_to_buffer(self, state_dict: OrderedDict):
5454
buffer = io.BytesIO()
@@ -57,9 +57,7 @@ def save_state_dict_to_buffer(self, state_dict: OrderedDict):
5757
return buffer
5858

5959
def test_register_nf4_as_param(self):
60-
nf4_tensor = NF4Tensor.from_tensor(
61-
inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16)
62-
)
60+
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16))
6361

6462
# Would raise if nn.Parameter registration fails, such as no detach()
6563
# impl when calling __torch_dispatch__
@@ -69,18 +67,14 @@ def test_register_nf4_as_param(self):
6967
def test_output_bf16(self):
7068
# Test to ensure W4 A16 produces A16
7169
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
72-
nf4_tensor = NF4Tensor.from_tensor(
73-
inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16)
74-
)
70+
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16))
7571
out = linear_nf4(input=inp, weight=nf4_tensor)
7672
assert out.dtype == torch.bfloat16
7773

7874
def test_backward_bf16(self):
7975
# Test to ensure backward pass gives activation a bf16 gradient and no gradient
8076
# to the linear's weight, as it is frozen.
81-
nf4_tensor = NF4Tensor.from_tensor(
82-
inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16)
83-
)
77+
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16))
8478
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
8579
linear_nf4(inp, nf4_tensor).sum().backward()
8680
assert inp.grad is not None and inp.grad.dtype == torch.bfloat16
@@ -94,7 +88,7 @@ def test_reconstruction_qlora_vs_bnb(self):
9488
device = "cuda"
9589
embed_dim = 512
9690
input_weight = _build_input_weight(embed_dim, device)
97-
nf4_weight = NF4Tensor.from_tensor(input_weight)
91+
nf4_weight = to_nf4(input_weight)
9892
bnb_linear = _build_bnb_linear(input_weight, device)
9993
bnb_reconstruction = bnb_linear(
10094
torch.eye(embed_dim, embed_dim, dtype=torch.bfloat16, device=device)
@@ -118,7 +112,7 @@ def test_nf4_bnb_linear(self):
118112
dim = 512
119113
device = "cuda"
120114
input_weight = _build_input_weight(dim, device)
121-
nf4_weight = NF4Tensor.from_tensor(input_weight)
115+
nf4_weight = to_nf4(input_weight)
122116
bnb_linear = _build_bnb_linear(input_weight, device)
123117

124118
inp = torch.randn(2, 512, dtype=torch.bfloat16, device="cuda")
@@ -170,5 +164,18 @@ def test_load_from_nf4_diff_meta(self):
170164
assert other_mod.param.block_size == 64
171165
assert other_mod.param.scaler_block_size == 1
172166

167+
def test_to_copy(self):
168+
inpt_tensor = torch.rand(128, device='cpu')
169+
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
170+
inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16)
171+
torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13)
172+
173+
if torch.cuda.is_available():
174+
inpt_tensor = torch.rand(128, device='cuda')
175+
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
176+
inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16)
177+
torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13)
178+
179+
173180
if __name__ == "__main__":
174181
unittest.main()

torchao/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from . import dtypes
2+
3+
__all__ = [
4+
"dtypes"
5+
]

torchao/dtypes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from .nf4tensor import NF4Tensor, to_nf4
12
from .uint4 import UInt4Tensor
23

34
__all__ = [
5+
"NF4Tensor",
6+
"to_nf4",
47
"UInt4Tensor"
58
]

torchao/dtypes/nf4tensor.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def decorator(func):
3636
def noop_detach(func, *args, **kwargs):
3737
return args[0][0]
3838

39+
@implements([torch.ops.aten._to_copy.default])
40+
def _to_copy(func, *args, **kwargs):
41+
return args[0][0].get_original_weight().to(args[1]['dtype'])
42+
3943

4044
@implements(
4145
[
@@ -164,8 +168,8 @@ def __init__(
164168
def from_tensor(
165169
cls,
166170
inpt_tensor: torch.Tensor,
167-
block_size: int = 64,
168-
scaler_block_size: int = 256,
171+
block_size: int,
172+
scaler_block_size: int,
169173
):
170174
assert inpt_tensor.dtype == torch.bfloat16
171175
assert (
@@ -452,6 +456,10 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride
452456
inner_tensors["nf4"],
453457
)
454458

459+
460+
def __str__(self):
461+
return self.to(torch.float32).__str__()
462+
455463
@classmethod
456464
def __torch_dispatch__(cls, func, types, args, kwargs=None):
457465
"""TODO we are not supporting torch dispatch at the moment
@@ -501,3 +509,9 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor:
501509
weight: NF4Tensor weight
502510
"""
503511
return LinearNF4.apply(input, weight)
512+
513+
def to_nf4(tensor,
514+
block_size: int = 64,
515+
scaler_block_size: int = 256):
516+
tensor1 = tensor.to(torch.bfloat16)
517+
return NF4Tensor.from_tensor(tensor1, block_size, scaler_block_size)

0 commit comments

Comments
 (0)