Skip to content

Commit 6aaed45

Browse files
authored
Merge branch 'develop' into mm_structred_output
2 parents f6c02a2 + 240d623 commit 6aaed45

File tree

112 files changed

+5000
-1264
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+5000
-1264
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@
88
<a href="https://github.com/PaddlePaddle/FastDeploy/commits"><img src="https://img.shields.io/github/commit-activity/m/PaddlePaddle/FastDeploy?color=3af"></a>
99
<a href="https://github.com/PaddlePaddle/FastDeploy/issues"><img src="https://img.shields.io/github/issues/PaddlePaddle/FastDeploy?color=9cc"></a>
1010
<a href="https://github.com/PaddlePaddle/FastDeploy/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/FastDeploy?color=ccf"></a>
11+
1112
</p>
1213

1314
<p align="center">
15+
<a href="https://trendshift.io/repositories/4046" target="_blank"><img src="https://trendshift.io/api/badge/repositories/4046" alt="PaddlePaddle%2FFastDeploy | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></br>
1416
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/installation/nvidia_gpu/"><b> Installation </b></a>
1517
|
1618
<a href="https://paddlepaddle.github.io/FastDeploy/get_started/quick_start"><b> Quick Start </b></a>
1719
|
18-
<a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a>
20+
<a href="https://paddlepaddle.github.io/FastDeploy/supported_models/"><b> Supported Models </b></a>
21+
1922
</p>
2023

2124
--------------------------------------------------------------------------------

build.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ BUILD_WHEEL=${1:-1}
1818
PYTHON_VERSION=${2:-"python"}
1919
export python=$PYTHON_VERSION
2020
FD_CPU_USE_BF16=${3:-"false"}
21+
# FD_BUILDING_ARCS: Specify target CUDA architectures for custom ops, e.g., "[80, 90, 100]".
22+
# For SM90 (Hopper), use 90. For SM100 (Blackwell), use 100.
23+
# These will be translated to 90a / 100a in setup_ops.py for specific features.
2124
FD_BUILDING_ARCS=${4:-""}
2225

2326

@@ -74,8 +77,10 @@ function copy_ops(){
7477
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
7578
if [ "$is_rocm" = "True" ]; then
7679
DEVICE_TYPE="rocm"
80+
mkdir -p ../fastdeploy/model_executor/ops/base
81+
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
7782
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
78-
echo -e "ROCM ops have been copy to fastdeploy"
83+
echo -e "BASE and ROCM ops have been copy to fastdeploy"
7984
return
8085
fi
8186
mkdir -p ../fastdeploy/model_executor/ops/base

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
158158
const paddle::Tensor &input, const paddle::Tensor &scale,
159159
const paddle::Tensor &topk_ids, const paddle::Tensor &topk_weights,
160160
const paddle::Tensor &token_nums_per_expert,
161-
const paddle::Tensor &token_nums_per_expert_padded);
161+
const paddle::Tensor &token_nums_per_expert_padded,
162+
const bool use_in_ep, const int token_nums_this_rank_padded);
162163

163164
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
164165
const int block_size);
@@ -492,6 +493,31 @@ paddle::Tensor FusedHadamardQuantFp8Func(
492493
const float scale);
493494
#endif
494495

