Skip to content

Commit 4235837

Browse files
authored
Add slicing support for fbgemm fp8 and int4 (#2308)
Summary: att, this is needed in vllm Note that irregular shapes will require padding, which is not implemented right now, we can add that if it's required by the model Test Plan: python test/dtypes/test_fbgemm_fp8.py -k test_slice python test/dtypes/test_fbgemm_int4.py -k test_slice Reviewers: Subscribers: Tasks: Tags:
1 parent 95151b4 commit 4235837

File tree

4 files changed

+344
-26
lines changed

4 files changed

+344
-26
lines changed

test/dtypes/test_fbgemm_fp8.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,87 @@
2525

2626

2727
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
28+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
29+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
2830
class TestFbgemmFp8Tensor(TestCase):
29-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
30-
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
31+
def setUp(self):
32+
self.config = FbgemmConfig(
33+
input_dtype=e4m3_dtype,
34+
weight_dtype=e4m3_dtype,
35+
output_dtype=torch.bfloat16,
36+
)
37+
3138
def test_linear(self):
3239
dtype = torch.bfloat16
3340
device = "cuda"
3441
input = torch.randn(1, 128, dtype=dtype, device=device)
3542
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
3643
original = linear(input)
37-
config = FbgemmConfig(
38-
input_dtype=e4m3_dtype,
39-
weight_dtype=e4m3_dtype,
40-
output_dtype=torch.bfloat16,
41-
)
42-
quantize_(linear, config)
44+
quantize_(linear, self.config)
4345
quantized = linear(input)
4446
self.assertTrue(compute_error(original, quantized) > 20)
4547

48+
def test_slice(self):
49+
dtype = torch.bfloat16
50+
device = "cuda"
51+
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
52+
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
53+
dummy1.weight = torch.nn.Parameter(
54+
dummy.weight.narrow(0, 0, 64), requires_grad=False
55+
)
56+
dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
57+
dummy2.weight = torch.nn.Parameter(
58+
dummy.weight.narrow(1, 0, 128), requires_grad=False
59+
)
60+
61+
quantize_(dummy, self.config)
62+
weight1 = dummy.weight.narrow(0, 0, 64)
63+
weight2 = dummy.weight.narrow(1, 0, 128)
64+
self.assertEqual(weight1.float8_data, dummy.weight.float8_data.narrow(0, 0, 64))
65+
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64))
66+
self.assertEqual(
67+
weight2.float8_data, dummy.weight.float8_data.narrow(1, 0, 128)
68+
)
69+
self.assertEqual(weight2.scale, dummy.weight.scale)
70+
71+
# check for sliced weight, before and after float8 quantization
72+
# does not differ too much
73+
input = torch.randn(2, 256, dtype=dtype, device=device)
74+
res_ref = dummy1(input)
75+
dummy.weight = torch.nn.Parameter(weight1, requires_grad=False)
76+
res = dummy(input)
77+
assert compute_error(res, res_ref) > 25
78+
79+
input = torch.randn(2, 128, dtype=dtype, device=device)
80+
res_ref = dummy2(input)
81+
dummy.weight = torch.nn.Parameter(weight2, requires_grad=False)
82+
res = dummy(input)
83+
assert compute_error(res, res_ref) > 15
84+
85+
def test_slice_and_copy_(self):
86+
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
87+
l.weight = torch.nn.Parameter(
88+
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
89+
)
90+
quantize_(l, self.config)
91+
param = l.weight
92+
param_data = param.data
93+
param_data = param_data.narrow(0, 0, 512)
94+
assert param.data.float8_data.data_ptr() == param_data.float8_data.data_ptr()
95+
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
96+
orig_value = param.data.float8_data[0][0].item()
97+
98+
# dummy_l has random input (shouldn't be 0)
99+
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
100+
quantize_(dummy_l, self.config)
101+
quantized = dummy_l.weight
102+
quantized = quantized.narrow(0, 0, 512)
103+
104+
param_data.copy_(quantized)
105+
106+
# making sure param.data is updated
107+
assert param.data.float8_data[0][0] != orig_value
108+
46109

47110
if __name__ == "__main__":
48111
run_tests()

