Skip to content

Commit 5b3ee6e

Browse files
committed
Updated to_nf4 test
1 parent 9c1fdfd commit 5b3ee6e

File tree

3 files changed

+15
-19
lines changed

3 files changed

+15
-19
lines changed

test/dtypes/test_uint4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
compute_error,
1919
)
2020
from torchao.quantization.quant_api import (
21-
replace_with_custom_fn_if_matches_filter,
21+
_replace_with_custom_fn_if_matches_filter,
2222
)
2323
from torch.ao.quantization.observer import ObserverBase
2424
from torch import nn
@@ -36,7 +36,7 @@ def fn(mod):
3636
mod.weight = torch.nn.Parameter(PerChannelSymmetricWeightUInt4Tensor.from_float(mod.weight), requires_grad=False)
3737
return mod
3838

39-
replace_with_custom_fn_if_matches_filter(
39+
_replace_with_custom_fn_if_matches_filter(
4040
model,
4141
lambda mod: fn(mod),
4242
lambda mod, fqn: isinstance(mod, torch.nn.Linear),

test/modules/test_nf4_linear.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from torch import nn
66
from torch.testing._internal.common_utils import TestCase
7-
from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor
7+
from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor, to_nf4
88
import torch.nn.functional as F
99
import io
1010
from collections import OrderedDict
@@ -48,7 +48,7 @@ class TestNF4Linear(TestCase):
4848
class TestMod(nn.Module):
4949
def __init__(self, tensor, block_size, scaler_block_size):
5050
super().__init__()
51-
self.param = torch.nn.Parameter(NF4Tensor.from_tensor(tensor, block_size, scaler_block_size))
51+
self.param = torch.nn.Parameter(to_nf4(tensor, block_size, scaler_block_size))
5252

5353
def save_state_dict_to_buffer(self, state_dict: OrderedDict):
5454
buffer = io.BytesIO()
@@ -57,9 +57,7 @@ def save_state_dict_to_buffer(self, state_dict: OrderedDict):
5757
return buffer
5858

5959
def test_register_nf4_as_param(self):
60-
nf4_tensor = NF4Tensor.from_tensor(
61-
inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16)
62-
)
60+
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16))
6361

6462
# Would raise if nn.Parameter registration fails, such as no detach()
6563
# impl when calling __torch_dispatch__
@@ -69,18 +67,14 @@ def test_register_nf4_as_param(self):
6967
def test_output_bf16(self):
7068
# Test to ensure W4 A16 produces A16
7169
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
72-
nf4_tensor = NF4Tensor.from_tensor(
73-
inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16)
74-
)
70+
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16))
7571
out = linear_nf4(input=inp, weight=nf4_tensor)
7672
assert out.dtype == torch.bfloat16
7773

7874
def test_backward_bf16(self):
7975
# Test to ensure backward pass gives activation a bf16 gradient and no gradient
8076
# to the linear's weight, as it is frozen.
81-
nf4_tensor = NF4Tensor.from_tensor(
82-
inpt_tensor=torch.randn(512, 512, dtype=torch.bfloat16)
83-
)
77+
nf4_tensor = to_nf4(torch.randn(512, 512, dtype=torch.bfloat16))
8478
inp = torch.randn(2, 512, dtype=torch.bfloat16, requires_grad=True)
8579
linear_nf4(inp, nf4_tensor).sum().backward()
8680
assert inp.grad is not None and inp.grad.dtype == torch.bfloat16
@@ -94,7 +88,7 @@ def test_reconstruction_qlora_vs_bnb(self):
9488
device = "cuda"
9589
embed_dim = 512
9690
input_weight = _build_input_weight(embed_dim, device)
97-
nf4_weight = NF4Tensor.from_tensor(input_weight)
91+
nf4_weight = to_nf4(input_weight)
9892
bnb_linear = _build_bnb_linear(input_weight, device)
9993
bnb_reconstruction = bnb_linear(
10094
torch.eye(embed_dim, embed_dim, dtype=torch.bfloat16, device=device)
@@ -118,7 +112,7 @@ def test_nf4_bnb_linear(self):
118112
dim = 512
119113
device = "cuda"
120114
input_weight = _build_input_weight(dim, device)
121-
nf4_weight = NF4Tensor.from_tensor(input_weight)
115+
nf4_weight = to_nf4(input_weight)
122116
bnb_linear = _build_bnb_linear(input_weight, device)
123117

124118
inp = torch.randn(2, 512, dtype=torch.bfloat16, device="cuda")

torchao/dtypes/nf4tensor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ def __init__(
168168
def from_tensor(
169169
cls,
170170
inpt_tensor: torch.Tensor,
171-
block_size: int = 64,
172-
scaler_block_size: int = 256,
171+
block_size: int,
172+
scaler_block_size: int,
173173
):
174174
assert inpt_tensor.dtype == torch.bfloat16
175175
assert (
@@ -510,6 +510,8 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor:
510510
"""
511511
return LinearNF4.apply(input, weight)
512512

513-
def to_nf4(tensor):
513+
def to_nf4(tensor,
514+
block_size: int = 64,
515+
scaler_block_size: int = 256):
514516
tensor1 = tensor.to(torch.bfloat16)
515-
return NF4Tensor.from_tensor(tensor1)
517+
return NF4Tensor.from_tensor(tensor1, block_size, scaler_block_size)

0 commit comments

Comments
 (0)