File tree Expand file tree Collapse file tree 3 files changed +25
-8
lines changed
tests/kernels/quantization
vllm/model_executor/layers/quantization Expand file tree Collapse file tree 3 files changed +25
-8
lines changed Original file line number Diff line number Diff line change 14
14
15
15
from tests .kernels .utils import opcheck
16
16
from vllm import _custom_ops as ops
17
+ from vllm .model_executor .layers .quantization .utils .machete_utils import (
18
+ query_machete_supported_group_sizes )
17
19
from vllm .model_executor .layers .quantization .utils .quant_utils import (
18
20
pack_rows , quantize_weights )
19
21
from vllm .platforms import current_platform
46
48
(1024 , 8192 , 4096 ),
47
49
]
48
50
49
- GROUP_SIZES_TO_TEST : list [Optional [int ]] = [128 , - 1 ]
50
-
51
51
52
52
@dataclass
53
53
class TypeConfig :
@@ -270,7 +270,7 @@ def test_machete_all_schedules(shape, types: TypeConfig):
270
270
if types .group_scale_type is None :
271
271
group_sizes = [None ]
272
272
else :
273
- group_sizes = GROUP_SIZES_TO_TEST
273
+ group_sizes = query_machete_supported_group_sizes ( types . act_type )
274
274
275
275
for group_size in group_sizes :
276
276
if not group_size_valid (shape , group_size ):
@@ -299,7 +299,7 @@ def test_machete_heuristic(shape, types: TypeConfig):
299
299
if types .group_scale_type is None :
300
300
group_sizes = [None ]
301
301
else :
302
- group_sizes = GROUP_SIZES_TO_TEST
302
+ group_sizes = query_machete_supported_group_sizes ( types . act_type )
303
303
304
304
for group_size in group_sizes :
305
305
if not group_size_valid (shape , group_size ):
Original file line number Diff line number Diff line change 8
8
9
9
from vllm import _custom_ops as ops
10
10
from vllm .model_executor .layers .quantization .utils .machete_utils import (
11
- MACHETE_SUPPORTED_GROUP_SIZES , check_machete_supports_shape ,
11
+ check_machete_supports_shape , query_machete_supported_group_sizes ,
12
12
query_machete_supported_quant_types )
13
13
from vllm .model_executor .layers .quantization .utils .quant_utils import (
14
14
pack_quantized_values_into_int32 , unpack_quantized_values_into_int32 )
@@ -40,10 +40,10 @@ def can_implement(cls,
40
40
"Machete, supported types are: " \
41
41
f"{ query_machete_supported_quant_types (c .zero_points )} "
42
42
43
- if c .group_size not in MACHETE_SUPPORTED_GROUP_SIZES :
43
+ if c .group_size not in query_machete_supported_group_sizes ( c . act_type ) :
44
44
return False , f"Group size ({ c .group_size } ) not supported by " \
45
45
"Machete, supported group sizes are: " \
46
- f"{ MACHETE_SUPPORTED_GROUP_SIZES } "
46
+ f"{ query_machete_supported_group_sizes ( c . act_type ) } "
47
47
48
48
return check_machete_supports_shape (c .partition_weight_shape [0 ],
49
49
c .partition_weight_shape [1 ])
Original file line number Diff line number Diff line change 7
7
8
8
from vllm .scalar_type import ScalarType , scalar_types
9
9
10
- MACHETE_SUPPORTED_GROUP_SIZES = [- 1 , 128 ]
11
10
MACHETE_PREPACKED_BLOCK_SHAPE = [64 , 128 ]
12
11
13
12
@@ -22,6 +21,24 @@ def query_machete_supported_act_types(zero_points: bool) -> list[ScalarType]:
22
21
return [torch .float16 , torch .bfloat16 ]
23
22
24
23
24
+ def query_machete_supported_group_sizes (act_type : torch .dtype ) -> list [int ]:
25
+ """
26
+ Queries the supported group sizes for Machete based on the activation type.
27
+
28
+ Args:
29
+ act_type: The activation data type (torch.float16, torch.bfloat16).
30
+
31
+ Returns:
32
+ A list of supported group sizes. The group size must
33
+ be divisible by `TileShapeK = 128 * 8 // num_bits(act_type)`.
34
+ -1 indicates per-channel quantization.
35
+ """
36
+ if act_type in [torch .float16 , torch .bfloat16 ]:
37
+ return [- 1 , 64 , 128 ]
38
+ else :
39
+ return [- 1 , 128 ]
40
+
41
+
25
42
def check_machete_supports_shape (in_features : int , out_featrues : int ) \
26
43
-> tuple [bool , Optional [str ]]:
27
44
if in_features % MACHETE_PREPACKED_BLOCK_SHAPE [0 ] != 0 :
You can’t perform that action at this time.
0 commit comments