Skip to content

Commit 83420a1

Browse files
committed
Add support for resharding for fbgemm configs and int4 preshuffle kernel
Summary: added transpose and cat op support, and also some custom transpose/reshape/unflatten support for resharding. In the future we should probably provide examples for using distributed checkpoint for resharding Test Plan: python test/dtypes/test_fbgemm_int4.py -k test_transpose python test/dtypes/test_fbgemm_int4.py -k test_cat python test/dtypes/test_fbgemm_fp8.py -k test_transpose python test/dtypes/test_fbgemm_fp8.py -k test_cat python test/dtypes/test_int4_groupwise_preshuffle.py Reviewers: Subscribers: Tasks: Tags:
1 parent 6243040 commit 83420a1

File tree

10 files changed

+1282
-123
lines changed

10 files changed

+1282
-123
lines changed

test/dtypes/test_fbgemm_fp8.py

Lines changed: 90 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
import torch
1010
from torch.testing._internal.common_utils import (
1111
TestCase,
12+
instantiate_parametrized_tests,
13+
parametrize,
1214
run_tests,
1315
)
1416

15-
from torchao.float8.config import e4m3_dtype
1617
from torchao.quantization import (
17-
FbgemmConfig,
18+
Float8DynamicActivationFloat8WeightConfig,
19+
PerRow,
1820
quantize_,
1921
)
2022
from torchao.quantization.utils import compute_error
@@ -23,36 +25,35 @@
2325
is_sm_at_least_90,
2426
)
2527

28+
FBGEMM_CONFIG = Float8DynamicActivationFloat8WeightConfig(
29+
granularity=PerRow(), kernel="fbgemm"
30+
)
31+
ATEN_CONFIG = Float8DynamicActivationFloat8WeightConfig(
32+
granularity=PerRow(), kernel="aten"
33+
)
34+
2635

2736
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
2837
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
2938
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
3039
class TestFbgemmFp8Tensor(TestCase):
3140
def setUp(self):
32-
self.config = FbgemmConfig(
33-
input_dtype=e4m3_dtype,
34-
weight_dtype=e4m3_dtype,
35-
output_dtype=torch.bfloat16,
36-
)
37-
self.bmm_config = FbgemmConfig(
38-
input_dtype=e4m3_dtype,
39-
weight_dtype=e4m3_dtype,
40-
output_dtype=torch.bfloat16,
41-
transpose_input=True,
42-
)
4341
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
4442

45-
def test_linear(self):
43+
@parametrize("config", [FBGEMM_CONFIG, ATEN_CONFIG])
44+
def test_linear(self, config):
4645
dtype = torch.bfloat16
4746
device = "cuda"
4847
input = torch.randn(1, 128, dtype=dtype, device=device)
4948
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
5049
original = linear(input)
51-
quantize_(linear, self.config)
50+
quantize_(linear, config)
5251
quantized = linear(input)
53-
self.assertTrue(compute_error(original, quantized) > 20)
52+
sqnr = compute_error(original, quantized)
53+
self.assertTrue(sqnr > 20, f"sqnr: {sqnr}")
5454

55-
def test_slice(self):
55+
@parametrize("config", [FBGEMM_CONFIG, ATEN_CONFIG])
56+
def test_slice(self, config):
5657
dtype = torch.bfloat16
5758
device = "cuda"
5859
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
@@ -65,7 +66,7 @@ def test_slice(self):
6566
dummy.weight.narrow(1, 0, 128), requires_grad=False
6667
)
6768

68-
quantize_(dummy, self.config)
69+
quantize_(dummy, config)
6970
weight1 = dummy.weight.narrow(0, 0, 64)
7071
weight2 = dummy.weight.narrow(1, 0, 128)
7172
self.assertEqual(weight1.float8_data, dummy.weight.float8_data.narrow(0, 0, 64))
@@ -81,20 +82,23 @@ def test_slice(self):
8182
res_ref = dummy1(input)
8283
dummy.weight = torch.nn.Parameter(weight1, requires_grad=False)
8384
res = dummy(input)
84-
assert compute_error(res, res_ref) > 25
85+
sqnr = compute_error(res, res_ref)
86+
self.assertTrue(sqnr > 25, f"sqnr: {sqnr}")
8587

