Skip to content

Commit 9e128c1

Browse files
committed
Add support for resharding for fbgemm configs
Summary: added transpose and cat op support, and also some custom transpose/reshape/unflatten support for resharding. In the future we should probably provide examples for using distributed checkpoint for resharding Test Plan: python test/dtypes/test_fbgemm_int4.py -k test_transpose python test/dtypes/test_fbgemm_int4.py -k test_cat python test/dtypes/test_fbgemm_fp8.py -k test_transpose python test/dtypes/test_fbgemm_fp8.py -k test_cat Reviewers: Subscribers: Tasks: Tags:
1 parent 6243040 commit 9e128c1

File tree

4 files changed

+353
-16
lines changed

4 files changed

+353
-16
lines changed

test/dtypes/test_fbgemm_fp8.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,35 @@ def test_to_device(self):
146146
quantize_(linear, self.config)
147147
linear.to(device)
148148

149+
def test_cat(self):
150+
# weight: (256, 128)
151+
linear1 = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
152+
# weight: (256, 128)
153+
linear2 = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
154+
155+
quantize_(linear1, self.config)
156+
quantize_(linear2, self.config)
157+
158+
cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
159+
cat_weight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
160+
self.assertTrue(cat_weight1.shape, (512, 128))
161+
self.assertTrue(cat_weight2.shape, (256, 256))
162+
163+
def test_transpose(self):
164+
# weight: (256, 128)
165+
linear1 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
166+
quantize_(linear1, self.config)
167+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
168+
linear1.bias = torch.nn.Parameter(
169+
torch.randn(128, dtype=torch.bfloat16, device="cuda")
170+
)
171+
self.assertTrue(linear1.weight.shape, (128, 256))
172+
173+
input = torch.randn(32, 256, dtype=torch.bfloat16, device="cuda")
174+
# make sure it runs
175+
res = linear1(input)
176+
self.assertTrue(res.shape, (32, 128))
177+
149178

150179
if __name__ == "__main__":
151180
run_tests()

test/dtypes/test_fbgemm_int4.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,34 @@ def test_to_device(self):
152152
quantize_(linear, self.config)
153153
linear.to(device)
154154

155+
def test_cat(self):
156+
# weight: (256, 128)
157+
linear1 = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
158+
# weight: (256, 128)
159+
linear2 = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
160+
161+
quantize_(linear1, self.config)
162+
quantize_(linear2, self.config)
163+
164+
cat_weight1 = torch.cat([linear1.weight, linear2.weight], dim=0)
165+
cat_weight2 = torch.cat([linear1.weight, linear2.weight], dim=1)
166+
self.assertTrue(cat_weight1.shape, (512, 128))
167+
self.assertTrue(cat_weight2.shape, (256, 256))
168+
169+
def test_transpose(self):
170+
# weight: (256, 128)
171+
linear1 = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
172+
quantize_(linear1, self.config)
173+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
174+
# transpose again to return to the original state
175+
linear1.weight = torch.nn.Parameter(linear1.weight.transpose(0, 1).contiguous())
176+
self.assertTrue(linear1.weight.shape, (256, 128))
177+
178+
input = torch.randn(32, 128, dtype=torch.bfloat16, device="cuda")
179+
# make sure it runs
180+
res = linear1(input)
181+
self.assertTrue(res.shape, (32, 256))
182+
155183

156184
if __name__ == "__main__":
157185
run_tests()

torchao/dtypes/fbgemm_fp8_tensor.py

Lines changed: 125 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,30 @@
2727
class FbgemmFp8Tensor(TorchAOBaseTensor):
2828
"""
2929
TODO: needs padding for cutlass kernels
30+
Args:
31+
data_to_scale_dim: the dim mapping from float8_data to scale, e.g.
32+
float8_data: (batch_size, output_channel, input_channel)
33+
scale: (batch_size, output_channel) (since it's per row quantization)
34+
data_to_scale_dim: {0: 0, 1: 1}
3035
"""
3136