test/dtypes/test_fbgemm_int4.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,93 @@
2424

2525

2626
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+")
27+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
28+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
2729
class TestFbgemmInt4Tensor(TestCase):
28-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
29-
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
30+
def setUp(self):
31+
self.config = FbgemmConfig(
32+
input_dtype=torch.bfloat16,
33+
weight_dtype=torch.int4,
34+
output_dtype=torch.bfloat16,
35+
block_size=[1, 128],
36+
)
37+
3038
def test_linear(self):
3139
dtype = torch.bfloat16
3240
device = "cuda"
3341
input = torch.randn(1, 128, dtype=dtype, device=device)
3442
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
3543
original = linear(input)
36-
config = FbgemmConfig(
37-
input_dtype=torch.bfloat16,
38-
weight_dtype=torch.int4,
39-
output_dtype=torch.bfloat16,
40-
block_size=[1, 128],
41-
)
42-
quantize_(linear, config)
44+
quantize_(linear, self.config)
4345
quantized = linear(input)
4446
self.assertTrue(compute_error(original, quantized) > 20)
4547

48+
def test_slice(self):
49+
dtype = torch.bfloat16
50+
device = "cuda"
51+
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
52+
dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device)
53+
dummy1.weight = torch.nn.Parameter(
54+
dummy.weight.narrow(0, 0, 64), requires_grad=False
55+
)
56+
dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device)
57+
dummy2.weight = torch.nn.Parameter(
58+
dummy.weight.narrow(1, 0, 128), requires_grad=False
59+
)
60+
61+
quantize_(dummy, self.config)
62+
weight1 = dummy.weight.narrow(0, 0, 64)
63+
weight2 = dummy.weight.narrow(1, 0, 128)
64+
self.assertEqual(
65+
weight1.packed_weight, dummy.weight.packed_weight.narrow(0, 0, 64)
66+
)
67+
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(1, 0, 64))
68+
self.assertEqual(
69+
weight2.packed_weight, dummy.weight.packed_weight.narrow(1, 0, 64)
70+
)
71+
self.assertEqual(weight2.scale, dummy.weight.scale.narrow(0, 0, 1))
72+
73+
# check for sliced weight, before and after float8 quantization
74+
# does not differ too much
75+
input = torch.randn(2, 256, dtype=dtype, device=device)
76+
res_ref = dummy1(input)
77+
dummy.weight = torch.nn.Parameter(weight1, requires_grad=False)
78+
res = dummy(input)
79+
assert compute_error(res, res_ref) > 20
80+
81+
input = torch.randn(2, 128, dtype=dtype, device=device)
82+
res_ref = dummy2(input)
83+
dummy.weight = torch.nn.Parameter(weight2, requires_grad=False)
84+
res = dummy(input)
85+
assert compute_error(res, res_ref) > 15
86+
87+
def test_slice_and_copy_(self):
88+
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
89+
l.weight = torch.nn.Parameter(
90+
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
91+
)
92+
quantize_(l, self.config)
93+
param = l.weight
94+
param_data = param.data
95+
param_data = param_data.narrow(0, 0, 512)
96+
assert (
97+
param.data.packed_weight.data_ptr() == param_data.packed_weight.data_ptr()
98+
)
99+
assert param.data.scale.data_ptr() == param_data.scale.data_ptr()
100+
assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr()
101+
orig_value = param.data.packed_weight[0][0].item()
102+
103+
# dummy_l has random input (shouldn't be 0)
104+
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
105+
quantize_(dummy_l, self.config)
106+
quantized = dummy_l.weight
107+
quantized = quantized.narrow(0, 0, 512)
108+
109+
param_data.copy_(quantized)
110+
111+
# making sure param.data is updated
112+
assert param.data.packed_weight[0][0] != orig_value
113+
46114

47115
if __name__ == "__main__":
48116
run_tests()

torchao/dtypes/fbgemm_fp8_tensor.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torchao.utils import (
1414
TORCH_VERSION_AT_LEAST_2_5,
1515
TorchAOBaseTensor,
16+
fill_defaults,
1617
)
1718

