Skip to content

Commit 1f28bdf

Browse files
lifullllifuyongqiangma
authored
dcu adapter ernie45t (#2756)
Co-authored-by: lifu <lifu@sugon.com> Co-authored-by: yongqiangma <xing.wo@163.com>
1 parent 03a7499 commit 1f28bdf

30 files changed

+1133
-41
lines changed

build.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@ function copy_ops(){
7777
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
7878
if [ "$is_rocm" = "True" ]; then
7979
DEVICE_TYPE="rocm"
80+
mkdir -p ../fastdeploy/model_executor/ops/base
81+
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
8082
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
81-
echo -e "ROCM ops have been copy to fastdeploy"
83+
echo -e "BASE and ROCM ops have been copy to fastdeploy"
8284
return
8385
fi
8486
mkdir -p ../fastdeploy/model_executor/ops/base

custom_ops/gpu_ops/helper.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,19 @@ HOSTDEVICE inline void Store(const AlignedVector<T, Size> &vec, T *addr) {
214214
*addr_vec = vec;
215215
}
216216

217+
#ifdef PADDLE_WITH_HIP
218+
template <int Size>
219+
HOSTDEVICE inline void Store(const AlignedVector<hip_bfloat16, Size> &vec,
220+
int8_t *addr) {
221+
printf("Error: Store hip_bfloat16 to int8_t is not supported!");
222+
}
223+
#else
217224
template <int Size>
218225
HOSTDEVICE inline void Store(const AlignedVector<__nv_bfloat16, Size> &vec,
219226
int8_t *addr) {
220227
printf("Error: Store __nv_bfloat16 to int8_t is not supported!");
221228
}
229+
#endif
222230

223231
template <int Size>
224232
HOSTDEVICE inline void Store(const AlignedVector<half, Size> &vec,
@@ -478,7 +486,12 @@ template <typename T>
478486
static void PrintMatrix3(const T *mat_d, int num, std::string name) {
479487

480488
std::vector<T> tmp(num);
489+
#ifdef PADDLE_WITH_HIP
490+
hipMemcpy(tmp.data(), mat_d, sizeof(T) * num, hipMemcpyDeviceToHost);
491+
#else
481492
cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost);
493+
#endif
494+
482495

483496
std::ofstream outfile;
484497
outfile.open(name + ".txt", std::ios::out);
@@ -495,6 +508,7 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) {
495508
outfile.close();
496509
}
497510

511+
#ifndef PADDLE_WITH_HIP
498512
__forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr,
499513
int mode = 0) {
500514
uint32_t flag;
@@ -534,6 +548,7 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
534548
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
535549
return max_shared_mem_per_block_opt_in;
536550
}
551+
#endif
537552

538553
inline int GetSMVersion() {
539554
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(

custom_ops/gpu_ops/set_data_ipc.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,12 @@ void set_data_ipc(const paddle::Tensor& tmp_input,
9191
memset((void *)shm, 0, sizeof(*shm));
9292

9393
void *data_ptr_now = reinterpret_cast<void*>(const_cast<data_t*>(tmp_input.data<data_t>()));
94+
#ifdef PADDLE_WITH_HIP
95+
checkCudaErrors(hipIpcGetMemHandle((hipIpcMemHandle_t *)&shm->memHandle, data_ptr_now));
96+
#else
9497
checkCudaErrors(cudaIpcGetMemHandle((cudaIpcMemHandle_t *)&shm->memHandle, data_ptr_now));
98+
#endif
99+
95100

96101
}
97102

custom_ops/gpu_ops/share_external_data.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,18 @@ std::vector<paddle::Tensor> ShareExternalData(paddle::Tensor& input,
3737
}
3838
shm = (volatile shmStruct *)info.addr;
3939
void *ptr = nullptr;
40+
#ifdef PADDLE_WITH_HIP
41+
checkCudaErrors(
42+
hipIpcOpenMemHandle(&ptr,
43+
*(hipIpcMemHandle_t *)&shm->memHandle, // NOLINT
44+
hipIpcMemLazyEnablePeerAccess));
45+
#else
4046
checkCudaErrors(
4147
cudaIpcOpenMemHandle(&ptr,
4248
*(cudaIpcMemHandle_t *)&shm->memHandle, // NOLINT
4349
cudaIpcMemLazyEnablePeerAccess));
50+
#endif
51+
4452
paddle::Tensor tmp_tensor = paddle::from_blob(
4553
ptr,
4654
shape,

custom_ops/setup_ops.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -187,39 +187,45 @@ def find_end_files(directory, end_str):
187187
if paddle.is_compiled_with_rocm():
188188
# NOTE(@duanyanhui): paddle.is_compiled_with_cuda() returns True when paddle compiled with rocm.
189189
# so we need to check if paddle compiled with rocm at first.
190+
json_dir = "third_party/nlohmann_json"
191+
if not os.path.exists(json_dir) or not os.listdir(json_dir):
192+
if not os.path.exists(json_dir):
193+
os.makedirs(json_dir)
194+
clone_git_repo("v3.11.3", "https://bgithub.xyz/nlohmann/json.git", json_dir)
195+
if not os.listdir(json_dir):
196+
raise ValueError("Git clone nlohmann_json failed!")
197+
sources=[
198+
"gpu_ops/set_value_by_flags.cu",
199+
"gpu_ops/token_penalty_multi_scores.cu",
200+
"gpu_ops/stop_generation.cu",
201+
"gpu_ops/stop_generation_multi_ends.cu",
202+
"gpu_ops/get_padding_offset.cu",
203+
"gpu_ops/update_inputs.cu",
204+
"gpu_ops/rebuild_padding.cu",
205+
"gpu_ops/step.cu",
206+
"gpu_ops/set_data_ipc.cu",
207+
"gpu_ops/moe/tritonmoe_preprocess.cu",
208+
"gpu_ops/step_system_cache.cu",
209+
"gpu_ops/get_output_ep.cc",
210+
"gpu_ops/speculate_decoding/speculate_get_padding_offset.cu",
211+
"gpu_ops/speculate_decoding/speculate_get_output.cc",
212+
"gpu_ops/share_external_data.cu",
213+
"gpu_ops/speculate_decoding/speculate_clear_accept_nums.cu",
214+
"gpu_ops/speculate_decoding/speculate_get_output_padding_offset.cu",
215+
"gpu_ops/speculate_decoding/speculate_get_seq_lens_output.cu",
216+
"gpu_ops/speculate_decoding/speculate_save_output.cc",
217+
"gpu_ops/speculate_decoding/speculate_set_value_by_flags.cu",
218+
"gpu_ops/speculate_decoding/speculate_step.cu",
219+
"gpu_ops/speculate_decoding/speculate_step_system_cache.cu",
220+
"gpu_ops/speculate_decoding/speculate_update_v3.cu",
221+
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
222+
"gpu_ops/fused_rotary_position_encoding.cu",
223+
"gpu_ops/step_reschedule.cu",
224+
]
190225
setup(
191226
name="fastdeploy_ops",
192227
ext_modules=CUDAExtension(
193-
sources=[
194-
"gpu_ops/save_with_output.cc",
195-
"gpu_ops/set_mask_value.cu",
196-
"gpu_ops/set_value_by_flags.cu",
197-
"gpu_ops/ngram_mask.cu",
198-
"gpu_ops/gather_idx.cu",
199-
"gpu_ops/token_penalty_multi_scores.cu",
200-
"gpu_ops/token_penalty_only_once.cu",
201-
"gpu_ops/stop_generation.cu",
202-
"gpu_ops/stop_generation_multi_ends.cu",
203-
"gpu_ops/stop_generation_multi_stop_seqs.cu",
204-
"gpu_ops/set_flags.cu",
205-
"gpu_ops/fused_get_rope.cu",
206-
"gpu_ops/transfer_output.cc",
207-
"gpu_ops/get_padding_offset.cu",
208-
"gpu_ops/update_inputs.cu",
209-
"gpu_ops/update_inputs_beam.cu",
210-
"gpu_ops/beam_search_softmax.cu",
211-
"gpu_ops/rebuild_padding.cu",
212-
"gpu_ops/save_with_output_msg.cc",
213-
"gpu_ops/get_output.cc",
214-
"gpu_ops/get_output_msg_with_topk.cc",
215-
"gpu_ops/step.cu",
216-
"gpu_ops/step_reschedule.cu",
217-
"gpu_ops/set_data_ipc.cu",
218-
"gpu_ops/read_data_ipc.cu",
219-
"gpu_ops/dequant_int8.cu",
220-
"gpu_ops/enforce_generation.cu",
221-
"gpu_ops/tune_cublaslt_gemm.cu",
222-
],
228+
sources=sources,
223229
extra_compile_args={
224230
"cxx": ["-O3"],
225231
"hipcc": [
@@ -231,6 +237,9 @@ def find_end_files(directory, end_str):
231237
"-U__HIP_NO_BFLOAT16_CONVERSIONS__",
232238
"-U__HIP_NO_BFLOAT162_OPERATORS__",
233239
"-U__HIP_NO_BFLOAT162_CONVERSIONS__",
240+
"-DPADDLE_DEV",
241+
"-Ithird_party/nlohmann_json/include",
242+
"-Igpu_ops",
234243
],
235244
},
236245
),

docs/get_started/installation/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ FastDeploy currently supports installation on the following hardware platforms:
66
- [Kunlun XPU Installation](kunlunxin_xpu.md)
77
- [Enflame S60 GCU Installation](Enflame_gcu.md)
88
- [Iluvatar GPU Installation](iluvatar_gpu.md)
9+
- [Hygon DCU Installation](hygon_dcu.md)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Run ERNIE-4.5-300B-A47B & ERNIE-4.5-21B-A3B model on hygon machine
2+
The current version of the software merely serves as a demonstration demo for the hygon k100AI combined with the Fastdeploy inference framework for large models. There may be issues when running the latest ERNIE4.5 model, and we will conduct repairs and performance optimization in the future. Subsequent versions will provide customers with a more stable version.
3+
4+
## Requirements
5+
Firstly, you need to prepare a machine with the following configuration
6+
- OS:Linux
7+
- Python:3.10
8+
- Memory: 2T
9+
- Disk: 4T
10+
- DCU Model:K100AI
11+
- DCU Driver Version:≥ 6.3.8-V1.9.2
12+
13+
## 1. Set up using Docker (Recommended)
14+
15+
```bash
16+
mkdir Work
17+
cd Work
18+
docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04-py3.10
19+
20+
docker run -it \
21+
--network=host \
22+
--name=ernie45t \
23+
--privileged \
24+
--device=/dev/kfd \
25+
--device=/dev/dri \
26+
--ipc=host \
27+
--shm-size=16G \
28+
--group-add video \
29+
--cap-add=SYS_PTRACE \
30+
--security-opt seccomp=unconfined \
31+
-u root \
32+
--ulimit stack=-1:-1 \
33+
--ulimit memlock=-1:-1 \
34+
-v `pwd`:/home \
35+
-v /opt/hyhal:/opt/hyhal:ro \
36+
image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04-py3.10 /bin/bash
37+
```
38+
39+
## 2. Start service
40+
```bash
41+
export FD_ATTENTION_BACKEND="BLOCK_ATTN"
42+
python -m fastdeploy.entrypoints.openai.api_server \
43+
--model "/models/ERNIE-45-Turbo/ERNIE-4.5-300B-A47B-Paddle/" \
44+
--port 8188 \
45+
--tensor-parallel-size 8 \
46+
--quantization=wint8 \
47+
--gpu-memory-utilization=0.8
48+
```
49+
50+
#### Send requests
51+
52+
Send requests using either curl or Python
53+
54+
```bash
55+
curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
56+
-H "Content-Type: application/json" \
57+
-d '{
58+
"messages": [
59+
{"role": "user", "content": "Where is the capital of China?"}
60+
]
61+
}'
62+
```
63+
64+
```python
65+
import openai
66+
67+
ip = "0.0.0.0"
68+
service_http_port = "8188"
69+
client = openai.Client(base_url=f"http://{ip}:{service_http_port}/v1", api_key="EMPTY_API_KEY")
70+
71+
response = client.chat.completions.create(
72+
model="default",
73+
messages=[
74+
{"role": "user", "content": "Eliza's rate per hour for the first 40 hours she works each week is $10. She also receives an overtime pay of 1.2 times her regular hourly rate. If Eliza worked for 45 hours this week, how much are her earnings for this week?"},
75+
],
76+
temperature=1,
77+
max_tokens=1024,
78+
stream=False,
79+
)
80+
print(response)
81+
```

docs/zh/get_started/installation/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ FastDeploy currently supports installation on the following hardware platforms:
66
- [Kunlunxin XPU Installation](kunlunxin_xpu.md)
77
- [Enflame S60 GCU Installation](Enflame_gcu.md)
88
- [Iluvatar GPU Installation](iluvatar_gpu.md)
9+
- [Hygon DCU Installation](hygon_dcu.md)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# 使用 FastDeploy 在海光 K100AI 上运行 ERNIE-4.5-300B-A47B & ERNIE-4.5-21B-A3B
2+
当前版本软件只是作为K100AI + Fastdeploy 推理大模型的一个演示 demo,跑最新ERNIE4.5模型可能存在问题,后续进行修复和性能优化,给客户提供一个更稳定的版本。
3+
4+
## 准备机器
5+
首先您需要准备以下配置的机器
6+
- OS:Linux
7+
- Python:3.10
8+
- 内存:2T
9+
- 磁盘:4T
10+
- DCU 型号:K100AI
11+
- DCU 驱动版本:≥ 6.3.8-V1.9.2
12+
13+
## 1. 使用 Docker 安装(推荐)
14+
15+
```bash
16+
mkdir Work
17+
cd Work
18+
docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04-py3.10
19+
20+
docker run -it \
21+
--network=host \
22+
--name=ernie45t \
23+
--privileged \
24+
--device=/dev/kfd \
25+
--device=/dev/dri \
26+
--ipc=host \
27+
--shm-size=16G \
28+
--group-add video \
29+
--cap-add=SYS_PTRACE \
30+
--security-opt seccomp=unconfined \
31+
-u root \
32+
--ulimit stack=-1:-1 \
33+
--ulimit memlock=-1:-1 \
34+
-v `pwd`:/home \
35+
-v /opt/hyhal:/opt/hyhal:ro \
36+
image.sourcefind.cn:5000/dcu/admin/base/custom:fastdeploy2.0.0-kylinv10-dtk25.04-py3.10 /bin/bash
37+
```
38+
39+
## 2. 启动服务
40+
```bash
41+
export FD_ATTENTION_BACKEND="BLOCK_ATTN"
42+
python -m fastdeploy.entrypoints.openai.api_server \
43+
--model "/models/ERNIE-45-Turbo/ERNIE-4.5-300B-A47B-Paddle/" \
44+
--port 8188 \
45+
--tensor-parallel-size 8 \
46+
--quantization=wint8 \
47+
--gpu-memory-utilization=0.8
48+
```
49+
50+
#### 请求服务
51+
52+
您可以基于 OpenAI 协议,通过 curl 和 python 两种方式请求服务。
53+
54+
```bash
55+
curl -X POST "http://0.0.0.0:8188/v1/chat/completions" \
56+
-H "Content-Type: application/json" \
57+
-d '{
58+
"messages": [
59+
{"role": "user", "content": "Where is the capital of China?"}
60+
]
61+
}'
62+
```
63+
64+
```python
65+
import openai
66+
67+
ip = "0.0.0.0"
68+
service_http_port = "8188"
69+
client = openai.Client(base_url=f"http://{ip}:{service_http_port}/v1", api_key="EMPTY_API_KEY")
70+
71+
response = client.chat.completions.create(
72+
model="default",
73+
messages=[
74+
{"role": "user", "content": "Eliza's rate per hour for the first 40 hours she works each week is $10. She also receives an overtime pay of 1.2 times her regular hourly rate. If Eliza worked for 45 hours this week, how much are her earnings for this week?"},
75+
],
76+
temperature=1,
77+
max_tokens=1024,
78+
stream=False,
79+
)
80+
print(response)
81+
```

fastdeploy/model_executor/layers/attention/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
from .native_paddle_backend import PaddleNativeAttnBackend
2121
from .xpu_attn_backend import XPUAttentionBackend
2222
from .iluvatar_attn_backend import IluvatarAttnBackend
23+
from .block_multihead_attn_backend import BlockAttentionBackend
2324

2425
__all__ = [
2526
"AttentionBackend", "PaddleNativeAttnBackend",
2627
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend",
27-
"MLAAttentionBackend", "FlashAttentionBackend", "IluvatarAttnBackend"
28+
"MLAAttentionBackend", "FlashAttentionBackend", "IluvatarAttnBackend",
29+
"BlockAttentionBackend"
2830
]

0 commit comments

Comments
 (0)