Skip to content

Commit 177ae9b

Browse files
committed
add copy_ dispatch and some tests
1 parent 55e5d40 commit 177ae9b

File tree

2 files changed

+85
-1
lines changed

2 files changed

+85
-1
lines changed

test/modules/test_nf4_linear.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from torch.testing._internal.common_utils import TestCase
77
from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor
88
import torch.nn.functional as F
9-
9+
import io
10+
from collections import OrderedDict
1011

1112
bnb_available = False
1213

@@ -44,6 +45,16 @@ def _build_bnb_linear(input_weight, device):
4445

4546

4647
class TestNF4Linear(TestCase):
48+
class TestMod(nn.Module):
49+
def __init__(self, tensor, block_size, scaler_block_size):
50+
super().__init__()
51+
self.param = torch.nn.Parameter(NF4Tensor.from_tensor(tensor, block_size, scaler_block_size))
52+
53+
def save_state_dict_to_buffer(self, state_dict: OrderedDict):
54+
buffer = io.BytesIO()
55+
torch.save(state_dict, buffer)
56+
buffer.seek(0)
57+
return buffer
4758

4859
def test_register_nf4_as_param(self):
4960
nf4_tensor = NF4Tensor.from_tensor(
@@ -121,6 +132,43 @@ def test_nf4_bnb_linear(self):
121132
assert err_native < 0.5 * dim
122133
assert err_bnb < 0.5 * dim
123134

135+
@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
136+
def test_load_from_bfloat16(self):
137+
"""Tests loading to and from different module state dicts"""
138+
inpt_tensor = torch.rand(64, device='cuda', dtype=torch.bfloat16)
139+
base_mod = self.TestMod(inpt_tensor, 32, 2)
140+
141+
bf16_dummy_dict = {"param": inpt_tensor}
142+
base_mod.load_state_dict(bf16_dummy_dict)
143+
144+
assert base_mod.param.block_size == 32
145+
assert base_mod.param.scaler_block_size == 2
146+
147+
@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
148+
def test_load_from_nf4_same_meta(self):
149+
"""Tests loading to and from different module state dicts"""
150+
inpt_tensor = torch.rand(64, device='cuda', dtype=torch.bfloat16)
151+
base_mod = self.TestMod(inpt_tensor, 32, 2)
152+
state_dict = base_mod.state_dict()
153+
saved_state_dict = self.save_state_dict_to_buffer(state_dict)
154+
155+
other_mod = self.TestMod(inpt_tensor, 32, 2)
156+
other_mod.load_state_dict(torch.load(saved_state_dict))
157+
assert other_mod.param.block_size == 32
158+
assert other_mod.param.scaler_block_size == 2
159+
160+
@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
161+
def test_load_from_nf4_diff_meta(self):
162+
"""Tests loading to and from different module state dicts"""
163+
inpt_tensor = torch.rand(128, device='cuda', dtype=torch.bfloat16)
164+
base_mod = self.TestMod(inpt_tensor, 32, 2)
165+
state_dict = base_mod.state_dict()
166+
saved_state_dict = self.save_state_dict_to_buffer(state_dict)
167+
168+
other_mod = self.TestMod(inpt_tensor, 64, 1)
169+
other_mod.load_state_dict(torch.load(saved_state_dict))
170+
assert other_mod.param.block_size == 64
171+
assert other_mod.param.scaler_block_size == 1
124172

125173
if __name__ == "__main__":
126174
unittest.main()

torchao/dtypes/nf4tensor.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313
NF4_OPS_TABLE: Dict[Any, Any] = {}
1414

1515

16+
def same_metadata(a: "NF4Tensor", b: "NF4Tensor"):
17+
both_nf4 = isinstance(a, NF4Tensor) and isinstance(b, NF4Tensor)
18+
return (
19+
both_nf4 and
20+
a.block_size == b.block_size
21+
and a.scaler_block_size == b.scaler_block_size
22+
and a.n_blocks == b.n_blocks
23+
)
1624

1725
def implements(aten_ops):
1826
"""Use this decorator to implement a function for an aten op in __torch_dispatch__"""
@@ -29,6 +37,34 @@ def noop_detach(func, *args, **kwargs):
2937
return args[0][0]
3038

3139

40+
@implements(
41+
[
42+
aten.copy_.default,
43+
]
44+
)
45+
def copy_(func, *args, **kwargs):
46+
original: NF4Tensor = args[0][0]
47+
copy_in: torch.Tensor = args[0][1]
48+
49+
# Base Case
50+
if same_metadata(original, copy_in):
51+
original_tensors = original.__tensor_flatten__()[0]
52+
for tensor_name in original_tensors:
53+
getattr(original, tensor_name).copy_(getattr(copy_in, tensor_name))
54+
return
55+
56+
# Convert Non NF4Tensor into NF4 for copy in
57+
if not isinstance(copy_in, NF4Tensor):
58+
copy_in_nf4 = NF4Tensor.from_tensor(copy_in, original.block_size, original.scaler_block_size)
59+
return original.copy_(copy_in_nf4)
60+
61+
# Other Tensor is not a NF4Tensor
62+
full_precision = copy_in.get_original_weight()
63+
same_meta_nf4 = NF4Tensor.from_tensor(
64+
full_precision, original.block_size, original.scaler_block_size
65+
)
66+
return original.copy_(same_meta_nf4)
67+
3268
@dataclass
3369
class SubclassTensorArgs:
3470
original_shape: torch.Size

0 commit comments

Comments
 (0)