Skip to content

Commit 8e33b70

Browse files
authored
Update GemLite to support vLLM V1 (#2199)
* update to forward_functional() * add 8-bit symmetric case * ruff * fix test
1 parent 04fb450 commit 8e33b70

File tree

3 files changed

+15
-66
lines changed

3 files changed

+15
-66
lines changed

test/quantization/test_config_serialization.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@
6363
GemliteUIntXWeightOnlyConfig(
6464
group_size=128, # Optional, has default of 64
6565
bit_width=8, # Optional, has default of 4
66-
packing_bitwidth=8, # Optional, has default of 32
67-
contiguous=True, # Optional, has default of None
6866
),
6967
FPXWeightOnlyConfig(ebits=4, mbits=8),
7068
# Sparsity configs

torchao/dtypes/uintx/gemlite_layout.py

Lines changed: 14 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
try:
2424
import gemlite
25-
from gemlite.core import GemLiteLinearTriton
2625
except:
2726
gemlite = None
2827

@@ -51,18 +50,6 @@ def _same_metadata(
5150
)
5251

5352

54-
def scale_activations_no_scaling(x):
55-
return x, None
56-
57-
58-
def scale_activations_int8(x):
59-
x_shape = x.shape
60-
out_x = x.view(-1, x.shape[-1])
61-
scaled_x = torch.abs(out_x).amax(axis=1, keepdim=True) / 127
62-
out_x = torch.round(out_x / scaled_x).to(dtype=torch.int8)
63-
return out_x.view(x_shape), scaled_x
64-
65-
6653
def get_gemlite_quant_kwargs(bit_width, group_size, dtype):
6754
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
6855

@@ -93,8 +80,6 @@ def get_gemlite_aqt_kwargs(
9380
weight,
9481
group_size=64,
9582
bit_width=4,
96-
packing_bitwidth=32,
97-
contiguous=None,
9883
use_hqq=True,
9984
):
10085
if gemlite is None:
@@ -106,12 +91,7 @@ def get_gemlite_aqt_kwargs(
10691
4,
10792
8,
10893
], f"gemlite only works with bit_width 4,8 but got {bit_width}"
109-
assert packing_bitwidth in [
110-
8,
111-
16,
112-
32,
113-
None,
114-
], f"gemlite needs packing_bitwidth in [8, 16, 32] but got {packing_bitwidth}"
94+
11595
assert weight.dtype in [torch.float16, torch.bfloat16], (
11696
f"gemlite only works with dtype torch.float16 or torch.bfloat16 but got {weight.dtype}"
11797
)
@@ -127,8 +107,6 @@ def get_gemlite_aqt_kwargs(
127107
aqt_kwargs["_layout"] = GemlitePackedLayout(
128108
group_size=group_size,
129109
bit_width=bit_width,
130-
packing_bitwidth=packing_bitwidth,
131-
contiguous=contiguous,
132110
)
133111
aqt_kwargs["use_hqq"] = use_hqq
134112
return aqt_kwargs
@@ -138,8 +116,6 @@ def get_gemlite_aqt_kwargs(
138116
class GemlitePackedLayout(Layout):
139117
group_size: Optional[int] = 64
140118
bit_width: int = 4
141-
packing_bitwidth: int = None
142-
contiguous: bool = None
143119

144120

145121
@register_layout(GemlitePackedLayout)
@@ -216,13 +192,18 @@ def from_plain(
216192
group_size, bit_width = _layout.group_size, _layout.bit_width
217193
out_features, in_features = int_data.shape
218194

219-
gemlite_linear = gemlite.helper.A16Wn(device=int_data.device).from_weights(
220-
int_data, scale, zero_point, bit_width, group_size, bias=None
221-
)
195+
if bit_width == 8 and group_size == in_features:
196+
gemlite_linear = gemlite.helper.A16W8(device=int_data.device).from_weights(
197+
int_data, scales=scale, bias=None
198+
)
199+
else:
200+
gemlite_linear = gemlite.helper.A16Wn(device=int_data.device).from_weights(
201+
int_data, scale, zero_point, bit_width, group_size, bias=None
202+
)
222203

223204
gemlite_kwargs = {
205+
"in_features": in_features,
224206
"out_features": out_features,
225-
"scaled_activations": gemlite_linear.scaled_activations,
226207
"meta_args": gemlite_linear.get_meta_args(),
227208
}
228209

@@ -253,20 +234,17 @@ def _apply_fn_to_data(self, fn):
253234

254235
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
255236
device = self.packed_weight.device
256-
elements_per_sample = self._layout.packing_bitwidth // self._layout.bit_width
257-
in_features = (
258-
self.packed_weight.numel() * elements_per_sample
259-
) // self.gemlite_kwargs["out_features"]
260237
int_data = (
261238
gemlite.bitpack.unpack_over_rows(
262239
self.packed_weight.cuda(),
263240
W_nbits=self._layout.bit_width,
264-
num_output_rows=in_features,
241+
num_output_rows=self.gemlite_kwargs["out_features"],
265242
dtype=torch.uint8,
266243
)
267244
.t()
268245
.contiguous()
269246
).to(device)
247+
270248
scale = self.scale.t().contiguous()
271249
zero_point = self.zero_point.t().contiguous()
272250

@@ -353,42 +331,21 @@ def block_size(self):
353331
return (1, self._layout.group_size)
354332

355333

356-
# logic taken from gemlite's core.py
357-
def _matmul_type_fn(batch_size: int, bit_width: int) -> str:
358-
if batch_size > 64:
359-
return "GEMM"
360-
elif batch_size > 1:
361-
return "GEMM_SPLITK"
362-
else:
363-
return gemlite.core.get_default_gemv(bit_width)
364-
365-
366334
def _linear_fp_act_int4_weight_gemlite_impl(input_tensor, weight_tensor, bias=None):
367335
if hasattr(weight_tensor, "tensor_impl"):
368336
weight_impl = weight_tensor.tensor_impl
369337
else:
370338
weight_impl = weight_tensor
371339

372-
batch_size = input_tensor.view(-1, input_tensor.shape[-1]).shape[0]
373-
matmul_type = _matmul_type_fn(batch_size, weight_impl._layout.bit_width)
374-
375-
if weight_impl.gemlite_kwargs["scaled_activations"]:
376-
scale_activations = scale_activations_int8
377-
else:
378-
scale_activations = scale_activations_no_scaling
379-
380-
return GemLiteLinearTriton.forward_functional(
340+
return gemlite.core.forward_functional(
381341
x=input_tensor,
382342
bias=bias,
383-
matmul_type=matmul_type,
384-
out_features=weight_impl.gemlite_kwargs["out_features"],
385-
scale_activations=scale_activations,
386-
meta_args=weight_impl.gemlite_kwargs["meta_args"],
387343
tensor_args=(
388344
weight_impl.packed_weight,
389345
weight_impl.scale,
390346
weight_impl.zero_point,
391347
),
348+
meta_args=weight_impl.gemlite_kwargs["meta_args"],
392349
)
393350

394351

torchao/quantization/quant_api.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -979,8 +979,6 @@ class GemliteUIntXWeightOnlyConfig(AOBaseConfig):
979979

980980
group_size: Optional[int] = 64
981981
bit_width: int = 4
982-
packing_bitwidth: int = 32
983-
contiguous: Optional[bool] = None
984982
set_inductor_config: bool = True
985983

986984

@@ -994,8 +992,6 @@ def _gemlite_uintx_weight_only_transform(
994992
):
995993
group_size = config.group_size
996994
bit_width = config.bit_width
997-
packing_bitwidth = config.packing_bitwidth
998-
contiguous = config.contiguous
999995
if config.set_inductor_config:
1000996
torchao.quantization.utils.recommended_inductor_config_setter()
1001997

@@ -1006,9 +1002,7 @@ def _gemlite_uintx_weight_only_transform(
10061002
use_hqq = True if bit_width == 4 else False
10071003
new_weight = to_affine_quantized_intx(
10081004
weight,
1009-
**get_gemlite_aqt_kwargs(
1010-
weight, group_size, bit_width, packing_bitwidth, contiguous, use_hqq
1011-
),
1005+
**get_gemlite_aqt_kwargs(weight, group_size, bit_width, use_hqq),
10121006
)
10131007
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
10141008
module.extra_repr = types.MethodType(_linear_extra_repr, module)

0 commit comments

Comments
 (0)