496+
int64_t init_custom_all_reduce(const std::vector<int64_t>& fake_ipc_ptrs,
497+
paddle::Tensor& rank_data, int64_t rank, bool full_nvlink);
498+
499+
void all_reduce(int64_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
500+
int64_t reg_buffer, int64_t reg_buffer_sz_bytes);
501+
502+
void dispose(int64_t _fa);
503+
504+
int64_t meta_size();
505+
506+
void register_buffer(int64_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
507+
508+
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(int64_t _fa);
509+
510+
void register_graph_buffers(int64_t _fa,
511+
const std::vector<std::vector<int64_t>>& handles,
512+
const std::vector<std::vector<int64_t>>& offsets);
513+
514+
std::tuple<int64_t, paddle::Tensor> allocate_shared_buffer_and_handle(
515+
int64_t size);
516+
517+
int64_t open_mem_handle(paddle::Tensor& mem_handle);
518+
519+
void free_shared_buffer(int64_t buffer);
520+
495521
PYBIND11_MODULE(fastdeploy_ops, m) {
496522

497523
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
@@ -784,4 +810,24 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
784810
m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func,
785811
py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function");
786812
#endif
813+
814+
m.def("init_custom_all_reduce", &init_custom_all_reduce, "init all reduce class function");
815+
816+
m.def("all_reduce", &all_reduce, "all reduce function");
817+
818+
m.def("dispose", &dispose, "del function for python");
819+
820+
m.def("meta_size", &meta_size, "meta_size function for Signal struct");
821+
822+
m.def("register_buffer", &register_buffer, "register ipc buffer");
823+
824+
m.def("register_graph_buffers", &register_graph_buffers, "register_graph_buffers");
825+
826+
m.def("allocate_shared_buffer_and_handle", &allocate_shared_buffer_and_handle, "allocate_shared_buffer_and_handle");
827+
828+
m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer");
829+
830+
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");
831+
832+
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");
787833
}
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
2+
3+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License");
6+
// you may not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
17+
#include "helper.h"
18+
#include "all_reduce.cuh"
19+
20+
// Fake pointer type, must match fptr_t type in ops.h.
21+
// We use this type alias to indicate when pointers are passed in as int64_t.
22+
using fptr_t = int64_t;
23+
static_assert(sizeof(void*) == sizeof(fptr_t));
24+
25+
fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
26+
paddle::Tensor& rank_data, int64_t rank,
27+
bool full_nvlink) {
28+
int world_size = fake_ipc_ptrs.size();
29+
if (world_size > 8)
30+
throw std::invalid_argument("world size > 8 is not supported");
31+
if (world_size % 2 != 0)
32+
throw std::invalid_argument("Odd num gpus is not supported for now");
33+
if (rank < 0 || rank >= world_size)
34+
throw std::invalid_argument("invalid rank passed in");
35+
36+
paddle::Signal* ipc_ptrs[8];
37+
for (int i = 0; i < world_size; i++) {
38+
ipc_ptrs[i] = reinterpret_cast<paddle::Signal*>(fake_ipc_ptrs[i]);
39+
}
40+
return (fptr_t) new paddle::CustomAllreduce(ipc_ptrs, rank_data.data(),
41+
rank_data.numel(), rank, world_size,
42+
full_nvlink);
43+
}
44+
45+
/**
46+
* Performs an out-of-place allreduce and stores result in out.
47+
*
48+
* If _reg_buffer is null, assumes inp.data() is already IPC-registered.
49+
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
50+
* copied into _reg_buffer.
51+
*/
52+
void all_reduce(fptr_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
53+
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
54+
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
55+
auto stream = inp.stream();
56+
57+
auto input_size = inp.numel() * 2;
58+
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
59+
if (reg_buffer) {
60+
cudaMemcpyAsync(reg_buffer, inp.data(), input_size,
61+
cudaMemcpyDeviceToDevice, stream);
62+
} else {
63+
reg_buffer = inp.data();
64+
}
65+
switch (out.dtype()) {
66+
case phi::DataType::FLOAT32: {
67+
fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
68+
reinterpret_cast<float*>(out.data()),
69+
out.numel());
70+
break;
71+
}
72+
case phi::DataType::FLOAT16: {
73+
fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
74+
reinterpret_cast<half*>(out.data()), out.numel());
75+
break;
76+
}
77+
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800)
78+
case phi::DataType::BFLOAT16: {
79+
fa->allreduce<nv_bfloat16>(
80+
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
81+
reinterpret_cast<nv_bfloat16*>(out.data()), out.numel());
82+
break;
83+
}
84+
#endif
85+
default:
86+
throw std::runtime_error(
87+
"custom allreduce only supports float32, float16 and bfloat16");
88+
}
89+
}
90+
91+
void dispose(fptr_t _fa) {
92+
delete reinterpret_cast<paddle::CustomAllreduce*>(_fa);
93+
}
94+
95+
int64_t meta_size() { return sizeof(paddle::Signal); }
96+
97+
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
98+
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
99+
void* ipc_ptrs[8];
100+
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
101+
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
102+
}
103+
fa->register_buffer(ipc_ptrs);
104+
}
105+
106+
// Use vector<int64_t> to represent byte data for python binding compatibility.
107+
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
108+
get_graph_buffer_ipc_meta(fptr_t _fa) {
109+
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
110+
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
111+
std::vector<int64_t> bytes(handle.begin(), handle.end());
112+
return std::make_tuple(bytes, offsets);
113+
}
114+
115+
// Use vector<int64_t> to represent byte data for python binding compatibility.
116+
void register_graph_buffers(fptr_t _fa,
117+
const std::vector<std::vector<int64_t>>& handles,
118+
const std::vector<std::vector<int64_t>>& offsets) {
119+
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
120+
std::vector<std::string> bytes;
121+
bytes.reserve(handles.size());
122+
for (int i = 0; i < handles.size(); i++) {
123+
bytes.emplace_back(handles[i].begin(), handles[i].end());
124+
}
125+
bytes.reserve(handles.size());
126+
fa->register_graph_buffers(bytes, offsets);
127+
}
128+
129+
std::tuple<fptr_t, paddle::Tensor> allocate_shared_buffer_and_handle(
130+
int64_t size) {
131+
132+
auto device_index = phi::backends::gpu::GetCurrentDeviceId();
133+
void* buffer;
134+
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
135+
auto stream = paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream();
136+
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
137+
138+
// Allocate buffer
139+
CUDACHECK(cudaMalloc((void**)&buffer, size));
140+
CUDACHECK(cudaMemsetAsync(buffer, 0, size, stream));
141+
CUDACHECK(cudaStreamSynchronize(stream));
142+
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
143+
144+
// Create IPC memhandle for the allocated buffer.
145+
// Will use it in open_mem_handle.
146+
auto handle =
147+
paddle::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, paddle::DataType::UINT8, paddle::GPUPlace(device_index));
148+
CUDACHECK(
149+
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data(), buffer));
150+
151+
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
152+
}
153+
154+
155+
fptr_t open_mem_handle(paddle::Tensor& mem_handle) {
156+
void* ipc_ptr;
157+
CUDACHECK(cudaIpcOpenMemHandle(
158+
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data()),
159+
cudaIpcMemLazyEnablePeerAccess));
160+
return reinterpret_cast<fptr_t>(ipc_ptr);
161+
}
162+
163+
void free_shared_buffer(fptr_t buffer) {
164+
CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer)));
165+
}

0 commit comments

Comments
 (0)