Skip to content

Commit be21ef5

Browse files
authored
[XPU] Supports BF16 for ERNIE-4.5-21B-A3B and ERNIE-4.5-0.3B (#2765)
* fix no quant xpu moe * change dir of xpu moe weight only
1 parent 771e71a commit be21ef5

File tree

5 files changed

+234
-117
lines changed

5 files changed

+234
-117
lines changed

fastdeploy/model_executor/layers/backends/xpu/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@
1616
xpu backend methods
1717
"""
1818

19-
from .quantization.weight_only import XPUWeightOnlyLinearMethod, XPUWeightOnlyMoEMethod
19+
from .quantization.weight_only import XPUWeightOnlyLinearMethod
2020

21-
__all__ = ['XPUWeightOnlyLinearMethod', 'XPUWeightOnlyMoEMethod']
21+
__all__ = ['XPUWeightOnlyLinearMethod']

fastdeploy/model_executor/layers/backends/xpu/quantization/weight_only.py

Lines changed: 0 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
"""
16-
17-
from typing import Dict
18-
1916
import paddle
2017
from paddle import nn
2118

22-
from fastdeploy.model_executor.layers.quantization.quant_base import \
23-
QuantMethodBase
2419
from fastdeploy.model_executor.layers.quantization.weight_only import (
2520
WeightOnlyConfig, WeightOnlyLinearMethod)
2621
from fastdeploy.model_executor.ops.xpu import weight_quantize_xpu
@@ -63,103 +58,3 @@ def process_loaded_weights(self, layer: nn.Layer,
6358
layer.linear_weight.set_value(
6459
paddle.transpose(quanted_weight_tensor, [1, 0]))
6560
layer.linear_weight_scale.set_value(weight_scale_tensor)
66-
67-
68-
class XPUWeightOnlyMoEMethod(QuantMethodBase):
69-
"""
70-
XPU Fused MoE Method.
71-
"""
72-
73-
def __init__(
74-
self,
75-
quant_config: WeightOnlyConfig,
76-
) -> None:
77-
super().__init__()
78-
self.quant_config = quant_config
79-
self.moe_quant_type = self.quant_config.algo
80-
81-
def create_weights(self, layer: nn.Layer, state_dict: Dict[str,
82-
paddle.Tensor]):
83-
"""
84-
Paddle cutlass create weight process.
85-
"""
86-
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
87-
assert len(ffn1_weights) == layer.num_local_experts
88-
assert len(ffn2_weights) == layer.num_local_experts
89-
assert ffn1_weights[0].shape == [
90-
layer.hidden_size, layer.moe_intermediate_size * 2
91-
]
92-
assert ffn2_weights[0].shape == [
93-
layer.moe_intermediate_size, layer.hidden_size
94-
]
95-
96-
added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
97-
added_scale_attrs = ["moe_ffn1_weight_scale", "moe_ffn2_weight_scale"]
98-
99-
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
100-
weight_name = added_weight_attrs[idx]
101-
scale_name = added_scale_attrs[idx]
102-
103-
weight_list = []
104-
weight_scale_list = []
105-
for i in range(layer.num_local_experts):
106-
quant_weight, scale = weight_quantize_xpu(
107-
weight_tensor[i], self.moe_quant_type, -1,
108-
-1) # weight is [k,n]
109-
weight_list.append(quant_weight.transpose(
110-
[1, 0])) # transpose weight to [n,k]
111-
weight_scale_list.append(scale)
112-
quanted_weight = paddle.stack(weight_list, axis=0)
113-
setattr(
114-
layer, weight_name,
115-
layer.create_parameter(
116-
shape=quanted_weight.shape,
117-
dtype=quanted_weight.dtype,
118-
default_initializer=paddle.nn.initializer.Constant(0),
119-
))
120-
getattr(layer, weight_name).set_value(quanted_weight)
121-
122-
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
123-
setattr(
124-
layer, scale_name,
125-
layer.create_parameter(
126-
shape=quanted_weight_scale.shape,
127-
dtype=quanted_weight_scale.dtype,
128-
))
129-
getattr(layer, scale_name).set_value(quanted_weight_scale)
130-
131-
def apply(
132-
self,
133-
layer: nn.Layer,
134-
x: paddle.Tensor,
135-
gate_out: paddle.Tensor,
136-
) -> paddle.Tensor:
137-
"""
138-
XPU compute Fused MoE.
139-
"""
140-
from fastdeploy.model_executor.ops.xpu import xpu_moe_layer
141-
142-
fused_moe_out = xpu_moe_layer(
143-
x,
144-
layer.gate_weight.transpose([1, 0]),
145-
layer.gate_correction_bias,
146-
layer.moe_ffn1_weight,
147-
layer.moe_ffn2_weight,
148-
None, # ffn1 bias
149-
None, # ffn2 bias
150-
(layer.moe_ffn1_weight_scale
151-
if hasattr(layer, "moe_ffn1_weight_scale") else None),
152-
(layer.moe_ffn2_weight_scale
153-
if hasattr(layer, "moe_ffn2_weight_scale") else None),
154-
(layer.moe_ffn2_in_scale
155-
if hasattr(layer, "moe_ffn2_in_scale") else None),
156-
self.moe_quant_type,
157-
layer.top_k,
158-
False, # moe group, used in deepseek
159-
)
160-
if layer.tp_size > 1:
161-
from fastdeploy.distributed.communication_op import \
162-
tensor_model_parallel_all_reduce
163-
tensor_model_parallel_all_reduce(fused_moe_out)
164-
165-
return fused_moe_out
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
"""
2+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
from typing import Dict
18+
19+
import paddle
20+
from paddle import nn
21+
22+
from fastdeploy.model_executor.layers.quantization.quant_base import \
23+
QuantMethodBase
24+
from fastdeploy.model_executor.layers.quantization.weight_only import \
25+
WeightOnlyConfig
26+
from fastdeploy.model_executor.ops.xpu import weight_quantize_xpu
27+
28+
from .fused_moe_backend_base import MoEMethodBase
29+
30+
31+
class XPUMoEMethod(MoEMethodBase):
32+
"""
33+
XPU MOE
34+
"""
35+
36+
def create_weights(self, layer: nn.Layer, state_dict):
37+
"""
38+
Paddle cutlass create weight process.
39+
"""
40+
# bf16
41+
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
42+
for weights in [ffn1_weights, ffn2_weights]:
43+
for idx, weight in enumerate(weights):
44+
weights[idx] = weight.transpose([1, 0])
45+
stacked_ffn1_weights = paddle.stack(ffn1_weights, axis=0)
46+
stacked_ffn2_weights = paddle.stack(ffn2_weights, axis=0)
47+
for idx, weight_tensor in enumerate(
48+
[stacked_ffn1_weights, stacked_ffn2_weights]):
49+
weight_name = self.added_weight_attrs[idx]
50+
setattr(
51+
layer, weight_name,
52+
layer.create_parameter(
53+
shape=weight_tensor.shape,
54+
dtype=weight_tensor.dtype,
55+
default_initializer=paddle.nn.initializer.Constant(0),
56+
))
57+
getattr(layer, weight_name).set_value(weight_tensor)
58+
59+
def apply_tp(
60+
self,
61+
layer: nn.Layer,
62+
x: paddle.Tensor,
63+
gate_out: paddle.Tensor,
64+
) -> paddle.Tensor:
65+
"""
66+
Paddle Cutlass compute Fused MoE.
67+
"""
68+
from fastdeploy.model_executor.ops.xpu import xpu_moe_layer
69+
70+
fused_moe_out = xpu_moe_layer(
71+
x,
72+
layer.gate_weight.transpose([1, 0]),
73+
layer.gate_correction_bias,
74+
layer.moe_ffn1_weight,
75+
layer.moe_ffn2_weight,
76+
None, # ffn1 bias
77+
None, # ffn2 bias
78+
None, # ffn1 scale
79+
None, # ffn2 scale
80+
None, # ffn1_in_scale
81+
"", # moe_quant_type
82+
layer.top_k,
83+
False, # moe group, used in deepseek
84+
)
85+
if layer.tp_size > 1:
86+
from fastdeploy.distributed.communication_op import \
87+
tensor_model_parallel_all_reduce
88+
tensor_model_parallel_all_reduce(fused_moe_out)
89+
90+
return fused_moe_out
91+
92+
def apply_ep_prefill(
93+
self,
94+
layer: nn.Layer,
95+
x: paddle.Tensor,
96+
gate_out: paddle.Tensor,
97+
) -> paddle.Tensor:
98+
"""
99+
Apply the EP prefill method.
100+
"""
101+
raise NotImplementedError
102+
103+
def apply_ep_decode(
104+
self,
105+
layer: nn.Layer,
106+
x: paddle.Tensor,
107+
gate_out: paddle.Tensor,
108+
) -> paddle.Tensor:
109+
"""
110+
Apply the EP decoder method.
111+
"""
112+
raise NotImplementedError
113+
114+
class XPUWeightOnlyMoEMethod(QuantMethodBase):
115+
"""
116+
XPU Fused MoE Method.
117+
"""
118+
119+
def __init__(
120+
self,
121+
quant_config: WeightOnlyConfig,
122+
) -> None:
123+
super().__init__()
124+
self.quant_config = quant_config
125+
self.moe_quant_type = self.quant_config.algo
126+
127+
def create_weights(self, layer: nn.Layer, state_dict: Dict[str,
128+
paddle.Tensor]):
129+
"""
130+
Paddle cutlass create weight process.
131+
"""
132+
ffn1_weights, ffn2_weights = layer.extract_moe_ffn_weights(state_dict)
133+
assert len(ffn1_weights) == layer.num_local_experts
134+
assert len(ffn2_weights) == layer.num_local_experts
135+
assert ffn1_weights[0].shape == [
136+
layer.hidden_size, layer.moe_intermediate_size * 2
137+
]
138+
assert ffn2_weights[0].shape == [
139+
layer.moe_intermediate_size, layer.hidden_size
140+
]
141+
142+
added_weight_attrs = ["moe_ffn1_weight", "moe_ffn2_weight"]
143+
added_scale_attrs = ["moe_ffn1_weight_scale", "moe_ffn2_weight_scale"]
144+
145+
for idx, weight_tensor in enumerate([ffn1_weights, ffn2_weights]):
146+
weight_name = added_weight_attrs[idx]
147+
scale_name = added_scale_attrs[idx]
148+
149+
weight_list = []
150+
weight_scale_list = []
151+
for i in range(layer.num_local_experts):
152+
quant_weight, scale = weight_quantize_xpu(
153+
weight_tensor[i], self.moe_quant_type, -1,
154+
-1) # weight is [k,n]
155+
weight_list.append(quant_weight.transpose(
156+
[1, 0])) # transpose weight to [n,k]
157+
weight_scale_list.append(scale)
158+
quanted_weight = paddle.stack(weight_list, axis=0)
159+
setattr(
160+
layer, weight_name,
161+
layer.create_parameter(
162+
shape=quanted_weight.shape,
163+
dtype=quanted_weight.dtype,
164+
default_initializer=paddle.nn.initializer.Constant(0),
165+
))
166+
getattr(layer, weight_name).set_value(quanted_weight)
167+
168+
quanted_weight_scale = paddle.stack(weight_scale_list, axis=0)
169+
setattr(
170+
layer, scale_name,
171+
layer.create_parameter(
172+
shape=quanted_weight_scale.shape,
173+
dtype=quanted_weight_scale.dtype,
174+
))
175+
getattr(layer, scale_name).set_value(quanted_weight_scale)
176+
177+
def apply(
178+
self,
179+
layer: nn.Layer,
180+
x: paddle.Tensor,
181+
gate_out: paddle.Tensor,
182+
) -> paddle.Tensor:
183+
"""
184+
XPU compute Fused MoE.
185+
"""
186+
from fastdeploy.model_executor.ops.xpu import xpu_moe_layer
187+
188+
fused_moe_out = xpu_moe_layer(
189+
x,
190+
layer.gate_weight.transpose([1, 0]),
191+
layer.gate_correction_bias,
192+
layer.moe_ffn1_weight,
193+
layer.moe_ffn2_weight,
194+
None, # ffn1 bias
195+
None, # ffn2 bias
196+
(layer.moe_ffn1_weight_scale
197+
if hasattr(layer, "moe_ffn1_weight_scale") else None),
198+
(layer.moe_ffn2_weight_scale
199+
if hasattr(layer, "moe_ffn2_weight_scale") else None),
200+
(layer.moe_ffn2_in_scale
201+
if hasattr(layer, "moe_ffn2_in_scale") else None),
202+
self.moe_quant_type,
203+
layer.top_k,
204+
False, # moe group, used in deepseek
205+
)
206+
if layer.tp_size > 1:
207+
from fastdeploy.distributed.communication_op import \
208+
tensor_model_parallel_all_reduce
209+
tensor_model_parallel_all_reduce(fused_moe_out)
210+
211+
return fused_moe_out

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,24 @@
2020

2121
from fastdeploy import envs
2222
from fastdeploy.model_executor.layers.utils import get_tensor
23-
from fastdeploy.platforms import current_platform
2423

2524

25+
def get_moe_method():
26+
"""
27+
return moe method based on device platform
28+
"""
29+
from fastdeploy.platforms import current_platform
30+
if current_platform.is_cuda():
31+
from .fused_moe_cutlass_backend import CutlassMoEMethod
32+
return CutlassMoEMethod(None)
33+
elif current_platform.is_xpu():
34+
from .fused_moe_xpu_backend import XPUMoEMethod
35+
return XPUMoEMethod(None)
36+
elif current_platform.is_gcu():
37+
from fastdeploy.model_executor.layers.backends import GCUFusedMoeMethod
38+
return GCUFusedMoeMethod(None)
39+
raise NotImplementedError()
40+
2641
class FusedMoE(nn.Layer):
2742
"""
2843
FusedMoE is a layer that performs MoE (Mixture of Experts) computation.
@@ -96,13 +111,7 @@ def __init__(
96111
self.moe_quant_type = moe_quant_config.name()
97112
else:
98113
# now, no quant method(w_fp16 a_fp16) can't get from quant_config, we will optimize it in future
99-
if current_platform.is_cuda():
100-
from .fused_moe_cutlass_backend import CutlassMoEMethod
101-
self.quant_method = CutlassMoEMethod(None)
102-
elif current_platform.is_gcu():
103-
from fastdeploy.model_executor.layers.backends import \
104-
GCUFusedMoeMethod
105-
self.quant_method = GCUFusedMoeMethod(None)
114+
self.quant_method = get_moe_method()
106115

107116
if self.ep_size > 1:
108117
self.quant_method.init_ep(self)

fastdeploy/model_executor/layers/quantization/weight_only.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ def from_config(cls, config: dict) -> "WeightOnlyConfig":
6060

6161
def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
6262
if current_platform.is_xpu():
63-
from fastdeploy.model_executor.layers.backends import (
64-
XPUWeightOnlyLinearMethod, XPUWeightOnlyMoEMethod)
63+
from fastdeploy.model_executor.layers.backends import \
64+
XPUWeightOnlyLinearMethod
65+
from fastdeploy.model_executor.layers.moe.fused_moe_xpu_backend import \
66+
XPUWeightOnlyMoEMethod
6567
if isinstance(layer, FusedMoE):
6668
return XPUWeightOnlyMoEMethod(self)
6769
else:

0 commit comments

Comments
 (0)