Skip to content

Commit 4e562e2

Browse files
committed
Add support for resharding for fbgemm configs
Summary: added transpose and cat op support, and also some custom transpose/reshape/unflatten support for resharding. In the future we should probably provide examples for using distributed checkpoint for resharding Test Plan: python test/dtypes/test_fbgemm_int4.py -k test_transpose python test/dtypes/test_fbgemm_int4.py -k test_cat python test/dtypes/test_fbgemm_fp8.py -k test_transpose python test/dtypes/test_fbgemm_fp8.py -k test_cat Reviewers: Subscribers: Tasks: Tags:
1 parent 6243040 commit 4e562e2

File tree

4 files changed

+389
-20
lines changed

4 files changed

+389
-20
lines changed

test/dtypes/test_fbgemm_fp8.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,53 @@ def test_to_device(self):
146146
quantize_(linear, self.config)
147147
linear.to(device)
148148

149+
def test_cat(self):
150+
dtype = torch.bfloat16
151+
device = "cuda"
152+
# weight: (256, 128)
153+
linear1 = torch.nn.Linear(128, 256, dtype=dtype)
154+
# weight: (256, 128)
155+
linear2 = torch.nn.Linear(128, 256, dtype=dtype)
156+
157+
cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
158+
cat_weight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
159+
dummy1 = torch.nn.Linear(128, 512, bias=False, dtype=dtype, device=device)
160+
161+
dummy1.weight = torch.nn.Parameter(cat_weight1)
162+
quantize_(dummy1, self.config)
163+
164+
quantize_(linear1, self.config)
165+
quantize_(linear2, self.config)
166+
167+
cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
168+
cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
169+
self.assertTrue(cat_qweight1.shape, (512, 128))
170+
self.assertTrue(cat_qweight2.shape, (256, 256))
171+
self.assertEqual(dummy1.weight.float8_data, cat_qweight1.float8_data)
172+
self.assertEqual(dummy1.weight.scale, cat_qweight1.scale)
173+
174+
ref_qweight2_float8_data = torch.cat([linear1.weight.float8_data, linear2.weight.float8_data], dim=1)
175+
ref_qweight2_scale = torch.cat([linear1.weight.scale, linear2.weight.scale], dim=1)
176+
self.assertEqual(cat_qweight2.float8_data, ref_qweight2_float8_data)
177+
self.assertEqual(cat_qweight2.scale, ref_qweight2_scale)
178+
179+
def test_transpose(self):
180+
dtype = torch.bfloat16
181+
device = "cuda"
182+
# weight: (256, 128)
183+
linear1 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
184+
quantize_(linear1, self.config)
185+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
186+
linear1.bias = torch.nn.Parameter(
187+
torch.randn(128, dtype=dtype, device=device)
188+
)
189+
self.assertTrue(linear1.weight.shape, (128, 256))
190+
191+
input = torch.randn(32, 256, dtype=dtype, device=device)
192+
# make sure it runs
193+
res = linear1(input)
194+
self.assertTrue(res.shape, (32, 128))
195+
149196

150197
if __name__ == "__main__":
151198
run_tests()

test/dtypes/test_fbgemm_int4.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,52 @@ def test_to_device(self):
152152
quantize_(linear, self.config)
153153
linear.to(device)
154154

155+
def test_cat(self):
156+
dtype = torch.bfloat16
157+
device = "cuda"
158+
# weight: (256, 128)
159+
linear1 = torch.nn.Linear(128, 256, dtype=dtype)
160+
# weight: (256, 128)
161+
linear2 = torch.nn.Linear(128, 256, dtype=dtype)
162+
163+
cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
164+
cat_weight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
165+
dummy1 = torch.nn.Linear(128, 512, bias=False, dtype=dtype, device=device)
166+
dummy2 = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
167+
168+
dummy1.weight = torch.nn.Parameter(cat_weight1)
169+
dummy2.weight = torch.nn.Parameter(cat_weight2)
170+
quantize_(dummy1, self.config)
171+
quantize_(dummy2, self.config)
172+
173+
quantize_(linear1, self.config)
174+
quantize_(linear2, self.config)
175+
176+
cat_qweight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
177+
cat_qweight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
178+
self.assertTrue(cat_qweight1.shape, (512, 128))
179+
self.assertTrue(cat_qweight2.shape, (256, 256))
180+
self.assertEqual(dummy1.weight.packed_weight, cat_qweight1.packed_weight)
181+
self.assertEqual(dummy1.weight.scale, cat_qweight1.scale)
182+
self.assertEqual(dummy1.weight.zero_point, cat_qweight1.zero_point)
183+
self.assertEqual(dummy2.weight.packed_weight, cat_qweight2.packed_weight)
184+
self.assertEqual(dummy2.weight.scale, cat_qweight2.scale)
185+
self.assertEqual(dummy2.weight.zero_point, cat_qweight2.zero_point)
186+
187+
def test_transpose(self):
188+
# weight: (256, 128)
189+
linear1 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
190+
quantize_(linear1, self.config)
191+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
192+
# transpose again to return to the original state
193+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
194+
self.assertTrue(linear1.weight.shape, (256, 128))
195+
196+
input = torch.randn(32, 128, dtype=torch.bfloat16, device="cuda")
197+
# make sure it runs
198+
res = linear1(input)
199+
self.assertTrue(res.shape, (32, 256))
200+
155201

