Skip to content

Commit dfcda29

Browse files
committed
Add support for resharding for fbgemm configs and int4 preshuffle kernel
Summary: added transpose and cat op support, and also some custom transpose/reshape/unflatten support for resharding. In the future we should probably provide examples for using distributed checkpoint for resharding Test Plan: python test/dtypes/test_fbgemm_int4.py -k test_transpose python test/dtypes/test_fbgemm_int4.py -k test_cat python test/dtypes/test_fbgemm_fp8.py -k test_transpose python test/dtypes/test_fbgemm_fp8.py -k test_cat python test/dtypes/test_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags:
1 parent 6243040 commit dfcda29

File tree

10 files changed

+1268
-123
lines changed

10 files changed

+1268
-123
lines changed

test/dtypes/test_fbgemm_fp8.py

Lines changed: 91 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
from torch.testing._internal.common_utils import (
1111
TestCase,
1212
run_tests,
13+
parametrize,
14+
instantiate_parametrized_tests,
1315
)
1416

15-
from torchao.float8.config import e4m3_dtype
1617
from torchao.quantization import (
17-
FbgemmConfig,
18+
Float8DynamicActivationFloat8WeightConfig,
19+
PerRow,
1820
quantize_,
1921
)
2022
from torchao.quantization.utils import compute_error
@@ -24,35 +26,35 @@
2426
)
2527

2628

29+
FBGEMM_CONFIG = Float8DynamicActivationFloat8WeightConfig(
30+
granularity=PerRow(), kernel="fbgemm"
31+
)
32+
ATEN_CONFIG = Float8DynamicActivationFloat8WeightConfig(
33+
granularity=PerRow(), kernel="aten"
34+
)
35+
36+
2737
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
2838
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
2939
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
3040
class TestFbgemmFp8Tensor(TestCase):
3141
def setUp(self):
32-
self.config = FbgemmConfig(
33-
input_dtype=e4m3_dtype,
34-
weight_dtype=e4m3_dtype,
35-
output_dtype=torch.bfloat16,
36-
)
37-
self.bmm_config = FbgemmConfig(
38-
input_dtype=e4m3_dtype,
39-
weight_dtype=e4m3_dtype,
40-
output_dtype=torch.bfloat16,
41-
transpose_input=True,
42-
)
4342
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
4443

45-
def test_linear(self):
44+
@parametrize("config", [FBGEMM_CONFIG, ATEN_CONFIG])
45+
def test_linear(self, config):
4646
dtype = torch.bfloat16
4747
device = "cuda"
4848
input = torch.randn(1, 128, dtype=dtype, device=device)
4949
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
5050
original = linear(input)
51-
quantize_(linear, self.config)
51+
quantize_(linear, config)
5252
quantized = linear(input)
53-
self.assertTrue(compute_error(original, quantized) > 20)
53+
sqnr = compute_error(original, quantized)
54+
self.assertTrue(sqnr > 20, f"sqnr: {sqnr}")
5455

55-
def test_slice(self):
56+
@parametrize("config", [FBGEMM_CONFIG, ATEN_CONFIG])
57+
def test_slice(self, config):
5658
dtype = torch.bfloat16
5759
device = "cuda"
5860
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
@@ -65,7 +67,7 @@ def test_slice(self):
6567
dummy.weight.narrow(1, 0, 128), requires_grad=False
6668
)
6769

68-
quantize_(dummy, self.config)
70+
quantize_(dummy, config)
6971
weight1 = dummy.weight.narrow(0, 0, 64)
7072
weight2 = dummy.weight.narrow(1, 0, 128)
7173
self.assertEqual(weight1.float8_data, dummy.weight.float8_data.narrow(0, 0, 64))
@@ -81,20 +83,23 @@ def test_slice(self):
8183
res_ref = dummy1(input)
8284
dummy.weight = torch.nn.Parameter(weight1, requires_grad=False)
8385
res = dummy(input)
84-
assert compute_error(res, res_ref) > 25
86+
sqnr = compute_error(res, res_ref)
87+
self.assertTrue(sqnr > 25, f"sqnr: {sqnr}")
8588

