|
6 | 6 | from torch.testing._internal.common_utils import TestCase
|
7 | 7 | from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor
|
8 | 8 | import torch.nn.functional as F
|
9 |
| - |
| 9 | +import io |
| 10 | +from collections import OrderedDict |
10 | 11 |
|
11 | 12 | bnb_available = False
|
12 | 13 |
|
@@ -44,6 +45,16 @@ def _build_bnb_linear(input_weight, device):
|
44 | 45 |
|
45 | 46 |
|
46 | 47 | class TestNF4Linear(TestCase):
|
| 48 | + class TestMod(nn.Module): |
| 49 | + def __init__(self, tensor, block_size, scaler_block_size): |
| 50 | + super().__init__() |
| 51 | + self.param = torch.nn.Parameter(NF4Tensor.from_tensor(tensor, block_size, scaler_block_size)) |
| 52 | + |
| 53 | + def save_state_dict_to_buffer(self, state_dict: OrderedDict): |
| 54 | + buffer = io.BytesIO() |
| 55 | + torch.save(state_dict, buffer) |
| 56 | + buffer.seek(0) |
| 57 | + return buffer |
47 | 58 |
|
48 | 59 | def test_register_nf4_as_param(self):
|
49 | 60 | nf4_tensor = NF4Tensor.from_tensor(
|
@@ -121,6 +132,43 @@ def test_nf4_bnb_linear(self):
|
121 | 132 | assert err_native < 0.5 * dim
|
122 | 133 | assert err_bnb < 0.5 * dim
|
123 | 134 |
|
| 135 | + @unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test") |
| 136 | + def test_load_from_bfloat16(self): |
| 137 | + """Tests loading to and from different module state dicts""" |
| 138 | + inpt_tensor = torch.rand(64, device='cuda', dtype=torch.bfloat16) |
| 139 | + base_mod = self.TestMod(inpt_tensor, 32, 2) |
| 140 | + |
| 141 | + bf16_dummy_dict = {"param": inpt_tensor} |
| 142 | + base_mod.load_state_dict(bf16_dummy_dict) |
| 143 | + |
| 144 | + assert base_mod.param.block_size == 32 |
| 145 | + assert base_mod.param.scaler_block_size == 2 |
| 146 | + |
| 147 | + @unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test") |
| 148 | + def test_load_from_nf4_same_meta(self): |
| 149 | + """Tests loading to and from different module state dicts""" |
| 150 | + inpt_tensor = torch.rand(64, device='cuda', dtype=torch.bfloat16) |
| 151 | + base_mod = self.TestMod(inpt_tensor, 32, 2) |
| 152 | + state_dict = base_mod.state_dict() |
| 153 | + saved_state_dict = self.save_state_dict_to_buffer(state_dict) |
| 154 | + |
| 155 | + other_mod = self.TestMod(inpt_tensor, 32, 2) |
| 156 | + other_mod.load_state_dict(torch.load(saved_state_dict)) |
| 157 | + assert other_mod.param.block_size == 32 |
| 158 | + assert other_mod.param.scaler_block_size == 2 |
| 159 | + |
| 160 | + @unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test") |
| 161 | + def test_load_from_nf4_diff_meta(self): |
| 162 | + """Tests loading to and from different module state dicts""" |
| 163 | + inpt_tensor = torch.rand(128, device='cuda', dtype=torch.bfloat16) |
| 164 | + base_mod = self.TestMod(inpt_tensor, 32, 2) |
| 165 | + state_dict = base_mod.state_dict() |
| 166 | + saved_state_dict = self.save_state_dict_to_buffer(state_dict) |
| 167 | + |
| 168 | + other_mod = self.TestMod(inpt_tensor, 64, 1) |
| 169 | + other_mod.load_state_dict(torch.load(saved_state_dict)) |
| 170 | + assert other_mod.param.block_size == 64 |
| 171 | + assert other_mod.param.scaler_block_size == 1 |
124 | 172 |
|
125 | 173 | if __name__ == "__main__":
|
126 | 174 | unittest.main()
|
0 commit comments