1819
__all__ = [
@@ -23,6 +24,10 @@
2324

2425

2526
class FbgemmFp8Tensor(TorchAOBaseTensor):
27+
"""
28+
TODO: needs padding for cutlass kernels
29+
"""
30+
2631
tensor_data_attrs = ["float8_data", "scale", "activation_scale_ub"]
2732
tensor_attributes = ["dtype"]
2833

@@ -118,9 +123,13 @@ def _(func, types, args, kwargs):
118123
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
119124
input_tensor, num_tokens, weight_tensor.activation_scale_ub
120125
)
126+
127+
a_data = xq
128+
b_data = weight_tensor.float8_data
129+
121130
res = torch.ops.fbgemm.f8f8bf16_rowwise(
122-
xq,
123-
weight_tensor.float8_data,
131+
a_data,
132+
b_data,
124133
x_scale,
125134
weight_tensor.scale,
126135
use_fast_accum=True,
@@ -139,13 +148,87 @@ def _(func, types, args, kwargs):
139148
)
140149

141150

142-
@implements([aten.clone.default, aten.copy_.default])
151+
@implements(aten.clone.default)
143152
def _(func, types, args, kwargs):
144153
return return_and_correct_aliasing(
145154
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
146155
)
147156

148157

158+
def _same_metadata(self: "FbgemmFp8Tensor", src: "FbgemmFp8Tensor") -> bool:
159+
return (
160+
isinstance(self, FbgemmFp8Tensor)
161+
and isinstance(src, FbgemmFp8Tensor)
162+
and self.shape == src.shape
163+
and self.float8_data.shape == src.float8_data.shape
164+
and self.scale.shape == src.scale.shape
165+
and self.activation_scale_ub.shape == src.activation_scale_ub.shape
166+
and self.dtype == src.dtype
167+
)
168+
169+
170+
@implements(aten.copy_.default)
171+
def _(func, types, args, kwargs):
172+
self = args[0]
173+
src = args[1]
174+
if _same_metadata(self, src):
175+
self_tensors = self.__tensor_flatten__()[0]
176+
for tensor_name in self_tensors:
177+
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
178+
return
179+
raise ValueError(
180+
f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}"
181+
)
182+
183+
184+
@implements(aten.slice.Tensor)
185+
def _(func, types, args, kwargs):
186+
"""Only supports slicing for dim == 1 and dim == 2
187+
original tensor shape has dimension (N, K)
188+
float8_data has dimension (N, K)
189+
scale (per row quantization) has dimension: (N,)
190+
191+
since float8_data has the same dimension as original tensor, we can directly slice that
192+
for scale, we'll do a slice when dim is 0, and don't need to do anything for dim 1
193+
194+
Note that we need to call slice on the float8_data and scale directly because slice
195+
is an operation that need to preserve aliasing, see `test_slice_and_copy_` in `test_fbgemm_fp8`
196+
for
197+
"""
198+
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
199+
assert step == 1
200+
assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}"
201+
if end >= self.shape[dim]:
202+
end = self.shape[dim]
203+
204+
assert self.float8_data.ndim == 2, (
205+
f"Expected packed weight to have dim 2, got {self.float8_data.dim}"
206+
)
207+
208+
# Always slice the float8_data
209+
sliced_data = aten.slice.Tensor(
210+
self.float8_data, dim, start, end, step
211+
).contiguous()
212+
213+
if dim == 0:
214+
# scale has dimension (N,) where N is the dim 0 of `self`
215+
# so we do the same slice on scale for dimension 0
216+
sliced_scale = aten.slice.Tensor(self.scale, 0, start, end, step)
217+
else:
218+
# since scale is per row, slicing along the dim == 1 dimension does
219+
# not change the scale
220+
sliced_scale = self.scale
221+
222+
return return_and_correct_aliasing(
223+
func,
224+
args,
225+
kwargs,
226+
FbgemmFp8Tensor(
227+
sliced_data, sliced_scale, self.activation_scale_ub, dtype=self.dtype
228+
),
229+
)
230+
231+
149232
to_fbgemm_fp8 = FbgemmFp8Tensor.from_float
150233

151234

0 commit comments

Comments
 (0)