8689
input = torch.randn(2, 128, dtype=dtype, device=device)
8790
res_ref = dummy2(input)
8891
dummy.weight = torch.nn.Parameter(weight2, requires_grad=False)
8992
res = dummy(input)
90-
assert compute_error(res, res_ref) > 15
93+
sqnr = compute_error(res, res_ref)
94+
self.assertTrue(sqnr > 15, f"sqnr: {sqnr}")
9195

92-
def test_slice_and_copy_(self):
96+
@parametrize("config", [FBGEMM_CONFIG, ATEN_CONFIG])
97+
def test_slice_and_copy_(self, config):
9398
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
9499
l.weight = torch.nn.Parameter(
95100
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
96101
)
97-
quantize_(l, self.config)
102+
quantize_(l, config)
98103
param = l.weight
99104
param_data = param.data
100105
param_data = param_data.narrow(0, 0, 512)
@@ -104,7 +109,7 @@ def test_slice_and_copy_(self):
104109

105110
# dummy_l has random input (shouldn't be 0)
106111
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
107-
quantize_(dummy_l, self.config)
112+
quantize_(dummy_l, config)
108113
quantized = dummy_l.weight
109114
quantized = quantized.narrow(0, 0, 512)
110115

@@ -113,7 +118,8 @@ def test_slice_and_copy_(self):
113118
# making sure param.data is updated
114119
assert param.data.float8_data[0][0] != orig_value
115120

116-
def test_bmm(self):
121+
@parametrize("config", [FBGEMM_CONFIG])
122+
def test_bmm(self, config):
117123
class M(torch.nn.Module):
118124
def __init__(self, weight):
119125
super().__init__()
@@ -128,24 +134,80 @@ def forward(self, x):
128134
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
129135
m = M(weight).eval()
130136
original = m(input)
131-
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
137+
# we need to transpose the weight first for bmm
138+
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
139+
quantize_(m, config, filter_fn=lambda x, fqn: True)
132140
quantized = m(input)
133141
self.assertTrue(compute_error(original, quantized) > 20)
134142

135-
def test_to_device(self):
143+
@parametrize("config", [FBGEMM_CONFIG, ATEN_CONFIG])
144+
def test_to_device(self, config):
136145
for device in self.GPU_DEVICES:
137146
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
138-
quantize_(linear, self.config)
147+
quantize_(linear, config)
139148
linear.to(device)
140149

141150
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
142-
quantize_(linear, self.config)
151+
quantize_(linear, config)
143152
linear.to(device=device)
144153

145154
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
146-
quantize_(linear, self.config)
155+
quantize_(linear, config)
147156
linear.to(device)
148157

158+
@parametrize("config", [FBGEMM_CONFIG, ATEN_CONFIG])
159+
def test_cat(self, config):
160+
dtype = torch.bfloat16
161+
device = "cuda"
162+
# weight: (256, 128)
163+
linear1 = torch.nn.Linear(128, 256, dtype=dtype)
164+
# weight: (256, 128)
165+
linear2 = torch.nn.Linear(128, 256, dtype=dtype)
166+
167+
cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
168+
dummy1 = torch.nn.Linear(128, 512, bias=False, dtype=dtype, device=device)
169+
170+
dummy1.weight = torch.nn.Parameter(cat_weight1)
171+
quantize_(dummy1, config)
172+
173+
quantize_(linear1, config)
174+
quantize_(linear2, config)
175+
176+
cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
177+
self.assertTrue(cat_qweight1.shape, (512, 128))
178+
self.assertEqual(dummy1.weight.float8_data, cat_qweight1.float8_data)
179+
self.assertEqual(dummy1.weight.scale, cat_qweight1.scale)
180+
181+
# concat with dim == 1 is not really correct and will be fixed later
182+
# when we support distributed checkpointing
183+
cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
184+
self.assertTrue(cat_qweight2.shape, (256, 256))
185+
ref_float8_data = torch.cat(
186+
[linear1.weight.float8_data, linear2.weight.float8_data], dim=1
187+
)
188+
ref_scale = linear1.weight.scale
189+
self.assertEqual(cat_qweight2.float8_data, ref_float8_data)
190+
self.assertEqual(cat_qweight2.scale, ref_scale)
191+
192+
@parametrize("config", [FBGEMM_CONFIG])
193+
def test_transpose(self, config):
194+
dtype = torch.bfloat16
195+
device = "cuda"
196+
# weight: (256, 128)
197+
linear1 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
198+
quantize_(linear1, config)
199+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
200+
linear1.bias = torch.nn.Parameter(torch.randn(128, dtype=dtype, device=device))
201+
self.assertTrue(linear1.weight.shape, (128, 256))
202+
203+
input = torch.randn(32, 256, dtype=dtype, device=device)
204+
# make sure it runs
205+
res = linear1(input)
206+
self.assertTrue(res.shape, (32, 128))
207+
208+
209+
instantiate_parametrized_tests(TestFbgemmFp8Tensor)
210+
149211

