Skip to content

Commit 888780f

Browse files
authored
[Feature] block_wise_fp8 support triton_moe_backend (#2767)
1 parent e3768c5 commit 888780f

File tree

5 files changed

+248
-10
lines changed

5 files changed

+248
-10
lines changed

docs/usage/environment_variables.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
6767
# Switch from standalone PD to centralized inference (0 or 1)
6868
"FD_PD_CHANGEABLE":
6969
lambda: os.getenv("FD_PD_CHANGEABLE", "1"),
70-
70+
71+
# Whether to use DeepGemm for FP8 blockwise MoE.
72+
"FD_USE_DEEP_GEMM":
73+
lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))),
74+
7175
}
72-
```
76+
```

docs/zh/usage/environment_variables.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# FastDeploy 环境变量说明
22
FastDeploy 的环境变量保存在了代码库根目录下 fastdeploy/envs.py 文件中,以下是其对应的中文版说明:
3+
34
```python
45
environment_variables: dict[str, Callable[[], Any]] = {
56
# 构建 FastDeploy 时使用的 CUDA 架构版本,这是一个字符串列表,例如[80,90]
@@ -65,6 +66,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
6566
# 是否从单机 PD 分离转换为集中式推理
6667
"FD_PD_CHANGEABLE":
6768
lambda: os.getenv("FD_PD_CHANGEABLE", "1"),
68-
69+
70+
# 是否使用DeepGemm后端的FP8 blockwise MoE.
71+
"FD_USE_DEEP_GEMM":
72+
lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))),
73+
6974
}
70-
```
75+
```

fastdeploy/envs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@
9797
# Whether to use fastsafetensor load weight (0 or 1)
9898
"FD_USE_FASTSAFETENSOR":
9999
lambda: os.getenv("FD_USE_FASTSAFETENSOR", "0"),
100+
101+
# Whether to use DeepGemm for FP8 blockwise MoE.
102+
"FD_USE_DEEP_GEMM":
103+
lambda: bool(int(os.getenv("FD_USE_DEEP_GEMM", "1"))),
100104
}
101105

102106

fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py

Lines changed: 220 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import fastdeploy
2121
from fastdeploy.distributed.communication_op import \
2222
tensor_model_parallel_all_reduce
23-
from fastdeploy.model_executor.layers.utils import get_tensor
23+
from fastdeploy.model_executor.layers.utils import (create_and_set_parameter,
24+
get_tensor)
2425
from fastdeploy.utils import ceil_div
2526

2627
from ..quantization.quant_base import QuantMethodBase
@@ -191,7 +192,7 @@ def apply(
191192

192193
ffn2_input = paddle.incubate.nn.functional.swiglu(
193194
ffn1_out)
194-
195+
195196
ffn2_out = paddle.empty(
196197
(token_num * top_k, hidden_size),
197198
dtype=x.dtype,
@@ -484,3 +485,220 @@ def apply(
484485
tensor_model_parallel_all_reduce(out)
485486

486487
return out
488+
489+
class BlockWiseFP8MoEMethod(QuantMethodBase):
490+
"""
491+
Use Triton Group Gemm to compute Fused BlockWise FP8 Quant MoE.
492+
"""
493+
494+
def __init__(self, quant_config):
495+
"""
496+
Triton Group Gemm to compute Fused MoE.
497+
"""
498+
self.quant_config = quant_config
499+
self.added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
500+
self.added_scale_attrs = [
501+
"moe_ffn1_weight_scale", "moe_ffn2_weight_scale"
502+
]
503+
504+
def process_prequanted_weights(self, layer: nn.Layer, state_dict) -> None:
505+
"""process_prequanted_weights"""
506+
507+
raise NotImplementedError()
508+
509+
def create_weights(self, layer: nn.Layer, state_dict):
510+
"""
511+
Triton MoE create weight process.
512+
"""
513+
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
514+
515+
self.check(layer, ffn1_weights, ffn2_weights)
516+
517+
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
518+
weight_name = self.added_weight_attrs[idx]
519+
scale_name = self.added_scale_attrs[idx]
520+
521+
weight_list = []
522+
weight_scale_list = []
523+
for i in range(layer.num_local_experts):
524+
from fastdeploy.model_executor.layers.utils import \
525+
per_block_cast_to_fp8
526+
quant_weight, scale = per_block_cast_to_fp8(
527+
weight_tensor[i], self.quant_config.weight_block_size)
528+
529+
weight_list.append(quant_weight)
530+
weight_scale_list.append(scale)
531+
quanted_weight = paddle.stack(weight_list, axis=0)
532+
quanted_weight = quanted_weight.transpose([0, 2, 1]).contiguous()
533+
create_and_set_parameter(layer, weight_name, quanted_weight)
534+
535+
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
536+
quanted_weight_scale = quanted_weight_scale.transpose(
537+
[0, 2, 1]).contiguous()
538+
create_and_set_parameter(layer, scale_name, quanted_weight_scale)
539+
540+
def check(self, layer: nn.Layer, ffn1_weights, ffn2_weights):
541+
"""
542+
check layer is valid for this method
543+
"""
544+
assert ffn1_weights[0].shape == [
545+
layer.hidden_size, layer.moe_intermediate_size * 2
546+
]
547+
assert ffn2_weights[0].shape == [
548+
layer.moe_intermediate_size, layer.hidden_size
549+
]
550+
551+
def apply(
552+
self,
553+
layer: nn.Layer,
554+
x: paddle.Tensor,
555+
gate_out: paddle.Tensor,
556+
) -> paddle.Tensor:
557+
"""
558+
Triton compute Fused MoE.
559+
"""
560+
561+
token_num = x.shape[0]
562+
top_k = layer.top_k
563+
num_local_experts = layer.num_local_experts
564+
moe_intermediate_size = layer.moe_intermediate_size
565+
hidden_size = layer.hidden_size
566+
E, N1, _ = layer.moe_ffn1_weight.shape
567+
N2 = layer.moe_ffn2_weight.shape[1]
568+
569+
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
570+
gate_out,
571+
layer.gate_correction_bias,
572+
layer.top_k,
573+
True, # apply_norm_weight
574+
False,
575+
)
576+
577+
config = {
578+
"BLOCK_SIZE_M": 64,
579+
"BLOCK_SIZE_N": self.quant_config.weight_block_size[1],
580+
"BLOCK_SIZE_K": self.quant_config.weight_block_size[0],
581+
"GROUP_SIZE_M": 32,
582+
"num_warps": 4,
583+
"num_stages": 3,
584+
}
585+
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
586+
587+
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess(
588+
topk_ids, num_local_experts, config["BLOCK_SIZE_M"])
589+
max_num_tokens_padded = sorted_token_ids.shape[0]
590+
591+
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
592+
ceil_div(moe_intermediate_size * 2, config["BLOCK_SIZE_N"]), )
593+
594+
from .triton_moe_kernels import fused_moe_kernel_paddle
595+
596+
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(
597+
x, self.quant_config.weight_block_size[0])
598+
599+
cache13 = paddle.empty([token_num * top_k * max(N1, N2)],
600+
dtype=x.dtype)
601+
intermediate_cache1 = cache13[:token_num * top_k * N1].view(
602+
[token_num * top_k, N1])
603+
intermediate_cache3 = cache13[:token_num * top_k * N2].view(
604+
[token_num * top_k, N2])
605+
606+
fused_moe_kernel_paddle[grid](
607+
x_q,
608+
layer.moe_ffn1_weight.view(paddle.float8_e4m3fn),
609+
intermediate_cache1,
610+
x_scale,
611+
layer.moe_ffn1_weight_scale,
612+
None,
613+
sorted_token_ids,
614+
expert_ids,
615+
num_tokens_post_padded,
616+
max_num_tokens_padded,
617+
token_num * top_k,
618+
N=moe_intermediate_size * 2,
619+
K=hidden_size,
620+
stride_am=x_q.strides[0],
621+
stride_ak=x_q.strides[1],
622+
stride_be=layer.moe_ffn1_weight.strides[0],
623+
stride_bk=layer.moe_ffn1_weight.strides[2],
624+
stride_bn=layer.moe_ffn1_weight.strides[1],
625+
stride_cm=intermediate_cache1.strides[0],
626+
stride_cn=intermediate_cache1.strides[1],
627+
#
628+
stride_asm=x_scale.strides[0], # only used in blockwise fp8
629+
stride_ask=x_scale.strides[1], # only used in blockwise fp8
630+
stride_bse=layer.moe_ffn1_weight_scale.strides[0],
631+
stride_bsk=layer.moe_ffn1_weight_scale.strides[2],
632+
stride_bsn=layer.moe_ffn1_weight_scale.strides[1],
633+
group_n=self.quant_config.weight_block_size[1],
634+
group_k=self.quant_config.weight_block_size[0],
635+
# Meta-parameters
636+
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
637+
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
638+
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
639+
GROUP_SIZE_M=config["GROUP_SIZE_M"],
640+
MUL_ROUTED_WEIGHT=False,
641+
top_k=top_k,
642+
compute_type_enum=1,
643+
use_fp8_w8a8=True,
644+
use_int8_w8a16=False,
645+
even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0,
646+
)
647+
648+
intermediate_cache2 = paddle.incubate.nn.functional.swiglu(
649+
intermediate_cache1)
650+
651+
grid = (ceil_div(max_num_tokens_padded, config["BLOCK_SIZE_M"]) *
652+
ceil_div(hidden_size, config["BLOCK_SIZE_N"]), )
653+
654+
x_q, x_scale = fastdeploy.model_executor.ops.gpu.per_token_quant(
655+
intermediate_cache2, self.quant_config.weight_block_size[0])
656+
657+
fused_moe_kernel_paddle[grid](
658+
x_q,
659+
layer.moe_ffn2_weight.view(paddle.float8_e4m3fn),
660+
intermediate_cache3,
661+
x_scale,
662+
layer.moe_ffn2_weight_scale,
663+
topk_weights,
664+
sorted_token_ids,
665+
expert_ids,
666+
num_tokens_post_padded,
667+
max_num_tokens_padded,
668+
token_num * top_k,
669+
N=hidden_size,
670+
K=moe_intermediate_size,
671+
stride_am=x_q.strides[0],
672+
stride_ak=x_q.strides[1],
673+
stride_be=layer.moe_ffn2_weight.strides[0],
674+
stride_bk=layer.moe_ffn2_weight.strides[2],
675+
stride_bn=layer.moe_ffn2_weight.strides[1],
676+
stride_cm=intermediate_cache3.strides[0],
677+
stride_cn=intermediate_cache3.strides[1],
678+
stride_asm=x_scale.strides[0], # only used in blockwise fp8
679+
stride_ask=x_scale.strides[1], # only used in blockwise fp8
680+
stride_bse=layer.moe_ffn2_weight_scale.strides[0],
681+
stride_bsk=layer.moe_ffn2_weight_scale.strides[2],
682+
stride_bsn=layer.moe_ffn2_weight_scale.strides[1],
683+
group_n=self.quant_config.weight_block_size[1],
684+
group_k=self.quant_config.weight_block_size[0],
685+
# Meta-parameters
686+
BLOCK_SIZE_M=config["BLOCK_SIZE_M"],
687+
BLOCK_SIZE_N=config["BLOCK_SIZE_N"],
688+
BLOCK_SIZE_K=config["BLOCK_SIZE_K"],
689+
GROUP_SIZE_M=config["GROUP_SIZE_M"],
690+
MUL_ROUTED_WEIGHT=True,
691+
top_k=1,
692+
compute_type_enum=1,
693+
use_fp8_w8a8=True,
694+
use_int8_w8a16=False,
695+
even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0,
696+
)
697+
698+
intermediate_cache3.reshape_([token_num, top_k, hidden_size])
699+
out = intermediate_cache3.sum(axis=1)
700+
701+
if layer.tp_size > 1:
702+
tensor_model_parallel_all_reduce(out)
703+
704+
return out

fastdeploy/model_executor/layers/quantization/block_wise_fp8.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
import paddle
1919

2020
import fastdeploy
21+
from fastdeploy import envs
2122
from fastdeploy.model_executor.layers.moe import FusedMoE
2223

23-
from ..utils import per_block_cast_to_fp8, get_tensor
24+
from ..utils import get_tensor, per_block_cast_to_fp8
2425
from .quant_base import QuantConfigBase, QuantMethodBase
2526

2627

@@ -37,6 +38,7 @@ def __init__(self, weight_block_size: list = [-1, -1]) -> None:
3738
self.quant_max_bound = 448
3839
self.quant_min_bound = -448
3940
self.quant_round_type = 1
41+
self.use_deep_gemm = bool(envs.FD_USE_DEEP_GEMM)
4042

4143
def name(self) -> str:
4244
return "block_wise_fp8"
@@ -51,9 +53,14 @@ def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
5153
Get quantization method.
5254
'''
5355
if isinstance(layer, FusedMoE):
54-
from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import \
55-
DeepGemmFusedMoeMethod
56-
return DeepGemmFusedMoeMethod(self)
56+
if self.use_deep_gemm:
57+
from fastdeploy.model_executor.layers.moe.fused_moe_deepgemm_backend import \
58+
DeepGemmFusedMoeMethod
59+
return DeepGemmFusedMoeMethod(self)
60+
else:
61+
from fastdeploy.model_executor.layers.moe.fused_moe_triton_backend import \
62+
BlockWiseFP8MoEMethod
63+
return BlockWiseFP8MoEMethod(self)
5764
else:
5865
return BlockWiseFP8LinearMethod(self)
5966

0 commit comments

Comments
 (0)