Skip to content

Commit 05c670e

Browse files
[Sync] Update to latest code (#2679)
* [Sync] Update to latest code * Add new code files * Add new code files * update code * Try to fix build.sh * Try to fix build.sh * Update code * Update requirements.txt * Update code --------- Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
1 parent d222248 commit 05c670e

File tree

95 files changed

+9914
-1310
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

95 files changed

+9914
-1310
lines changed

build.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ function build_and_install() {
166166

167167
echo -e "${BLUE}[install]${NONE} installing fastdeploy..."
168168
cd $DIST_DIR
169-
${python} -m pip install ./dist/fastdeploy*.whl --force-reinstall --no-cache-dir
169+
find . -name "fastdeploy*.whl" | xargs ${python} -m pip install --force-reinstall --no-cache-dir
170170
if [ $? -ne 0 ]; then
171171
cd ..
172172
echo -e "${RED}[FAIL]${NONE} install fastdeploy wheel failed"
@@ -228,6 +228,9 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then
228228
${BLUE}fastdeploy branch:${NONE} $EFFLLM_BRANCH ($EFFLLM_COMMIT)\n"
229229

230230
echo -e "${GREEN}wheel saved under${NONE} ${RED}${BOLD}./dist${NONE}"
231+
232+
# install wheel
233+
${python} -m pip install ./dist/fastdeploy*.whl
231234
echo -e "${GREEN}wheel install success${NONE}\n"
232235

233236
trap : 0
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#pragma once
15+
16+
17+
#include "multi_head_latent_attention_kernel.h"
18+
19+
template <size_t vec_size, typename T>
20+
struct softmax_state_t {
21+
AlignedVector<T, vec_size> o;
22+
T m;
23+
T d;
24+
25+
__device__ __forceinline__ void init() {
26+
if constexpr (std::is_same<T, half>::value) {
27+
#pragma unroll
28+
for (int i = 0; i < vec_size / 2; ++i) {
29+
*((half2*)(&o) + i) = make_half2(0, 0);
30+
}
31+
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
32+
#pragma unroll
33+
for (int i = 0; i < vec_size / 2; ++i) {
34+
*((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0);
35+
}
36+
}
37+
d = 1.f;
38+
if constexpr (std::is_same<T, half>::value) {
39+
m = __float2half(-5e4f);
40+
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
41+
m = __float2bfloat16(-3.38953e38f);
42+
}
43+
}
44+
45+
__device__ __forceinline__ softmax_state_t() {
46+
init();
47+
}
48+
49+
__device__ __forceinline__ void merge(const AlignedVector<T, vec_size>& other_o,
50+
T other_m,
51+
T other_d) {
52+
// using kType = typename cascade_attn_nv_type2_traits<T>::type;
53+
T m_prev = m, d_prev = d;
54+
m = m_prev > other_m ? m_prev : other_m;
55+
T scale1 = hexp(m_prev - m), scale2 = hexp(other_m - m);
56+
57+
d = d_prev * scale1 + other_d * scale2;
58+
59+
#pragma unroll
60+
for (size_t i = 0; i < vec_size; ++i) {
61+
o[i] = o[i] * scale1 + other_o[i] * scale2;
62+
}
63+
}
64+
65+
__device__ __forceinline__ void normalize() {
66+
67+
#pragma unroll
68+
for (size_t i = 0; i < vec_size; ++i) {
69+
o[i] /= d;
70+
}
71+
}
72+
73+
};
74+
75+
template <size_t vec_size, typename T, uint32_t num_tiles = 0>
76+
struct softmax_state_ts {
77+
uint32_t num_tiles_ = num_tiles;
78+
AlignedVector<T, vec_size> o[num_tiles];
79+
float m;
80+
float d;
81+
82+
__device__ __forceinline__ void init() {
83+
#pragma unroll
84+
for (uint32_t tile_id = 0; tile_id < num_tiles_; ++tile_id) {
85+
if constexpr (std::is_same<T, half>::value) {
86+
#pragma unroll
87+
for (int i = 0; i < vec_size / 2; ++i) {
88+
*((half2*)(&o[tile_id]) + i) = make_half2(0, 0);
89+
}
90+
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
91+
#pragma unroll
92+
for (int i = 0; i < vec_size / 2; ++i) {
93+
*((nv_bfloat162*)(&o[tile_id]) + i) = make_bfloat162(0, 0);
94+
}
95+
}
96+
}
97+
d = 1.f;
98+
if constexpr (std::is_same<T, half>::value) {
99+
m = -5e4f;
100+
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
101+
m = -3.38953e38f;
102+
}
103+
}
104+
105+
__device__ __forceinline__ softmax_state_ts() {
106+
init();
107+
}
108+
109+
__device__ __forceinline__ void normalize(const uint32_t tile_id) {
110+
111+
#pragma unroll
112+
for (size_t i = 0; i < vec_size; i++) {
113+
o[tile_id][i] /= d;
114+
}
115+
}
116+
117+
};
118+
119+
template <SharedMemFillMode fill_mode, uint32_t HEAD_DIM_QK, uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t BLOCK_SIZE, uint32_t CACHE_VEC_SIZE, typename CacheT>
120+
__device__ __forceinline__ void produce_kv(CacheT *smem,
121+
CacheT *kv_base_gptr,
122+
const int * block_table_smem,
123+
const uint32_t seq_offset_gmem,
124+
const uint32_t seq_offset_smem,
125+
const uint32_t kv_head_idx,
126+
const uint32_t kv_num_heads,
127+
const uint32_t tidx,
128+
const uint32_t chunk_start,
129+
const uint32_t chunk_end) {
130+
int block_id = __ldg(&block_table_smem[seq_offset_gmem / BLOCK_SIZE]);
131+
if (block_id < 0) {
132+
block_id = 0;
133+
}
134+
const uint32_t block_offset = seq_offset_gmem % BLOCK_SIZE;
135+
// 8/16 T/int8 each time
136+
const uint32_t k_offset_base = ((block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE + block_offset) * HEAD_DIM_QK;
137+
const uint32_t smem_offset_base = seq_offset_smem * HEAD_DIM_QK;
138+
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
139+
pred_load<128, PrefetchMode::kPrefetch, fill_mode, CacheT>(
140+
smem + smem_offset_base + vid * CACHE_VEC_SIZE,
141+
kv_base_gptr + k_offset_base + vid * CACHE_VEC_SIZE,
142+
seq_offset_gmem < chunk_end
143+
);
144+
}
145+
}
146+
147+
template <uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t bdy, uint32_t HEAD_DIM, uint32_t DEAL_EACH_TIME, uint32_t num_tile_v, typename T, typename CacheT>
148+
__device__ __forceinline__ void compute_qk(const T* cu_q_smem,
149+
const CacheT* k_smem,
150+
const uint32_t kv_idx_base,
151+
const uint32_t stage_idx,
152+
const uint32_t iter_base,
153+
const uint32_t iter_bound,
154+
const uint32_t tidx,
155+
const uint32_t gid,
156+
const float scale,
157+
float *s,
158+
softmax_state_ts<vec_size, T, num_tile_v>& st) {
159+
const CacheT* smem;
160+
AlignedVector<T, vec_size> q_vec;
161+
AlignedVector<T, vec_size> k_vec;
162+
float m_prev = st.m;
163+
// smem = base_smem + (stage_idx * DEAL_EACH_TIME + zid * tile_size) * HEAD_DIM;
164+
smem = k_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM;
165+
#pragma unroll
166+
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
167+
if (iter_base + j < iter_bound) {
168+
if constexpr (std::is_same<T, half>::value) {
169+
s[j] = 0.f;
170+
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
171+
s[j] = 0.f;
172+
}
173+
#pragma unroll
174+
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
175+
Load<T, vec_size>(cu_q_smem + vid * vec_size, &q_vec);
176+
Load<CacheT, vec_size>(smem + j * HEAD_DIM + vid * vec_size, &k_vec);
177+
for (uint32_t i = 0; i < vec_size; ++i) {
178+
s[j] += static_cast<float>(q_vec[i] * k_vec[i]);
179+
}
180+
}
181+
#pragma unroll
182+
for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
183+
s[j] += __shfl_xor_sync(-1, s[j], offset, 32);
184+
}
185+
__syncthreads();
186+
} else {
187+
if constexpr (std::is_same<T, half>::value) {
188+
s[j] = -5e4f;
189+
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
190+
s[j] = -3.38953e38f;
191+
}
192+
}
193+
st.m = st.m > s[j] ? st.m : s[j];
194+
}
195+
196+
// T o_scale = hexp(m_prev - st.m);
197+
float o_scale = __expf(m_prev - st.m);
198+
st.d *= o_scale;
199+
200+
#pragma unroll
201+
for (uint32_t j = 0; j < DEAL_EACH_TIME; ++j) {
202+
// s[j] = hexp(s[j] - st.m);
203+
s[j] = __expf(s[j] - st.m);
204+
st.d += s[j];
205+
}
206+
#pragma unroll
207+
for (uint32_t tile_id = 0; tile_id < num_tile_v; ++tile_id) {
208+
for (uint32_t i = 0; i < vec_size; ++i) {
209+
st.o[tile_id][i] *= o_scale;
210+
}
211+
}
212+
}
213+
214+
template<uint32_t vec_size, uint32_t NUM_VEC_PER_HEAD, uint32_t bdx, uint32_t DEAL_EACH_TIME, uint32_t HEAD_DIM_QK, uint32_t num_tile, typename T, typename CacheT>
215+
__device__ __forceinline__ void compute_sv(const float *s,
216+
const CacheT *base_v_smem,
217+
const uint32_t stage_idx,
218+
const uint32_t iter_base,
219+
const uint32_t iter_bound,
220+
const uint32_t tidx,
221+
softmax_state_ts<vec_size, T, num_tile>& st) {
222+
const CacheT* v_smem;
223+
AlignedVector<T, vec_size> v_vec;
224+
#pragma unroll
225+
for (int j = 0; (j < DEAL_EACH_TIME) && (iter_base + j < iter_bound); ++j) {
226+
v_smem = base_v_smem + stage_idx * DEAL_EACH_TIME * HEAD_DIM_QK + j * HEAD_DIM_QK;
227+
for(uint32_t vid = tidx; vid < NUM_VEC_PER_HEAD; vid += bdx) {
228+
Load<T, vec_size>(v_smem + vid * vec_size, &v_vec);
229+
uint32_t tile_id = vid / bdx;
230+
#pragma unroll
231+
for (int reg_id = 0; reg_id < vec_size; ++reg_id) {
232+
st.o[tile_id][reg_id] += static_cast<T>(s[j]) * v_vec[reg_id];
233+
}
234+
}
235+
}
236+
}

0 commit comments

Comments
 (0)