150212
if __name__ == "__main__":
151213
run_tests()

test/dtypes/test_fbgemm_int4.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def setUp(self):
3939
weight_dtype=torch.int4,
4040
output_dtype=torch.bfloat16,
4141
block_size=[1, 1, 128],
42-
transpose_input=True,
4342
)
4443
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
4544

@@ -134,6 +133,7 @@ def forward(self, x):
134133
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
135134
m = M(weight).eval()
136135
original = m(input)
136+
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
137137
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
138138
quantized = m(input)
139139
self.assertTrue(compute_error(original, quantized) > 18)
@@ -152,6 +152,53 @@ def test_to_device(self):
152152
quantize_(linear, self.config)
153153
linear.to(device)
154154

155+
def test_cat(self):
156+
dtype = torch.bfloat16
157+
device = "cuda"
158+
# weight: (256, 128)
159+
linear1 = torch.nn.Linear(128, 256, dtype=dtype)
160+
# weight: (256, 128)
161+
linear2 = torch.nn.Linear(128, 256, dtype=dtype)
162+
163+
cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
164+
cat_weight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
165+
dummy1 = torch.nn.Linear(128, 512, bias=False, dtype=dtype, device=device)
166+
dummy2 = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
167+
168+
dummy1.weight = torch.nn.Parameter(cat_weight1)
169+
dummy2.weight = torch.nn.Parameter(cat_weight2)
170+
quantize_(dummy1, self.config)
171+
quantize_(dummy2, self.config)
172+
173+
quantize_(linear1, self.config)
174+
quantize_(linear2, self.config)
175+
176+
cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
177+
self.assertTrue(cat_qweight1.shape, (512, 128))
178+
self.assertEqual(dummy1.weight.packed_weight, cat_qweight1.packed_weight)
179+
self.assertEqual(dummy1.weight.scale, cat_qweight1.scale)
180+
self.assertEqual(dummy1.weight.zero_point, cat_qweight1.zero_point)
181+
182+
cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
183+
self.assertTrue(cat_qweight2.shape, (256, 256))
184+
self.assertEqual(dummy2.weight.packed_weight, cat_qweight2.packed_weight)
185+
self.assertEqual(dummy2.weight.scale, cat_qweight2.scale)
186+
self.assertEqual(dummy2.weight.zero_point, cat_qweight2.zero_point)
187+
188+
def test_transpose(self):
189+
# weight: (256, 128)
190+
linear1 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
191+
quantize_(linear1, self.config)
192+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
193+
# transpose again to return to the original state
194+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
195+
self.assertTrue(linear1.weight.shape, (256, 128))
196+
197+
input = torch.randn(32, 128, dtype=torch.bfloat16, device="cuda")
198+
# make sure it runs
199+
res = linear1(input)
200+
self.assertTrue(res.shape, (32, 256))
201+
155202

156203
if __name__ == "__main__":
157204
run_tests()

0 commit comments

Comments
 (0)