Skip to content

Commit b8e809a

Browse files
authored
[Kernel] Support deep_gemm for linear methods (#19085)
Signed-off-by: artetaout <lulala341@gmail.com>
1 parent 5039ec2 commit b8e809a

File tree

3 files changed

+124
-1
lines changed

3 files changed

+124
-1
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import importlib.util
3+
import logging
4+
5+
import torch
6+
7+
from vllm.platforms import current_platform
8+
from vllm.triton_utils import triton
9+
from vllm.utils import direct_register_custom_op
10+
11+
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
12+
if has_deep_gemm:
13+
import deep_gemm
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
def prepare_block_fp8_matmul_inputs(
19+
A: torch.Tensor,
20+
B: torch.Tensor,
21+
As: torch.Tensor,
22+
Bs: torch.Tensor,
23+
block_size: list[int],
24+
output_dtype: torch.dtype = torch.float16,
25+
) -> tuple[int, int, int, torch.Tensor]:
26+
assert len(block_size) == 2
27+
block_n, block_k = block_size[0], block_size[1]
28+
29+
assert A.shape[-1] == B.shape[-1]
30+
assert A.shape[:-1] == As.shape[:-1]
31+
assert A.is_contiguous()
32+
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
33+
34+
M = A.numel() // A.shape[-1]
35+
36+
assert B.ndim == 2
37+
assert B.is_contiguous()
38+
assert Bs.ndim == 2
39+
N, K = B.shape
40+
assert triton.cdiv(N, block_n) == Bs.shape[0]
41+
assert triton.cdiv(K, block_k) == Bs.shape[1]
42+
43+
C_shape = A.shape[:-1] + (N, )
44+
C = A.new_empty(C_shape, dtype=output_dtype)
45+
46+
return M, N, K, C
47+
48+
49+
def w8a8_block_fp8_matmul_deepgemm(
50+
A: torch.Tensor,
51+
B: torch.Tensor,
52+
As: torch.Tensor,
53+
Bs: torch.Tensor,
54+
block_size: list[int],
55+
output_dtype: torch.dtype,
56+
) -> torch.Tensor:
57+
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size,
58+
output_dtype)
59+
# Deepgemm only supports output tensor type as bfloat16
60+
assert C.dtype == torch.bfloat16
61+
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
62+
return C
63+
64+
65+
def w8a8_block_fp8_matmul_deepgemm_fake(
66+
A: torch.Tensor,
67+
B: torch.Tensor,
68+
As: torch.Tensor,
69+
Bs: torch.Tensor,
70+
block_size: list[int],
71+
output_dtype: torch.dtype,
72+
) -> torch.Tensor:
73+
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size,
74+
output_dtype)
75+
return C
76+
77+
78+
direct_register_custom_op(
79+
op_name="w8a8_block_fp8_matmul_deepgemm",
80+
op_func=w8a8_block_fp8_matmul_deepgemm,
81+
mutates_args=[],
82+
fake_impl=w8a8_block_fp8_matmul_deepgemm_fake,
83+
dispatch_key=current_platform.dispatch_key,
84+
)

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ def apply(self,
402402

403403
if self.block_quant:
404404
assert self.quant_config.weight_block_size is not None
405+
405406
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
406407
input=x,
407408
weight=layer.weight,

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33

44
# Adapted from https://github.com/sgl-project/sglang/pull/2575
55
import functools
6+
import importlib.util
67
import json
78
import os
89
from typing import Any, Callable, Optional, Union
910

1011
import torch
1112

13+
import vllm.envs as envs
1214
from vllm import _custom_ops as ops
1315
from vllm.logger import init_logger
1416
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@@ -20,6 +22,7 @@
2022
from vllm.utils import direct_register_custom_op
2123

2224
logger = init_logger(__name__)
25+
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
2326

2427

2528
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
@@ -98,6 +101,19 @@ def dispatch_w8a8_blockscale_func(
98101
return w8a8_block_fp8_matmul
99102

100103

104+
def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
105+
"""
106+
Check if DeepGEMM should be used based on the output dtype and weight shape.
107+
DeepGEMM is only supported for bfloat16 output dtype and weights with shape
108+
divisible by 128.
109+
"""
110+
111+
return (current_platform.is_cuda()
112+
and current_platform.is_device_capability(90) and has_deep_gemm
113+
and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16
114+
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
115+
116+
101117
# TODO fix ROCm->Triton custom path:
102118
# https://github.com/vllm-project/vllm/issues/14397
103119
def apply_w8a8_block_fp8_linear(
@@ -114,6 +130,29 @@ def apply_w8a8_block_fp8_linear(
114130
# View input as 2D matrix for fp8 methods
115131
input_2d = input.view(-1, input.shape[-1])
116132
output_shape = [*input.shape[:-1], weight.shape[0]]
133+
output_dtype = input.dtype
134+
135+
if should_use_deepgemm(output_dtype, weight):
136+
137+
input_2d = input.view(-1, input.shape[-1])
138+
output_shape = [*input.shape[:-1], weight.shape[0]]
139+
140+
q_input, x_scale = per_token_group_quant_fp8(
141+
input_2d,
142+
block_size[1],
143+
column_major_scales=True,
144+
)
145+
146+
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
147+
q_input,
148+
weight,
149+
x_scale,
150+
weight_scale,
151+
block_size,
152+
output_dtype=output_dtype)
153+
if bias is not None:
154+
output += bias
155+
return output.to(dtype=output_dtype).view(*output_shape)
117156

118157
if current_platform.is_cuda():
119158
if current_platform.has_device_capability(100):
@@ -134,7 +173,6 @@ def ceil_div(x: int, y: int) -> int:
134173

135174
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
136175
use_cutlass, use_aiter_and_is_supported)
137-
138176
if use_cutlass:
139177
q_input, x_scale = per_token_group_quant_fp8(
140178
input_2d, block_size[1], column_major_scales=use_cutlass)

0 commit comments

Comments
 (0)