Skip to content

Commit a39fd37

Browse files
committed
Add support for resharding for fbgemm configs
Summary: added transpose and cat op support, and also some custom transpose/reshape/unflatten support for resharding. In the future we should probably provide examples for using distributed checkpoint for resharding Test Plan: python test/dtypes/test_fbgemm_int4.py -k test_transpose python test/dtypes/test_fbgemm_int4.py -k test_cat python test/dtypes/test_fbgemm_fp8.py -k test_transpose python test/dtypes/test_fbgemm_fp8.py -k test_cat Reviewers: Subscribers: Tasks: Tags:
1 parent 6243040 commit a39fd37

File tree

7 files changed

+460
-61
lines changed

7 files changed

+460
-61
lines changed

test/dtypes/test_fbgemm_fp8.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
run_tests,
1313
)
1414

15-
from torchao.float8.config import e4m3_dtype
1615
from torchao.quantization import (
1716
FbgemmConfig,
1817
quantize_,
@@ -29,17 +28,12 @@
2928
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
3029
class TestFbgemmFp8Tensor(TestCase):
3130
def setUp(self):
31+
self.e4m3_dtype = torch.float8_e4m3fn
3232
self.config = FbgemmConfig(
33-
input_dtype=e4m3_dtype,
34-
weight_dtype=e4m3_dtype,
33+
input_dtype=self.e4m3_dtype,
34+
weight_dtype=self.e4m3_dtype,
3535
output_dtype=torch.bfloat16,
3636
)
37-
self.bmm_config = FbgemmConfig(
38-
input_dtype=e4m3_dtype,
39-
weight_dtype=e4m3_dtype,
40-
output_dtype=torch.bfloat16,
41-
transpose_input=True,
42-
)
4337
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
4438

4539
def test_linear(self):
@@ -128,7 +122,9 @@ def forward(self, x):
128122
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
129123
m = M(weight).eval()
130124
original = m(input)
131-
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
125+
# we need to transpose the weight first for bmm
126+
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
127+
quantize_(m, self.config, filter_fn=lambda x, fqn: True)
132128
quantized = m(input)
133129
self.assertTrue(compute_error(original, quantized) > 20)
134130

@@ -146,6 +142,54 @@ def test_to_device(self):
146142
quantize_(linear, self.config)
147143
linear.to(device)
148144

145+
def test_cat(self):
146+
dtype = torch.bfloat16
147+
device = "cuda"
148+
# weight: (256, 128)
149+
linear1 = torch.nn.Linear(128, 256, dtype=dtype)
150+
# weight: (256, 128)
151+
linear2 = torch.nn.Linear(128, 256, dtype=dtype)
152+
153+
cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
154+
dummy1 = torch.nn.Linear(128, 512, bias=False, dtype=dtype, device=device)
155+
156+
dummy1.weight = torch.nn.Parameter(cat_weight1)
157+
quantize_(dummy1, self.config)
158+
159+
quantize_(linear1, self.config)
160+
quantize_(linear2, self.config)
161+
162+
cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
163+
self.assertTrue(cat_qweight1.shape, (512, 128))
164+
self.assertEqual(dummy1.weight.float8_data, cat_qweight1.float8_data)
165+
self.assertEqual(dummy1.weight.scale, cat_qweight1.scale)
166+
167+
# concat with dim == 1 is not really correct and will be fixed later
168+
# when we support distributed checkpointing
169+
cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
170+
self.assertTrue(cat_qweight2.shape, (256, 256))
171+
ref_float8_data = torch.cat(
172+
[linear1.weight.float8_data, linear2.weight.float8_data], dim=1
173+
)
174+
ref_scale = linear1.weight.scale
175+
self.assertEqual(cat_qweight2.float8_data, ref_float8_data)
176+
self.assertEqual(cat_qweight2.scale, ref_scale)
177+
178+
def test_transpose(self):
179+
dtype = torch.bfloat16
180+
device = "cuda"
181+
# weight: (256, 128)
182+
linear1 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
183+
quantize_(linear1, self.config)
184+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
185+
linear1.bias = torch.nn.Parameter(torch.randn(128, dtype=dtype, device=device))
186+
self.assertTrue(linear1.weight.shape, (128, 256))
187+
188+
input = torch.randn(32, 256, dtype=dtype, device=device)
189+
# make sure it runs
190+
res = linear1(input)
191+
self.assertTrue(res.shape, (32, 128))
192+
149193

150194
if __name__ == "__main__":
151195
run_tests()

test/dtypes/test_fbgemm_int4.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def setUp(self):
3939
weight_dtype=torch.int4,
4040
output_dtype=torch.bfloat16,
4141
block_size=[1, 1, 128],
42-
transpose_input=True,
4342
)
4443
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
4544

@@ -134,6 +133,7 @@ def forward(self, x):
134133
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
135134
m = M(weight).eval()
136135
original = m(input)
136+
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
137137
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
138138
quantized = m(input)
139139
self.assertTrue(compute_error(original, quantized) > 18)
@@ -152,6 +152,53 @@ def test_to_device(self):
152152
quantize_(linear, self.config)
153153
linear.to(device)
154154

155+
def test_cat(self):
156+
dtype = torch.bfloat16
157+
device = "cuda"
158+
# weight: (256, 128)
159+
linear1 = torch.nn.Linear(128, 256, dtype=dtype)
160+
# weight: (256, 128)
161+
linear2 = torch.nn.Linear(128, 256, dtype=dtype)
162+
163+
cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
164+
cat_weight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
165+
dummy1 = torch.nn.Linear(128, 512, bias=False, dtype=dtype, device=device)
166+
dummy2 = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
167+
168+
dummy1.weight = torch.nn.Parameter(cat_weight1)
169+
dummy2.weight = torch.nn.Parameter(cat_weight2)
170+
quantize_(dummy1, self.config)
171+
quantize_(dummy2, self.config)
172+
173+
quantize_(linear1, self.config)
174+
quantize_(linear2, self.config)
175+
176+
cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
177+
self.assertTrue(cat_qweight1.shape, (512, 128))
178+
self.assertEqual(dummy1.weight.packed_weight, cat_qweight1.packed_weight)
179+
self.assertEqual(dummy1.weight.scale, cat_qweight1.scale)
180+
self.assertEqual(dummy1.weight.zero_point, cat_qweight1.zero_point)
181+
182+
cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
183+
self.assertTrue(cat_qweight2.shape, (256, 256))
184+
self.assertEqual(dummy2.weight.packed_weight, cat_qweight2.packed_weight)
185+
self.assertEqual(dummy2.weight.scale, cat_qweight2.scale)
186+
self.assertEqual(dummy2.weight.zero_point, cat_qweight2.zero_point)
187+
188+
def test_transpose(self):
189+
# weight: (256, 128)
190+
linear1 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
191+
quantize_(linear1, self.config)
192+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
193+
# transpose again to return to the original state
194+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
195+
self.assertTrue(linear1.weight.shape, (256, 128))
196+
197+
input = torch.randn(32, 128, dtype=torch.bfloat16, device="cuda")
198+
# make sure it runs
199+
res = linear1(input)
200+
self.assertTrue(res.shape, (32, 256))
201+
155202

156203
if __name__ == "__main__":
157204
run_tests()

torchao/core/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@
1212

1313
import torch
1414

15+
__all__ = [
16+
"AOBaseConfig",
17+
"VersionMismatchError",
18+
"config_to_dict",
19+
"config_from_dict",
20+
"ALLOWED_AO_MODULES",
21+
]
22+
1523

1624
class AOBaseConfig(abc.ABC):
1725
"""

0 commit comments

Comments
 (0)