Skip to content

Commit 3abfe22

Browse files
authored
Enable group size 64 for Machete (#20290)
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
1 parent e81fbef commit 3abfe22

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

tests/kernels/quantization/test_machete_mm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from tests.kernels.utils import opcheck
1616
from vllm import _custom_ops as ops
17+
from vllm.model_executor.layers.quantization.utils.machete_utils import (
18+
query_machete_supported_group_sizes)
1719
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1820
pack_rows, quantize_weights)
1921
from vllm.platforms import current_platform
@@ -46,8 +48,6 @@
4648
(1024, 8192, 4096),
4749
]
4850

49-
GROUP_SIZES_TO_TEST: list[Optional[int]] = [128, -1]
50-
5151

5252
@dataclass
5353
class TypeConfig:
@@ -270,7 +270,7 @@ def test_machete_all_schedules(shape, types: TypeConfig):
270270
if types.group_scale_type is None:
271271
group_sizes = [None]
272272
else:
273-
group_sizes = GROUP_SIZES_TO_TEST
273+
group_sizes = query_machete_supported_group_sizes(types.act_type)
274274

275275
for group_size in group_sizes:
276276
if not group_size_valid(shape, group_size):
@@ -299,7 +299,7 @@ def test_machete_heuristic(shape, types: TypeConfig):
299299
if types.group_scale_type is None:
300300
group_sizes = [None]
301301
else:
302-
group_sizes = GROUP_SIZES_TO_TEST
302+
group_sizes = query_machete_supported_group_sizes(types.act_type)
303303

304304
for group_size in group_sizes:
305305
if not group_size_valid(shape, group_size):

vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from vllm import _custom_ops as ops
1010
from vllm.model_executor.layers.quantization.utils.machete_utils import (
11-
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
11+
check_machete_supports_shape, query_machete_supported_group_sizes,
1212
query_machete_supported_quant_types)
1313
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1414
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
@@ -40,10 +40,10 @@ def can_implement(cls,
4040
"Machete, supported types are: "\
4141
f"{query_machete_supported_quant_types(c.zero_points)}"
4242

43-
if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES:
43+
if c.group_size not in query_machete_supported_group_sizes(c.act_type):
4444
return False, f"Group size ({c.group_size}) not supported by "\
4545
"Machete, supported group sizes are: "\
46-
f"{MACHETE_SUPPORTED_GROUP_SIZES}"
46+
f"{query_machete_supported_group_sizes(c.act_type)}"
4747

4848
return check_machete_supports_shape(c.partition_weight_shape[0],
4949
c.partition_weight_shape[1])

vllm/model_executor/layers/quantization/utils/machete_utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from vllm.scalar_type import ScalarType, scalar_types
99

10-
MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128]
1110
MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128]
1211

1312

@@ -22,6 +21,24 @@ def query_machete_supported_act_types(zero_points: bool) -> list[ScalarType]:
2221
return [torch.float16, torch.bfloat16]
2322

2423

24+
def query_machete_supported_group_sizes(act_type: torch.dtype) -> list[int]:
25+
"""
26+
Queries the supported group sizes for Machete based on the activation type.
27+
28+
Args:
29+
act_type: The activation data type (torch.float16, torch.bfloat16).
30+
31+
Returns:
32+
A list of supported group sizes. The group size must
33+
be divisible by `TileShapeK = 128 * 8 // num_bits(act_type)`.
34+
-1 indicates per-channel quantization.
35+
"""
36+
if act_type in [torch.float16, torch.bfloat16]:
37+
return [-1, 64, 128]
38+
else:
39+
return [-1, 128]
40+
41+
2542
def check_machete_supports_shape(in_features: int, out_featrues: int) \
2643
-> tuple[bool, Optional[str]]:
2744
if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0:

0 commit comments

Comments
 (0)