4
4
import torch
5
5
from torch import nn
6
6
from torch .testing ._internal .common_utils import TestCase
7
- from torchao .dtypes .nf4tensor import linear_nf4 , NF4Tensor
7
+ from torchao .dtypes .nf4tensor import linear_nf4 , NF4Tensor , to_nf4
8
8
import torch .nn .functional as F
9
9
import io
10
10
from collections import OrderedDict
@@ -48,7 +48,7 @@ class TestNF4Linear(TestCase):
48
48
class TestMod (nn .Module ):
49
49
def __init__ (self , tensor , block_size , scaler_block_size ):
50
50
super ().__init__ ()
51
- self .param = torch .nn .Parameter (NF4Tensor . from_tensor (tensor , block_size , scaler_block_size ))
51
+ self .param = torch .nn .Parameter (to_nf4 (tensor , block_size , scaler_block_size ))
52
52
53
53
def save_state_dict_to_buffer (self , state_dict : OrderedDict ):
54
54
buffer = io .BytesIO ()
@@ -57,9 +57,7 @@ def save_state_dict_to_buffer(self, state_dict: OrderedDict):
57
57
return buffer
58
58
59
59
def test_register_nf4_as_param (self ):
60
- nf4_tensor = NF4Tensor .from_tensor (
61
- inpt_tensor = torch .randn (512 , 512 , dtype = torch .bfloat16 )
62
- )
60
+ nf4_tensor = to_nf4 (torch .randn (512 , 512 , dtype = torch .bfloat16 ))
63
61
64
62
# Would raise if nn.Parameter registration fails, such as no detach()
65
63
# impl when calling __torch_dispatch__
@@ -69,18 +67,14 @@ def test_register_nf4_as_param(self):
69
67
def test_output_bf16 (self ):
70
68
# Test to ensure W4 A16 produces A16
71
69
inp = torch .randn (2 , 512 , dtype = torch .bfloat16 , requires_grad = True )
72
- nf4_tensor = NF4Tensor .from_tensor (
73
- inpt_tensor = torch .randn (512 , 512 , dtype = torch .bfloat16 )
74
- )
70
+ nf4_tensor = to_nf4 (torch .randn (512 , 512 , dtype = torch .bfloat16 ))
75
71
out = linear_nf4 (input = inp , weight = nf4_tensor )
76
72
assert out .dtype == torch .bfloat16
77
73
78
74
def test_backward_bf16 (self ):
79
75
# Test to ensure backward pass gives activation a bf16 gradient and no gradient
80
76
# to the linear's weight, as it is frozen.
81
- nf4_tensor = NF4Tensor .from_tensor (
82
- inpt_tensor = torch .randn (512 , 512 , dtype = torch .bfloat16 )
83
- )
77
+ nf4_tensor = to_nf4 (torch .randn (512 , 512 , dtype = torch .bfloat16 ))
84
78
inp = torch .randn (2 , 512 , dtype = torch .bfloat16 , requires_grad = True )
85
79
linear_nf4 (inp , nf4_tensor ).sum ().backward ()
86
80
assert inp .grad is not None and inp .grad .dtype == torch .bfloat16
@@ -94,7 +88,7 @@ def test_reconstruction_qlora_vs_bnb(self):
94
88
device = "cuda"
95
89
embed_dim = 512
96
90
input_weight = _build_input_weight (embed_dim , device )
97
- nf4_weight = NF4Tensor . from_tensor (input_weight )
91
+ nf4_weight = to_nf4 (input_weight )
98
92
bnb_linear = _build_bnb_linear (input_weight , device )
99
93
bnb_reconstruction = bnb_linear (
100
94
torch .eye (embed_dim , embed_dim , dtype = torch .bfloat16 , device = device )
@@ -118,7 +112,7 @@ def test_nf4_bnb_linear(self):
118
112
dim = 512
119
113
device = "cuda"
120
114
input_weight = _build_input_weight (dim , device )
121
- nf4_weight = NF4Tensor . from_tensor (input_weight )
115
+ nf4_weight = to_nf4 (input_weight )
122
116
bnb_linear = _build_bnb_linear (input_weight , device )
123
117
124
118
inp = torch .randn (2 , 512 , dtype = torch .bfloat16 , device = "cuda" )
0 commit comments