3237
tensor_data_attrs = ["float8_data", "scale", "activation_scale_ub"]
33-
tensor_attributes = ["dtype"]
38+
tensor_attributes = ["data_to_scale_dim", "dtype"]
3439

35-
def __new__(cls, float8_data, scale, activation_scale_ub, dtype):
40+
def __new__(cls, float8_data, scale, activation_scale_ub, data_to_scale_dim, dtype):
3641
shape = float8_data.shape
3742
kwargs = {}
3843
kwargs["device"] = float8_data.device
3944
kwargs["dtype"] = dtype
4045
kwargs["requires_grad"] = False
4146
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
4247

43-
def __init__(self, float8_data, scale, activation_scale_ub, dtype):
48+
def __init__(
49+
self, float8_data, scale, activation_scale_ub, data_to_scale_dim, dtype
50+
):
4451
self.float8_data = float8_data
4552
self.scale = scale
53+
self.data_to_scale_dim = data_to_scale_dim
4654
self.activation_scale_ub = activation_scale_ub
4755

4856
def __tensor_flatten__(self):
@@ -68,12 +76,12 @@ def _apply_fn_to_data(self, fn):
6876
def __repr__(self):
6977
return (
7078
f"{self.__class__.__name__}(weight={self.float8_data}, scale={self.scale}, "
71-
f"activation_scale_ub={self.activation_scale_ub}, "
79+
f"activation_scale_ub={self.activation_scale_ub}, data_to_scale_dim={self.data_to_scale_dim}, "
7280
f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
7381
)
7482

7583
def _quantization_type(self):
76-
return f"shape={self.shape}, activation_scale_ub={self.activation_scale_ub}, device={self.device}"
84+
return f"shape={self.shape}, data_to_scale_dim={self.data_to_scale_dim}, activation_scale_ub={self.activation_scale_ub}, device={self.device}"
7785

7886
def to(self, *args, **kwargs):
7987
kwargs = self._get_to_kwargs(*args, **kwargs)
@@ -82,6 +90,53 @@ def to(self, *args, **kwargs):
8290
self.float8_data.to(device),
8391
self.scale.to(device),
8492
self.activation_scale_ub.to(device),
93+
self.data_to_scale_dim,
94+
self.dtype,
95+
)
96+
97+
def _transpose_and_reshape(self):
98+
"""This is added for resharding support, since the resharding logic for the model we are
99+
working with only support 2D
100+
"""
101+
assert len(self.shape) == 3, (
102+
f"Only expected to be used when the Tensor is 3D, got {len(self.shape)}"
103+
)
104+
dim0, dim1, dim2 = self.shape
105+
# because we first transpose the weight before quantization, we'll recover the original shape
106+
# by swapping dim1 and dim2
107+
original_shape = (dim0, dim2, dim1)
108+
# we must save this as 2D in the state dict, since loading code expects 2D weights
109+
new_shape = (-1, original_shape[-1])
110+
float8_data = self.float8_data
111+
float8_data = float8_data.transpose(1, 2).reshape(*new_shape).contiguous()
112+
data_to_scale_dim = {0: 0, 1: 1}
113+
return self.__class__(
114+
float8_data,
115+
self.scale,
116+
self.activation_scale_ub,
117+
data_to_scale_dim,
118+
self.dtype,
119+
)
120+
121+
def _unflatten(self, num_experts):
122+
"""This is added for resharding support, since the resharding logic for the model we are
123+
working with only support 2D
124+
"""
125+
float8_data = self.float8_data
126+
dim0, dim1 = self.shape
127+
float8_data = float8_data.unflatten(0, (num_experts, -1)).squeeze(dim=0)
128+
data_to_scale_dim = {0: 0}
129+
dim0, dim1, dim2 = float8_data.shape
130+
if dim1 == self.scale.shape[1]:
131+
data_to_scale_dim[1] = 1
132+
else:
133+
data_to_scale_dim[2] = 1
134+
135+
return self.__class__(
136+
float8_data,
137+
self.scale,
138+
self.activation_scale_ub,
139+
data_to_scale_dim,
85140
self.dtype,
86141
)
87142

