Skip to content

Commit 145b1df

Browse files
committed
Merge remote-tracking branch 'origin/main' into nf4to1
2 parents 9f3c6ae + ebde5e6 commit 145b1df

File tree

2 files changed

+85
-1
lines changed

2 files changed

+85
-1
lines changed

test/modules/test_nf4_linear.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from torch.testing._internal.common_utils import TestCase
77
from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor
88
import torch.nn.functional as F
9-
9+
import io
10+
from collections import OrderedDict
1011

1112
bnb_available = False
1213

@@ -44,6 +45,16 @@ def _build_bnb_linear(input_weight, device):
4445

4546

4647
class TestNF4Linear(TestCase):
48+
class TestMod(nn.Module):
49+
def __init__(self, tensor, block_size, scaler_block_size):
50+
super().__init__()
51+
self.param = torch.nn.Parameter(NF4Tensor.from_tensor(tensor, block_size, scaler_block_size))
52+
53+
def save_state_dict_to_buffer(self, state_dict: OrderedDict):
54+
buffer = io.BytesIO()
55+
torch.save(state_dict, buffer)
56+
buffer.seek(0)
57+
return buffer
4758

4859
def test_register_nf4_as_param(self):
4960
nf4_tensor = NF4Tensor.from_tensor(
@@ -121,6 +132,43 @@ def test_nf4_bnb_linear(self):
121132
assert err_native < 0.5 * dim
122133
assert err_bnb < 0.5 * dim
123134

135+
@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
136+
def test_load_from_bfloat16(self):
137+
"""Tests loading to and from different module state dicts"""
138+
inpt_tensor = torch.rand(64, device='cuda', dtype=torch.bfloat16)
139+
base_mod = self.TestMod(inpt_tensor, 32, 2)
140+
141+
bf16_dummy_dict = {"param": inpt_tensor}
142+
base_mod.load_state_dict(bf16_dummy_dict)
143+
144+
assert base_mod.param.block_size == 32
145+
assert base_mod.param.scaler_block_size == 2
146+
147+
@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
148+
def test_load_from_nf4_same_meta(self):
149+
"""Tests loading to and from different module state dicts"""
150+
inpt_tensor = torch.rand(64, device='cuda', dtype=torch.bfloat16)
151+
base_mod = self.TestMod(inpt_tensor, 32, 2)
152+
state_dict = base_mod.state_dict()
153+
saved_state_dict = self.save_state_dict_to_buffer(state_dict)
154+
155+
other_mod = self.TestMod(inpt_tensor, 32, 2)
156+
other_mod.load_state_dict(torch.load(saved_state_dict))
157+
assert other_mod.param.block_size == 32
158+
assert other_mod.param.scaler_block_size == 2
159+
160+
@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
161+
def test_load_from_nf4_diff_meta(self):
162+
"""Tests loading to and from different module state dicts"""
163+
inpt_tensor = torch.rand(128, device='cuda', dtype=torch.bfloat16)
164+
base_mod = self.TestMod(inpt_tensor, 32, 2)
165+
state_dict = base_mod.state_dict()
166+
saved_state_dict = self.save_state_dict_to_buffer(state_dict)
167+
168+
other_mod = self.TestMod(inpt_tensor, 64, 1)
169+
other_mod.load_state_dict(torch.load(saved_state_dict))
170+
assert other_mod.param.block_size == 64
171+
assert other_mod.param.scaler_block_size == 1
124172

125173
if __name__ == "__main__":
126174
unittest.main()

torchao/dtypes/nf4tensor.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313
NF4_OPS_TABLE: Dict[Any, Any] = {}
1414

1515

16+
def same_metadata(a: "NF4Tensor", b: "NF4Tensor"):
17+
both_nf4 = isinstance(a, NF4Tensor) and isinstance(b, NF4Tensor)
18+
return (
19+
both_nf4 and
20+
a.block_size == b.block_size
21+
and a.scaler_block_size == b.scaler_block_size
22+
and a.n_blocks == b.n_blocks
23+
)
1624

1725
def implements(aten_ops):
1826
"""Use this decorator to implement a function for an aten op in __torch_dispatch__"""
@@ -33,6 +41,34 @@ def _to_copy(func, *args, **kwargs):
3341
return args[0][0].get_original_weight().to(args[1]['dtype'])
3442

3543

44+
@implements(
45+
[
46+
aten.copy_.default,
47+
]
48+
)
49+
def copy_(func, *args, **kwargs):
50+
original: NF4Tensor = args[0][0]
51+
copy_in: torch.Tensor = args[0][1]
52+
53+
# Base Case
54+
if same_metadata(original, copy_in):
55+
original_tensors = original.__tensor_flatten__()[0]
56+
for tensor_name in original_tensors:
57+
getattr(original, tensor_name).copy_(getattr(copy_in, tensor_name))
58+
return
59+
60+
# Convert Non NF4Tensor into NF4 for copy in
61+
if not isinstance(copy_in, NF4Tensor):
62+
copy_in_nf4 = NF4Tensor.from_tensor(copy_in, original.block_size, original.scaler_block_size)
63+
return original.copy_(copy_in_nf4)
64+
65+
# Other Tensor is not a NF4Tensor
66+
full_precision = copy_in.get_original_weight()
67+
same_meta_nf4 = NF4Tensor.from_tensor(
68+
full_precision, original.block_size, original.scaler_block_size
69+
)
70+
return original.copy_(same_meta_nf4)
71+
3672
@dataclass
3773
class SubclassTensorArgs:
3874
original_shape: torch.Size

0 commit comments

Comments
 (0)