Skip to content

Commit d0f4d6b

Browse files
[GCU] Support gcu platform (#2702)
baseline: e7fa57e Co-authored-by: yongqiangma <xing.wo@163.com>
1 parent 26d5d73 commit d0f4d6b

33 files changed

+2988
-85
lines changed

build.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,14 @@ function copy_ops(){
113113
return
114114
fi
115115

116+
is_gcu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('gcu'))"`
117+
if [ "$is_gcu" = "True" ]; then
118+
DEVICE_TYPE="gcu"
119+
cp -r ${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gcu
120+
echo -e "gcu ops have been copy to fastdeploy"
121+
return
122+
fi
123+
116124
DEVICE_TYPE="cpu"
117125
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
118126
cd ../../../../

custom_ops/setup_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,17 @@ def find_end_files(directory, end_str):
501501
],
502502
),
503503
)
504+
elif paddle.is_compiled_with_custom_device("gcu"):
505+
setup(
506+
name="fastdeploy_ops",
507+
ext_modules=CppExtension(
508+
sources=[
509+
"gpu_ops/save_with_output_msg.cc",
510+
"gpu_ops/get_output.cc",
511+
"gpu_ops/get_output_msg_with_topk.cc",
512+
]
513+
),
514+
)
504515
else:
505516
use_bf16 = envs.FD_CPU_USE_BF16 == "True"
506517

docs/get_started/installation/Enflame_gcu.md

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
# Running ERNIE-4.5-21B-A3B with FastDeploy
1+
# Running ERNIE 4.5 Series Models with FastDeploy
22