156202
if __name__ == "__main__":
157203
run_tests()

torchao/dtypes/fbgemm_fp8_tensor.py

Lines changed: 113 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,16 @@
2626

2727
class FbgemmFp8Tensor(TorchAOBaseTensor):
2828
"""
29+
Float8 Rowwise Quantized (weight) Tensor, with float8 rowwise dynamic quantization for activation.
2930
TODO: needs padding for cutlass kernels
31+
32+
Tensor Attributes:
33+
float8_data: float8 raw data, dtype torchao.float8.config.e4m3_dtype
34+
scale: the rowwise scale for float8 Tensor
35+
activation_scale_ub: upper bound for activation scale, used during dynamic quantization for activation
36+
37+
Non-Tensor Attributes:
38+
dtype: Original Tensor dtype
3039
"""
3140

3241
tensor_data_attrs = ["float8_data", "scale", "activation_scale_ub"]
@@ -40,7 +49,9 @@ def __new__(cls, float8_data, scale, activation_scale_ub, dtype):
4049
kwargs["requires_grad"] = False
4150
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
4251

43-
def __init__(self, float8_data, scale, activation_scale_ub, dtype):
52+
def __init__(
53+
self, float8_data, scale, activation_scale_ub, dtype
54+
):
4455
self.float8_data = float8_data
4556
self.scale = scale
4657
self.activation_scale_ub = activation_scale_ub
@@ -85,6 +96,47 @@ def to(self, *args, **kwargs):
8596
self.dtype,
8697
)
8798

