Skip to content

Commit f7e12c8

Browse files
authored
Change NF4Tensor dtype and add support for linear (#62)
1 parent 857b545 commit f7e12c8

File tree

3 files changed

+90
-3
lines changed

3 files changed

+90
-3
lines changed

.github/workflows/regression_test.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,32 @@ jobs:
2727
pip install torch
2828
2929
30+
- name: Install package
31+
run: |
32+
pip install .
33+
34+
- name: Run tests
35+
run: |
36+
pytest test
37+
38+
test-nightly:
39+
runs-on: 4-core-ubuntu-gpu-t4
40+
steps:
41+
- uses: actions/checkout@v2
42+
43+
- name: Set up Python
44+
uses: actions/setup-python@v2
45+
with:
46+
python-version: 3.9
47+
48+
- name: Install dependencies
49+
run: |
50+
python -m pip install --upgrade pip
51+
pip install -r requirements.txt
52+
pip install -r dev-requirements.txt
53+
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
54+
55+
3056
- name: Install package
3157
run: |
3258
pip install .

test/dtypes/test_nf4.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch.nn.functional as F
99
import io
1010
from collections import OrderedDict
11+
import torchao
1112

1213
bnb_available = False
1314

@@ -176,6 +177,28 @@ def test_to_copy(self):
176177
inpt_tensor_bfloat16 = inpt_tensor_nf4.to(torch.bfloat16)
177178
torch.testing.assert_allclose(inpt_tensor, inpt_tensor_bfloat16, atol=0.13, rtol=0.13)
178179

180+
def test_to_bfloat16(self):
181+
inpt_tensor = torch.rand(128, dtype=torch.bfloat16)
182+
inpt_tensor_nf4 = to_nf4(inpt_tensor, 32, 2)
183+
assert type(inpt_tensor_nf4) != torch.Tensor
184+
assert type(inpt_tensor_nf4.to(torch.bfloat16)) == torch.Tensor
185+
assert inpt_tensor_nf4.to(torch.bfloat16).dtype == torch.bfloat16
186+
187+
def test_smoketest_linear(self):
188+
a = torch.randn(32, 32, dtype=torch.bfloat16, device='cuda')
189+
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
190+
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
191+
out1 = torch.nn.functional.linear(inp, a)
192+
out2 = torch.nn.functional.linear(inp, a_nf4)
193+
194+
@unittest.skipIf(torch.__version__.split('+')[0] == '2.2.1', "Broken on stable.")
195+
def test_smoketest_linear_compile(self):
196+
a = torch.randn(32, 32, dtype=torch.bfloat16, device='cuda')
197+
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
198+
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
199+
out3 = torch.compile(torch.nn.functional.linear, mode='max-autotune')(inp, a_nf4)
200+
201+
179202

180203
if __name__ == "__main__":
181204
unittest.main()

torchao/dtypes/nf4tensor.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,50 @@ def noop_detach(func, *args, **kwargs):
5151
# pyre-fixme[3]: Return type must be annotated.
5252
# pyre-fixme[2]: Parameter must be annotated.
5353
def _to_copy(func, *args, **kwargs):
54+
if not args[0][0].is_contiguous():
55+
assert args[0][0].t().is_contiguous()
56+
return func(args[0][0].t()).t()
5457
return args[0][0].get_original_weight().to(args[1]['dtype'])
5558

5659
@implements([torch.ops.aten.to.dtype])
5760
# pyre-fixme[3]: Return type must be annotated.
5861
# pyre-fixme[2]: Parameter must be annotated.
5962
def to_dtype(func, *args, **kwargs):
63+
if not args[0][0].is_contiguous():
64+
assert args[0][0].t().is_contiguous()
65+
return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t()
6066
return args[0][0].get_original_weight().to(args[0][1])
6167

68+
@implements([torch.ops.aten.t.default])
69+
# pyre-fixme[3]: Return type must be annotated.
70+
# pyre-fixme[2]: Parameter must be annotated.
71+
def t_default(func, *args, **kwargs):
72+
a = args[0][0]
73+
tensor_meta = SubclassTensorArgs(
74+
a.size(),
75+
(a.stride(1), a.stride(0)),
76+
a.storage_offset(),
77+
torch.bits2x4,
78+
a.device,
79+
a.requires_grad)
80+
b = NF4Tensor(
81+
tensor_meta,
82+
a.block_size,
83+
a.n_blocks,
84+
a.scaler_block_size,
85+
a.quantized_scalers,
86+
a.quantization_factor,
87+
a.scaler_mean,
88+
a.quantized_data,
89+
a.nf4)
90+
return b
91+
92+
@implements([torch.ops.aten.mm.default])
93+
# pyre-fixme[3]: Return type must be annotated.
94+
# pyre-fixme[2]: Parameter must be annotated.
95+
def mm_default(func, *args, **kwargs):
96+
return linear_nf4(args[0][0], args[0][1])
97+
6298

6399
@implements(
64100
[
@@ -160,7 +196,8 @@ def __new__(
160196
tensor_meta.original_shape,
161197
tensor_meta.original_strides,
162198
tensor_meta.storage_offset,
163-
dtype=tensor_meta.dtype,
199+
# Picked some floating dtype, but we need dtype extensibility
200+
dtype=torch.float8_e5m2fnuz,
164201
device=tensor_meta.device,
165202
requires_grad=tensor_meta.requires_grad,
166203
)
@@ -198,6 +235,7 @@ def from_tensor(
198235
block_size: int,
199236
scaler_block_size: int,
200237
):
238+
assert inpt_tensor.dim() <= 2
201239
assert inpt_tensor.dtype == torch.bfloat16
202240
assert (
203241
inpt_tensor.numel() % block_size == 0
@@ -428,7 +466,7 @@ def quantize_tensor_nearest(
428466
# pyre-fixme[40]: Static method `dequantize` cannot override a non-static method
429467
# defined in `torch._C.TensorBase`.
430468
def dequantize(value: torch.Tensor, nf4: torch.Tensor) -> torch.Tensor:
431-
"""Dequantize a nf4 value to float16 format"""
469+
"""Dequantize a nf4 value to bfloat16 format"""
432470
# return nf4.index_select(0, value)
433471
return nf4[value]
434472

@@ -546,7 +584,7 @@ class LinearNF4(torch.autograd.Function):
546584
def forward(ctx, input: torch.Tensor, weight: NF4Tensor):
547585
"""Save the quantized nf4 weight for backward pass"""
548586
ctx.nf4_weight = weight
549-
return F.linear(input, weight.get_original_weight())
587+
return F.linear(input, weight.to(input.dtype))
550588

551589
@staticmethod
552590
# pyre-fixme[14]: `backward` overrides method defined in `_SingleLevelFunction`

0 commit comments

Comments
 (0)