33
The Enflame S60 ([Learn about Enflame](https://www.enflame-tech.com/)) is a next-generation AI inference accelerator card designed for large-scale deployment in data centers. It meets the demands of large language models (LLMs), search/advertising/recommendation systems, and traditional models. Characterized by broad model coverage, user-friendliness, and high portability, it is widely applicable to mainstream inference scenarios such as image and text generation applications, search and recommendation systems, and text/image/speech recognition.
44

5-
FastDeploy has deeply adapted and optimized the ernie-4_5-21b-a3b-bf16-paddle model for the Enflame S60, achieving a unified inference interface between GCU and GPU. This allows seamless migration of inference tasks without code modifications.
5+
FastDeploy has deeply adapted and optimized the ERNIE 4.5 Series Models for the Enflame S60, achieving a unified inference interface between GCU and GPU. This allows seamless migration of inference tasks without code modifications.
66

77
## 🚀 Quick Start 🚀
88

@@ -27,15 +27,15 @@ lspci | grep S60
2727
3b:00.0 Processing accelerators: Shanghai Enflame Technology Co. Ltd S60 [Enflame] (rev 01)
2828
3c:00.0 Processing accelerators: Shanghai Enflame Technology Co. Ltd S60 [Enflame] (rev 01)
2929
```
30-
### 1. Environment Setup (Estimated time: 510 minutes)
30+
### 1. Environment Setup (Estimated time: 5-10 minutes)
3131
1. Pull the Docker image
3232
```bash
3333
# Note: This image only contains the Paddle development environment, not precompiled PaddlePaddle packages
34-
docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.4.623-ubuntu20-x86_64-gcc84
34+
docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84
3535
```
3636
2. Start the container
3737
```bash
38-
docker run --name paddle-gcu-llm -v /home:/home -v /work:/work --network=host --ipc=host -it --privileged ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.4.623-ubuntu20-x86_64-gcc84 /bin/bash
38+
docker run --name paddle-gcu-llm -v /home:/home -v /work:/work --network=host --ipc=host -it --privileged ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84 /bin/bash
3939
```
4040
3. Obtain and install drivers<br/>
4141
**Full software packages are preloaded in the Docker container. Copy them to an external directory, e.g., ```/home/workspace/deps/```**
@@ -67,39 +67,45 @@ python -m pip install paddle-custom-gcu==3.1.0 -i https://www.paddlepaddle.org.c
6767
7. Install FastDeploy and dependencies
6868
```bash
6969
python -m pip install fastdeploy -i https://www.paddlepaddle.org.cn/packages/stable/gcu/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simplels
70-
apt install python3.10-distutils
70+
# For source compilation, refer to the following steps
71+
git clone https://github.com/PaddlePaddle/FastDeploy
72+
cd FastDeploy
73+
python -m pip install -r requirements.txt --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simplels
74+
bash build.sh 1
7175
```
72-
### 2. Data Preparation (Estimated time: 25 minutes)
76+
### 2. Data Preparation (Estimated time: 2-5 minutes)
7377
Use a trained model for inference on GSM8K dataset:
7478
```bash
7579
mkdir -p /home/workspace/benchmark/ && cd /home/workspace/benchmark/
7680
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
7781
```
78-
Place model weights in a directory, e.g., ```/work/models/ernie-4_5-21b-a3b-bf16-paddle/```
79-
### 3. Inference (Estimated time: 25 minutes)
82+
Place model weights in a directory, e.g., ```/work/models/ERNIE-4.5-300B-A47B-Paddle/```
83+
### 3. Inference (Estimated time: 2-5 minutes)
8084
Start the inference service:
8185
```bash
8286
python -m fastdeploy.entrypoints.openai.api_server \
83-
--model "/work/models/ernie-4_5-21b-a3b-bf16-paddle/" \
87+
--model "/work/models/ERNIE-4.5-300B-A47B-Paddle/" \
8488
--port 8188 \
8589
--metrics-port 8200 \
86-
--tensor-parallel-size 4 \
87-
--max-model-len 8192 \
88-
--num-gpu-blocks-override 1024
90+
--tensor-parallel-size 8 \
91+
--max-model-len 32768 \
92+
--num-gpu-blocks-override 4096 \
93+
--max-num-batched-tokens 32768 \
94+
--quantization "wint4"
8995
```
9096
Query the model service:
9197
```bash
9298
curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
9399
-H "Content-Type: application/json" \
94100
-d '{
95101
"messages": [
96-
{"role": "user", "content": "The largest ocean is"}
102+
{"role": "user", "content": "Where is Beijing?"}
97103
]
98104
}'
99105
```
100106
Successful execution returns inference results, e.g.:
101107
```json
102-
{"id":"chatcmpl-5cd96f3b-eff3-4dc0-8aa2-8b5d7b7b86f2","object":"chat.completion","created":1751167862,"model":"default","choices":[{"index":0,"message":{"role":"assistant","content":"3. **Pacific Ocean**: The Pacific Ocean is the largest and deepest of the world's oceans. It covers an area of approximately 181,344,000 square kilometers, which is more than 30% of the Earth's surface. It is located between the Americas to the west and east, and Asia and Australia to the north and south. The Pacific Ocean is known for its vastness, diverse marine life, and numerous islands.\n\nIn summary, the largest ocean in the world is the Pacific Ocean.","reasoning_content":null,"tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"total_tokens":127,"completion_tokens":116,"prompt_tokens_details":{"cached_tokens":0}}}
108+
{"id":"chatcmpl-20f1210d-6943-4110-ad2d-c76ba11604ad","object":"chat.completion","created":1751621261,"model":"default","choices":[{"index":0,"message":{"role":"assistant","content":"Beijing is the capital city of the People's Republic of China, located in the northern part of the country. It is situated in the North China Plain, bordered by the mountains to the west, north, and northeast. Beijing serves as China's political, cultural, and international exchange center, playing a crucial role in the nation's development and global interactions.","reasoning_content":null,"tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"total_tokens":88,"completion_tokens":77,"prompt_tokens_details":{"cached_tokens":0}}}
103109
```
104110
### 4. Accuracy Testing (Estimated time: 60–180 minutes)
105111
Place the accuracy script ```bench_gsm8k.py``` in ```/home/workspace/benchmark/``` and modify sampling parameters, e.g.:
@@ -120,10 +126,10 @@ data = {
120126
Run accuracy tests:
121127
```bash
122128
cd /home/workspace/benchmark/
123-
python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 2
129+
python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 8
124130
```
125131
Upon completion, accuracy results are saved in ```result.jsonl```, e.g.:
126132
```json
127-
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 365.548, "accuracy": 0.967, "num_requests": 30, "other": {"num_questions": 30, "parallel": 2}}
133+
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}}
128134
```
129135

docs/zh/get_started/installation/Enflame_gcu.md

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
# 使用 FastDeploy 在燧原 S60 上运行 ERNIE-4.5-21B-A3B模型
1+
# 使用 FastDeploy 在燧原 S60 上运行 ERNIE 4.5 系列模型
22

33
燧原 S60([了解燧原](https://www.enflame-tech.com/))是面向数据中心大规模部署的新一代人工智能推理加速卡,满足大语言模型、搜广推及传统模型的需求,具有模型覆盖面广、易用性强、易迁移易部署等特点,可广泛应用于图像及文本生成等应用、搜索与推荐、文本、图像及语音识别等主流推理场景。
44

5-
FastDeploy 在燧原 S60 上对 ernie-4_5-21b-a3b-bf16-paddle 模型进行了深度适配和优化,实现了 GCU 推理入口和 GPU 的统一,无需修改即可完成推理任务的迁移。
5+
FastDeploy 在燧原 S60 上对 ERNIE 4.5 系列模型进行了深度适配和优化,实现了 GCU 推理入口和 GPU 的统一,无需修改即可完成推理任务的迁移。
66

77
## 🚀 快速开始 🚀
88

@@ -30,11 +30,11 @@ lspci | grep S60
3030
1. 拉取镜像
3131
```bash
3232
# 注意此镜像仅为paddle开发环境,镜像中不包含预编译的飞桨安装包
33-
docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.4.623-ubuntu20-x86_64-gcc84
33+
docker pull ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84
3434
```
3535
2. 参考如下命令启动容器
3636
```bash
37-
docker run --name paddle-gcu-llm -v /home:/home -v /work:/work --network=host --ipc=host -it --privileged ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.4.623-ubuntu20-x86_64-gcc84 /bin/bash
37+
docker run --name paddle-gcu-llm -v /home:/home -v /work:/work --network=host --ipc=host -it --privileged ccr-2vdh3abv-pub.cnc.bj.baidubce.com/device/paddle-gcu:topsrider3.5.102-ubuntu20-x86_64-gcc84 /bin/bash
3838
```
3939
3. 获取并安装驱动<br/>
4040
**docker 内提前放置了全量软件包,需拷贝至 docker 外目录,如:```/home/workspace/deps/```**
@@ -63,42 +63,48 @@ python -m pip install paddlepaddle==3.1.0a0 -i https://www.paddlepaddle.org.cn/p
6363
python -m pip install paddle-custom-gcu==3.1.0 -i https://www.paddlepaddle.org.cn/packages/stable/gcu/
6464
# 如想源码编译安装,请参考https://github.com/PaddlePaddle/PaddleCustomDevice/blob/develop/backends/gcu/README_cn.md
6565
```
66-
7. 安装 FastDeploy 和 依赖<br/>
66+
7. 安装 FastDeploy <br/>
6767
```bash
6868
python -m pip install fastdeploy -i https://www.paddlepaddle.org.cn/packages/stable/gcu/ --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simplels
69-
apt install python3.10-distutils
69+
# 如想源码编译安装,请参考如下步骤
70+
git clone https://github.com/PaddlePaddle/FastDeploy
71+
cd FastDeploy
72+
python -m pip install -r requirements.txt --extra-index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simplels
73+
bash build.sh 1
7074
```
7175
### 2. 数据准备:(这将花费您 2~5min 时间)
7276
使用训练好的模型,在 GSM8K 上推理
7377
```bash
7478
mkdir -p /home/workspace/benchmark/ && cd /home/workspace/benchmark/
7579
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
7680
```
77-
准备模型和权重,置于环境目录,如:```/work/models/ernie-4_5-21b-a3b-bf16-paddle/```
81+
准备模型和权重,置于环境目录,如:```/work/models/ERNIE-4.5-300B-A47B-Paddle/```
7882
### 3. 推理:(这将花费您 2~5min 时间)
7983
执行如下命令启动推理服务
8084
```bash
8185
python -m fastdeploy.entrypoints.openai.api_server \
82-
--model "/work/models/ernie-4_5-21b-a3b-bf16-paddle/" \
86+
--model "/work/models/ERNIE-4.5-300B-A47B-Paddle/" \
8387
--port 8188 \
8488
--metrics-port 8200 \
85-
--tensor-parallel-size 4 \
86-
--max-model-len 8192 \
87-
--num-gpu-blocks-override 1024
89+
--tensor-parallel-size 8 \
90+
--max-model-len 32768 \
91+
--num-gpu-blocks-override 4096 \
92+
--max-num-batched-tokens 32768 \
93+
--quantization "wint4"
8894
```
8995
使用如下命令请求模型服务
9096
```bash
9197
curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
9298
-H "Content-Type: application/json" \
9399
-d '{
94100
"messages": [
95-
{"role": "user", "content": "The largest ocean is"}
101+
{"role": "user", "content": "Where is Beijing?"}
96102
]
97103
}'
98104
```
99105
成功运行后,可以查看到推理结果的生成,样例如下
100106
```json
101-
{"id":"chatcmpl-5cd96f3b-eff3-4dc0-8aa2-8b5d7b7b86f2","object":"chat.completion","created":1751167862,"model":"default","choices":[{"index":0,"message":{"role":"assistant","content":"3. **Pacific Ocean**: The Pacific Ocean is the largest and deepest of the world's oceans. It covers an area of approximately 181,344,000 square kilometers, which is more than 30% of the Earth's surface. It is located between the Americas to the west and east, and Asia and Australia to the north and south. The Pacific Ocean is known for its vastness, diverse marine life, and numerous islands.\n\nIn summary, the largest ocean in the world is the Pacific Ocean.","reasoning_content":null,"tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"total_tokens":127,"completion_tokens":116,"prompt_tokens_details":{"cached_tokens":0}}}
107+
{"id":"chatcmpl-20f1210d-6943-4110-ad2d-c76ba11604ad","object":"chat.completion","created":1751621261,"model":"default","choices":[{"index":0,"message":{"role":"assistant","content":"Beijing is the capital city of the People's Republic of China, located in the northern part of the country. It is situated in the North China Plain, bordered by the mountains to the west, north, and northeast. Beijing serves as China's political, cultural, and international exchange center, playing a crucial role in the nation's development and global interactions.","reasoning_content":null,"tool_calls":null},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"total_tokens":88,"completion_tokens":77,"prompt_tokens_details":{"cached_tokens":0}}}
102108
```
103109
### 4. 精度测试:(这将花费您 60~180min 时间)
104110
准备精度脚本 ```bench_gsm8k.py``` 置于 ```/home/workspace/benchmark/``` ,并修改采样参数,如:
@@ -119,10 +125,10 @@ data = {
119125
执行以下命令启动精度测试
120126
```bash
121127
cd /home/workspace/benchmark/
122-
python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 2
128+
python -u bench_gsm8k.py --port 8188 --num-questions 1319 --num-shots 5 --parallel 8
123129
```
124-
执行成功运行后,当前目录可以查看到精度结果的生成,文件为 ```result.jsonl```,样例如下(部分数据集,仅示例)
130+
执行成功运行后,当前目录可以查看到精度结果的生成,文件为 ```result.jsonl```,样例如下
125131
```json
126-
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 365.548, "accuracy": 0.967, "num_requests": 30, "other": {"num_questions": 30, "parallel": 2}}
132+
{"task": "gsm8k", "backend": "paddlepaddle", "num_gpus": 1, "latency": 13446.01, "accuracy": 0.956, "num_requests": 1319, "other": {"num_questions": 1319, "parallel": 8}}
127133
```
128134

fastdeploy/model_executor/layers/activation.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import paddle
2121
from paddle import nn
22-
from paddle.incubate.nn.functional import fused_bias_act
22+
from paddle.incubate.nn.functional import fused_bias_act, swiglu
2323

2424
from fastdeploy.config import FDConfig
2525
from fastdeploy.platforms import current_platform
@@ -66,6 +66,8 @@ def __init__(
6666
if current_platform.is_cuda() or current_platform.is_xpu(
6767
) or current_platform.is_iluvatar():
6868
self.forward = self.forward_cuda
69+
elif current_platform.is_gcu():
70+
self.forward = self.forward_gcu
6971
else:
7072
raise NotImplementedError
7173

@@ -123,3 +125,18 @@ def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
123125
quant_max_bound=self.quant_max_bound,
124126
quant_min_bound=self.quant_min_bound,
125127
)
128+
129+
def forward_gcu(self, x):
130+
"""
131+
Forward propagation of the custom activation layer.
132+
133+
Args:
134+
x (Tensor): Input tensor to the activation layer.
135+
136+
Returns:
137+
Tensor: Output tensor.
138+
"""
139+
out = swiglu(x)
140+
if self.bias is not None:
141+
out = out + self.bias
142+
return out

fastdeploy/model_executor/layers/backends/__init__.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,24 @@
1616
all backends methods
1717
"""
1818

19-
from .xpu import *
20-
from .npu import *
19+
from fastdeploy.platforms import current_platform
2120

2221
__all__ = []
23-
from . import npu
24-
if hasattr(npu, '__all__'):
25-
__all__.extend(npu.__all__)
26-
27-
from . import xpu
28-
if hasattr(xpu, '__all__'):
29-
__all__.extend(xpu.__all__)
22+
23+
if current_platform.is_xpu():
24+
from . import xpu
25+
from .xpu import *
26+
if hasattr(xpu, '__all__'):
27+
__all__.extend(xpu.__all__)
28+
29+
if current_platform.is_npu():
30+
from . import npu
31+
from .npu import *
32+
if hasattr(npu, '__all__'):
33+
__all__.extend(npu.__all__)
34+
35+
if current_platform.is_gcu():
36+
from . import gcu
37+
from .gcu import *
38+
if hasattr(gcu, '__all__'):
39+
__all__.extend(gcu.__all__)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
gcu backend methods
17+
"""
18+
19+
from .attention.flash_attn_backend import GCUFlashAttnBackend
20+
from .attention.mem_efficient_attn_backend import GCUMemEfficientAttnBackend
21+
from .moe.fused_moe_method_gcu_backend import (GCUFusedMoeMethod,
22+
GCUWeightOnlyMoEMethod)
23+
from .quantization.weight_only import GCUWeightOnlyLinearMethod
24+
25+
__all__ = [
26+
'GCUFlashAttnBackend',
27+
'GCUMemEfficientAttnBackend',
28+
'GCUFusedMoeMethod',
29+
'GCUWeightOnlyMoEMethod',
30+
'GCUWeightOnlyLinearMethod',
31+
]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .flash_attn_backend import GCUFlashAttnBackend
16+
from .mem_efficient_attn_backend import GCUMemEfficientAttnBackend
17+
18+
__all__ = [
19+
"GCUFlashAttnBackend",
20+
"GCUMemEfficientAttnBackend",
21+
]

0 commit comments

Comments
 (0)