Skip to content

Commit bf03ff3

Browse files
[Kernel] Add Conch backend for mixed-precision linear layer (#19818)
Signed-off-by: Jacob Manning <jmanning+oss@stackav.com>
1 parent 47043eb commit bf03ff3

File tree

5 files changed

+105
-1
lines changed

5 files changed

+105
-1
lines changed

requirements/rocm.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ setuptools>=77.0.3,<80.0.0
1717
setuptools-scm>=8
1818
runai-model-streamer==0.11.0
1919
runai-model-streamer-s3==0.11.0
20+
conch-triton-kernels==1.2.1

vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
AllSparkLinearKernel)
99
from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501
1010
BitBLASLinearKernel)
11+
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
12+
ConchLinearKernel)
1113
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
1214
ExllamaLinearKernel)
1315
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
@@ -24,6 +26,7 @@
2426
AllSparkLinearKernel,
2527
MarlinLinearKernel,
2628
BitBLASLinearKernel,
29+
ConchLinearKernel,
2730
ExllamaLinearKernel,
2831
]
2932

@@ -80,4 +83,4 @@ def choose_mp_linear_kernel(
8083
raise ValueError(
8184
"Failed to find a kernel that can implement the "\
8285
"WNA16 linear layer. Reasons: \n"
83-
+ '\n'.join(failure_reasons))
86+
+ '\n'.join(failure_reasons))
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from importlib.util import find_spec
5+
from typing import Final, Optional
6+
7+
import torch
8+
9+
from vllm.model_executor.parameter import (BasevLLMParameter,
10+
permute_param_layout_)
11+
from vllm.scalar_type import scalar_types
12+
13+
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
14+
15+
_CONCH_SUPPORTED_WEIGHT_TYPES: Final = [
16+
scalar_types.uint4, scalar_types.uint8, scalar_types.uint4b8,
17+
scalar_types.uint8b128
18+
]
19+
_CONCH_SUPPORTED_GROUP_SIZES: Final = [-1, 128]
20+
21+
22+
class ConchLinearKernel(MPLinearKernel):
23+
24+
@classmethod
25+
def get_min_capability(cls) -> int:
26+
return 80
27+
28+
@classmethod
29+
def can_implement(cls,
30+
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
31+
if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES:
32+
error_msg = f"Weight type ({c.weight_type}) not supported by "\
33+
"ConchLinearKernel, supported types are: " \
34+
f"{_CONCH_SUPPORTED_WEIGHT_TYPES}"
35+
return False, error_msg
36+
37+
if c.group_size not in _CONCH_SUPPORTED_GROUP_SIZES:
38+
error_msg = f"Group size ({c.group_size}) not supported by "\
39+
"ConchLinearKernel, supported group sizes are: " \
40+
f"{_CONCH_SUPPORTED_GROUP_SIZES}"
41+
return False, error_msg
42+
43+
if find_spec("conch") is None:
44+
error_msg = "conch-triton-kernels is not installed, please "\
45+
"install it via `pip install conch-triton-kernels` "\
46+
"and try again!"
47+
return False, error_msg
48+
49+
return True, None
50+
51+
# note assumes that
52+
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
53+
# `weight_scale` is: {input_dim = 0, output_dim = 1}
54+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
55+
56+
def transform_w_q(x):
57+
assert isinstance(x, BasevLLMParameter)
58+
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
59+
x.data = x.data.contiguous()
60+
return x
61+
62+
def transform_w_s(x):
63+
assert isinstance(x, BasevLLMParameter)
64+
permute_param_layout_(x, input_dim=0, output_dim=1)
65+
x.data = x.data.contiguous()
66+
return x
67+
68+
self._transform_param(layer, self.w_q_name, transform_w_q)
69+
self._transform_param(layer, self.w_s_name, transform_w_s)
70+
71+
def apply_weights(self,
72+
layer: torch.nn.Module,
73+
x: torch.Tensor,
74+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
75+
from conch.ops.quantization.gemm import mixed_precision_gemm
76+
77+
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
78+
79+
output = mixed_precision_gemm(
80+
x=x,
81+
w_q_packed=w_q.data,
82+
w_s=w_s.data,
83+
w_zp=w_zp.data if w_zp is not None else None,
84+
weight_size_bits=self.config.weight_type.size_bits,
85+
weight_bias=self.config.weight_type.bias,
86+
group_size=self.config.group_size,
87+
)
88+
89+
if bias is not None:
90+
output.add_(bias) # In-place add
91+
92+
return output

vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
1515
from vllm.model_executor.parameter import (BasevLLMParameter,
1616
permute_param_layout_)
17+
from vllm.platforms import current_platform
1718

1819
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
1920

@@ -27,6 +28,9 @@ def get_min_capability(cls) -> int:
2728
@classmethod
2829
def can_implement(cls,
2930
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
31+
# Machete uses CUTLASS, so it can only be compatible with Nvidia
32+
if not current_platform.is_cuda():
33+
return False, "Machete only supported on CUDA"
3034

3135
if c.has_g_idx and\
3236
c.partition_weight_shape[0] != c.full_weight_shape[0]:

vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
1414
from vllm.model_executor.parameter import (BasevLLMParameter,
1515
permute_param_layout_)
16+
from vllm.platforms import current_platform
1617

1718
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
1819

@@ -26,6 +27,9 @@ def get_min_capability(cls) -> int:
2627
@classmethod
2728
def can_implement(cls,
2829
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
30+
# Marlin uses inline PTX, so it can only be compatible with Nvidia
31+
if not current_platform.is_cuda():
32+
return False, "Marlin only supported on CUDA"
2933

3034
quant_types = query_marlin_supported_quant_types(c.zero_points)
3135
if c.weight_type not in quant_types:

0 commit comments

Comments
 (0)