|
12 | 12 | run_tests,
|
13 | 13 | )
|
14 | 14 |
|
15 |
| -from torchao.float8.config import e4m3_dtype |
16 | 15 | from torchao.quantization import (
|
17 | 16 | FbgemmConfig,
|
18 | 17 | quantize_,
|
|
29 | 28 | @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
|
30 | 29 | class TestFbgemmFp8Tensor(TestCase):
|
31 | 30 | def setUp(self):
|
| 31 | + self.e4m3_dtype = torch.float8_e4m3fn |
32 | 32 | self.config = FbgemmConfig(
|
33 |
| - input_dtype=e4m3_dtype, |
34 |
| - weight_dtype=e4m3_dtype, |
| 33 | + input_dtype=self.e4m3_dtype, |
| 34 | + weight_dtype=self.e4m3_dtype, |
35 | 35 | output_dtype=torch.bfloat16,
|
36 | 36 | )
|
37 | 37 | self.bmm_config = FbgemmConfig(
|
38 |
| - input_dtype=e4m3_dtype, |
39 |
| - weight_dtype=e4m3_dtype, |
| 38 | + input_dtype=self.e4m3_dtype, |
| 39 | + weight_dtype=self.e4m3_dtype, |
40 | 40 | output_dtype=torch.bfloat16,
|
41 | 41 | transpose_input=True,
|
42 | 42 | )
|
@@ -146,6 +146,53 @@ def test_to_device(self):
|
146 | 146 | quantize_(linear, self.config)
|
147 | 147 | linear.to(device)
|
148 | 148 |
|
| 149 | + def test_cat(self): |
| 150 | + dtype = torch.bfloat16 |
| 151 | + device = "cuda" |
| 152 | + # weight: (256, 128) |
| 153 | + linear1 = torch.nn.Linear(128, 256, dtype=dtype) |
| 154 | + # weight: (256, 128) |
| 155 | + linear2 = torch.nn.Linear(128, 256, dtype=dtype) |
| 156 | + |
| 157 | + cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0) |
| 158 | + dummy1 = torch.nn.Linear(128, 512, bias=False, dtype=dtype, device=device) |
| 159 | + |
| 160 | + dummy1.weight = torch.nn.Parameter(cat_weight1) |
| 161 | + quantize_(dummy1, self.config) |
| 162 | + |
| 163 | + quantize_(linear1, self.config) |
| 164 | + quantize_(linear2, self.config) |
| 165 | + |
| 166 | + cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0) |
| 167 | + self.assertTrue(cat_qweight1.shape, (512, 128)) |
| 168 | + self.assertEqual(dummy1.weight.float8_data, cat_qweight1.float8_data) |
| 169 | + self.assertEqual(dummy1.weight.scale, cat_qweight1.scale) |
| 170 | + |
| 171 | + # concat with dim == 1 is not really correct and will be fixed later |
| 172 | + # when we support distributed checkpointing |
| 173 | + cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1) |
| 174 | + self.assertTrue(cat_qweight2.shape, (256, 256)) |
| 175 | + ref_float8_data = torch.cat([linear1.weight.float8_data, linear2.weight.float8_data], dim=1) |
| 176 | + ref_scale = linear1.weight.scale |
| 177 | + self.assertEqual(cat_qweight2.float8_data, ref_float8_data) |
| 178 | + self.assertEqual(cat_qweight2.scale, ref_scale) |
| 179 | + |
| 180 | + |
| 181 | + def test_transpose(self): |
| 182 | + dtype = torch.bfloat16 |
| 183 | + device = "cuda" |
| 184 | + # weight: (256, 128) |
| 185 | + linear1 = torch.nn.Linear(128, 256, dtype=dtype, device=device) |
| 186 | + quantize_(linear1, self.config) |
| 187 | + linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous()) |
| 188 | + linear1.bias = torch.nn.Parameter(torch.randn(128, dtype=dtype, device=device)) |
| 189 | + self.assertTrue(linear1.weight.shape, (128, 256)) |
| 190 | + |
| 191 | + input = torch.randn(32, 256, dtype=dtype, device=device) |
| 192 | + # make sure it runs |
| 193 | + res = linear1(input) |
| 194 | + self.assertTrue(res.shape, (32, 128)) |
| 195 | + |
149 | 196 |
|
150 | 197 | if __name__ == "__main__":
|
151 | 198 | run_tests()
|
0 commit comments