8688
input = torch.randn(2, 128, dtype=dtype, device=device)
8789
res_ref = dummy2(input)
8890
dummy.weight = torch.nn.Parameter(weight2, requires_grad=False)
8991
res = dummy(input)
90-
assert compute_error(res, res_ref) > 15
92+
sqnr = compute_error(res, res_ref)
93+
self.assertTrue(sqnr > 15, f"sqnr: {sqnr}")
9194

92-
def test_slice_and_copy_(self):
95+
@parametrize("config", [FBGEMM_CONFIG, ATEN_CONFIG])
96+
def test_slice_and_copy_(self, config):
9397
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
9498
l.weight = torch.nn.Parameter(
9599
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
96100
)
97-
quantize_(l, self.config)
101+
quantize_(l, config)
98102
param = l.weight
99103
param_data = param.data
100104
param_data = param_data.narrow(0, 0, 512)
@@ -104,7 +108,7 @@ def test_slice_and_copy_(self):
104108

105109
# dummy_l has random input (shouldn't be 0)
106110
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
107-
quantize_(dummy_l, self.config)
111+
quantize_(dummy_l, config)
108112
quantized = dummy_l.weight
109113
quantized = quantized.narrow(0, 0, 512)
110114

@@ -113,7 +117,8 @@ def test_slice_and_copy_(self):
113117
# making sure param.data is updated
114118
assert param.data.float8_data[0][0] != orig_value
115119

116-
def test_bmm(self):
120+
@parametrize("config", [FBGEMM_CONFIG])
121+
def test_bmm(self, config):
117122
class M(torch.nn.Module):
118123
def __init__(self, weight):
119124
super().__init__()
@@ -128,24 +133,80 @@ def forward(self, x):
128133
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
129134
m = M(weight).eval()
130135
original = m(input)
131-
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
136+
# we need to transpose the weight first for bmm
137+
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
138+
quantize_(m, config, filter_fn=lambda x, fqn: True)
132139
quantized = m(input)
133140
self.assertTrue(compute_error(original, quantized) > 20)
134141

135-
def test_to_device(self):
142+
@parametrize("config", [FBGEMM_CONFIG, ATEN_CONFIG])
143+
def test_to_device(self, config):
136144
for device in self.GPU_DEVICES:
137145
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
138-
quantize_(linear, self.config)
146+
quantize_(linear, config)
139147
linear.to(device)
140148

141149
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
142-
quantize_(linear, self.config)
150+
quantize_(linear, config)
143151
linear.to(device=device)
144152

145153
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
146-
quantize_(linear, self.config)
154+
quantize_(linear, config)
147155
linear.to(device)
148156

157+
@parametrize("config", [FBGEMM_CONFIG, ATEN_CONFIG])
158+
def test_cat(self, config):
159+
dtype = torch.bfloat16
160+
device = "cuda"
161+
# weight: (256, 128)
162+
linear1 = torch.nn.Linear(128, 256, dtype=dtype)
163+
# weight: (256, 128)
164+
linear2 = torch.nn.Linear(128, 256, dtype=dtype)
165+
166+
cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
167+
dummy1 = torch.nn.Linear(128, 512, bias=False, dtype=dtype, device=device)
168+
169+
dummy1.weight = torch.nn.Parameter(cat_weight1)
170+
quantize_(dummy1, config)
171+
172+
quantize_(linear1, config)
173+
quantize_(linear2, config)
174+
175+
cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
176+
self.assertTrue(cat_qweight1.shape, (512, 128))
177+
self.assertEqual(dummy1.weight.float8_data, cat_qweight1.float8_data)
178+
self.assertEqual(dummy1.weight.scale, cat_qweight1.scale)
179+
180+
# concat with dim == 1 is not really correct and will be fixed later
181+
# when we support distributed checkpointing
182+
cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
183+
self.assertTrue(cat_qweight2.shape, (256, 256))
184+
ref_float8_data = torch.cat(
185+
[linear1.weight.float8_data, linear2.weight.float8_data], dim=1
186+
)
187+
ref_scale = linear1.weight.scale
188+
self.assertEqual(cat_qweight2.float8_data, ref_float8_data)
189+
self.assertEqual(cat_qweight2.scale, ref_scale)
190+
191+
@parametrize("config", [FBGEMM_CONFIG])
192+
def test_transpose(self, config):
193+
dtype = torch.bfloat16
194+
device = "cuda"
195+
# weight: (256, 128)
196+
linear1 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
197+
quantize_(linear1, config)
198+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
199+
linear1.bias = torch.nn.Parameter(torch.randn(128, dtype=dtype, device=device))
200+
self.assertTrue(linear1.weight.shape, (128, 256))
201+
202+
input = torch.randn(32, 256, dtype=dtype, device=device)
203+
# make sure it runs
204+
res = linear1(input)
205+
self.assertTrue(res.shape, (32, 128))
206+
207+
208+
instantiate_parametrized_tests(TestFbgemmFp8Tensor)
209+
149210