99+
def _transpose_and_reshape(self):
100+
"""This is added for resharding support, since the resharding logic for the model we are
101+
working with only support 2D
102+
"""
103+
assert len(self.shape) == 3, (
104+
f"Only expected to be used when the Tensor is 3D, got {len(self.shape)}"
105+
)
106+
dim0, dim1, dim2 = self.shape
107+
# because we first transpose the weight before quantization, we'll recover the original shape
108+
# by swapping dim1 and dim2
109+
original_shape = (dim0, dim2, dim1)
110+
# we must save this as 2D in the state dict, since loading code expects 2D weights
111+
new_shape = (-1, original_shape[-1])
112+
float8_data = self.float8_data
113+
float8_data = float8_data.transpose(1, 2).reshape(*new_shape).contiguous()
114+
scale = self.scale.transpose(1, 2).reshape(*new_shape).contiguous()
115+
return self.__class__(
116+
float8_data,
117+
scale,
118+
self.activation_scale_ub,
119+
self.dtype,
120+
)
121+
122+
def _unflatten(self, num_experts):
123+
"""This is added for resharding support, since the resharding logic for the model we are
124+
working with only support 2D
125+
"""
126+
float8_data = self.float8_data
127+
scale = self.scale
128+
dim0, dim1 = self.shape
129+
float8_data = float8_data.unflatten(0, (num_experts, -1)).squeeze(dim=0)
130+
scale = scale.unflatten(0, (num_experts, -1)).squeeze(dim=0)
131+
dim0, dim1, dim2 = float8_data.shape
132+
133+
return self.__class__(
134+
float8_data,
135+
scale,
136+
self.activation_scale_ub,
137+
self.dtype,
138+
)
139+
88140
@classmethod
89141
def from_float(
90142
cls,
@@ -106,8 +158,10 @@ def from_float(
106158
else:
107159
w = w.t()
108160

109-
wq, w_scale = torch.ops.triton.quantize_fp8_row(w)
110-
# wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
161+
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
162+
# add a last dimension for per row quantization to align the rank of
163+
# w_scale and wq
164+
w_scale = w_scale.unsqueeze(-1).contiguous()
111165
dtype = w.dtype
112166
del w
113167
return FbgemmFp8Tensor(
@@ -133,18 +187,18 @@ def _(func, types, args, kwargs):
133187

134188
# not used
135189
num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device)
136-
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
190+
a_data, a_scale = torch.ops.fbgemm.quantize_fp8_per_row(
137191
input_tensor, num_tokens, weight_tensor.activation_scale_ub
138192
)
139193

140-
a_data = xq
141194
b_data = weight_tensor.float8_data
195+
b_scale = weight_tensor.scale.squeeze(-1)
142196

143197
res = torch.ops.fbgemm.f8f8bf16_rowwise(
144198
a_data,
145199
b_data,
146-
x_scale,
147-
weight_tensor.scale,
200+
a_scale,
201+
b_scale,
148202
use_fast_accum=True,
149203
)
150204
res = res.reshape(*orig_act_size[:-1], orig_out_features)
@@ -163,19 +217,21 @@ def _(func, types, args, kwargs):
163217
orig_act_size = input_tensor.size()
164218
# not used
165219
num_tokens = torch.empty([input_tensor.size(0)], device=input_tensor.device)
166-
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
220+
a_data, a_scale = torch.ops.fbgemm.quantize_fp8_per_row(
167221
input_tensor, num_tokens, weight_tensor.activation_scale_ub
168222
)
169223

170-
a_data = xq
171224
b_data = weight_tensor.float8_data
225+
b_scale = weight_tensor.scale.squeeze(-1)
226+
assert b_data.is_contiguous(), "weight for bmm must be contiguous"
227+
172228
orig_out_features = b_data.shape[-2]
173229

174230
res = torch.ops.fbgemm.f8f8bf16_rowwise_batched(
175231
a_data,
176232
b_data,
177-
x_scale,
178-
weight_tensor.scale,
233+
a_scale,
234+
b_scale,
179235
)
180236
res = res.reshape(*orig_act_size[:-1], orig_out_features)
181237
return res
@@ -269,6 +325,52 @@ def _(func, types, args, kwargs):
269325
)
270326

271327

328+
@implements(aten.cat.default)
329+
def _(func, types, args, kwargs):
330+
tensors, dim = fill_defaults(args, 2, [[], 0])
331+
tensor_0 = tensors[0]
332+
if dim < 0:
333+
dim = tensor_0.ndim + dim
334+
335+
for i in range(1, len(tensors)):
336+
assert tensor_0.float8_data.ndim == tensors[i].float8_data.ndim
337+
assert tensor_0.scale.ndim == tensors[i].scale.ndim
338+
assert tensor_0.activation_scale_ub == tensors[i].activation_scale_ub
339+
340+
float8_datas = [t.float8_data for t in tensors]
341+
scales = [t.scale for t in tensors]
342+
343+
# with rowwise quantization, dimension of float8_data and
344+
# origianl shape will be the same, so original dim argument applies
345+
# to float8_data
346+
cat_float8_data = aten.cat.default(float8_datas, dim)
347+
348+
if dim != 2:
349+
cat_scale = aten.cat.default(scales, dim=dim)
350+
else:
351+
cat_scale = scales[0]
352+
353+
new = tensor_0.__class__(
354+
cat_float8_data,
355+
cat_scale,
356+
tensor_0.activation_scale_ub,
357+
tensor_0.dtype,
358+
)
359+
return return_and_correct_aliasing(func, args, kwargs, new)
360+
361+
362+
@implements(aten.transpose.int)
363+
def _(func, types, args, kwargs):
364+
self, dim0, dim1 = args
365+
float8_data = self.float8_data.transpose(dim0, dim1).contiguous()
366+
scale = self.scale.transpose(dim0, dim1).contiguous()
367+
368+
new = self.__class__(
369+
float8_data, scale, self.activation_scale_ub, self.dtype
370+
)
371+
return return_and_correct_aliasing(func, args, kwargs, new)
372+
373+
272374
to_fbgemm_fp8 = FbgemmFp8Tensor.from_float
273375

274376

0 commit comments

Comments
 (0)