@@ -106,14 +161,18 @@ def from_float(
106161
else:
107162
w = w.t()
108163

109-
wq, w_scale = torch.ops.triton.quantize_fp8_row(w)
110-
# wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
164+
data_to_scale_dim = {0: 0}
165+
if w.ndim == 3:
166+
data_to_scale_dim[1] = 1
167+
168+
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
111169
dtype = w.dtype
112170
del w
113171
return FbgemmFp8Tensor(
114172
wq,
115173
w_scale,
116174
activation_scale_ub=activation_scale_ub,
175+
data_to_scale_dim=data_to_scale_dim,
117176
dtype=dtype,
118177
)
119178

@@ -169,6 +228,8 @@ def _(func, types, args, kwargs):
169228

170229
a_data = xq
171230
b_data = weight_tensor.float8_data
231+
assert b_data.is_contiguous(), "weight for bmm must be contiguous"
232+
172233
orig_out_features = b_data.shape[-2]
173234

174235
res = torch.ops.fbgemm.f8f8bf16_rowwise_batched(
@@ -269,6 +330,63 @@ def _(func, types, args, kwargs):
269330
)
270331

271332

333+
@implements(aten.cat.default)
334+
def _(func, types, args, kwargs):
335+
tensors, dim = fill_defaults(args, 2, [[], 0])
336+
tensor_0 = tensors[0]
337+
if dim < 0:
338+
dim = tensor_0.ndim + dim
339+
340+
for i in range(1, len(tensors)):
341+
assert tensor_0.float8_data.ndim == tensors[i].float8_data.ndim
342+
assert tensor_0.scale.ndim == tensors[i].scale.ndim
343+
assert tensor_0.activation_scale_ub == tensors[i].activation_scale_ub
344+
assert tensor_0.data_to_scale_dim == tensors[i].data_to_scale_dim
345+
346+
float8_data = [t.float8_data for t in tensors]
347+
scale = [t.scale for t in tensors]
348+
349+
# with rowwise quantization, dimension of float8_data and
350+
# origianl shape will be the same, so original dim argument applies
351+
# to float8_data
352+
cat_float8_data = aten.cat.default(float8_data, dim)
353+
354+
# if cat dimension has a corresponding scale dimension, then we'll concat the corresponding
355+
# scale dimension, otherwise, we'll just use the existing scale
356+
if dim in tensor_0.data_to_scale_dim:
357+
cat_scale = aten.cat.default(scale, dim=tensor_0.data_to_scale_dim[dim])
358+
else:
359+
cat_scale = scale[0]
360+
361+
new = tensor_0.__class__(
362+
cat_float8_data,
363+
cat_scale,
364+
tensor_0.activation_scale_ub,
365+
tensor_0.data_to_scale_dim,
366+
tensor_0.dtype,
367+
)
368+
return return_and_correct_aliasing(func, args, kwargs, new)
369+
370+
371+
@implements(aten.transpose.int)
372+
def _(func, types, args, kwargs):
373+
self, dim0, dim1 = args
374+
float8_data = self.float8_data.transpose(dim0, dim1).contiguous()
375+
data_to_scale_dim = self.data_to_scale_dim.copy()
376+
377+
if dim0 in data_to_scale_dim:
378+
data_to_scale_dim[dim1] = data_to_scale_dim[dim0]
379+
del data_to_scale_dim[dim0]
380+
elif dim1 in data_to_scale_dim:
381+
data_to_scale_dim[dim0] = data_to_scale_dim[dim1]
382+
del data_to_scale_dim[dim1]
383+
384+
new = self.__class__(
385+
float8_data, self.scale, self.activation_scale_ub, data_to_scale_dim, self.dtype
386+
)
387+
return return_and_correct_aliasing(func, args, kwargs, new)
388+
389+
272390
to_fbgemm_fp8 = FbgemmFp8Tensor.from_float
273391

274392

0 commit comments

Comments
 (0)