150211
if __name__ == "__main__":
151212
run_tests()

test/dtypes/test_fbgemm_int4.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def setUp(self):
3939
weight_dtype=torch.int4,
4040
output_dtype=torch.bfloat16,
4141
block_size=[1, 1, 128],
42-
transpose_input=True,
4342
)
4443
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
4544

@@ -134,6 +133,7 @@ def forward(self, x):
134133
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
135134
m = M(weight).eval()
136135
original = m(input)
136+
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
137137
quantize_(m, self.bmm_config, filter_fn=lambda x, fqn: True)
138138
quantized = m(input)
139139
self.assertTrue(compute_error(original, quantized) > 18)
@@ -152,6 +152,53 @@ def test_to_device(self):
152152
quantize_(linear, self.config)
153153
linear.to(device)
154154

155+
def test_cat(self):
156+
dtype = torch.bfloat16
157+
device = "cuda"
158+
# weight: (256, 128)
159+
linear1 = torch.nn.Linear(128, 256, dtype=dtype)
160+
# weight: (256, 128)
161+
linear2 = torch.nn.Linear(128, 256, dtype=dtype)
162+
163+
cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
164+
cat_weight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
165+
dummy1 = torch.nn.Linear(128, 512, bias=False, dtype=dtype, device=device)
166+
dummy2 = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
167+
168+
dummy1.weight = torch.nn.Parameter(cat_weight1)
169+
dummy2.weight = torch.nn.Parameter(cat_weight2)
170+
quantize_(dummy1, self.config)
171+
quantize_(dummy2, self.config)
172+
173+
quantize_(linear1, self.config)
174+
quantize_(linear2, self.config)
175+
176+
cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
177+
self.assertTrue(cat_qweight1.shape, (512, 128))
178+
self.assertEqual(dummy1.weight.packed_weight, cat_qweight1.packed_weight)
179+
self.assertEqual(dummy1.weight.scale, cat_qweight1.scale)
180+
self.assertEqual(dummy1.weight.zero_point, cat_qweight1.zero_point)
181+
182+
cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
183+
self.assertTrue(cat_qweight2.shape, (256, 256))
184+
self.assertEqual(dummy2.weight.packed_weight, cat_qweight2.packed_weight)
185+
self.assertEqual(dummy2.weight.scale, cat_qweight2.scale)
186+
self.assertEqual(dummy2.weight.zero_point, cat_qweight2.zero_point)
187+
188+
def test_transpose(self):
189+
# weight: (256, 128)
190+
linear1 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
191+
quantize_(linear1, self.config)
192+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
193+
# transpose again to return to the original state
194+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
195+
self.assertTrue(linear1.weight.shape, (256, 128))
196+
197+
input = torch.randn(32, 128, dtype=torch.bfloat16, device="cuda")
198+
# make sure it runs
199+
res = linear1(input)
200+
self.assertTrue(res.shape, (32, 256))
201+
155202

156203
if __name__ == "__main__":
157204
run_tests()

0 commit comments

Comments
 (0)