Skip to content

Commit 240d623

Browse files
authored
[Fix]fix top_k_top_p sampling (#2801)
* fix topk-topp * update * add base_non_truncated
1 parent 5907126 commit 240d623

File tree

8 files changed

+23
-123
lines changed

8 files changed

+23
-123
lines changed

docs/usage/environment_variables.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
5252
"FD_ATTENTION_BACKEND":
5353
lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"),
5454

55-
# Sampling class ("base", "air", or "rejection")
55+
# Sampling class ("base", "base_non_truncated", "air", or "rejection")
5656
"FD_SAMPLING_CLASS":
5757
lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
5858

docs/zh/usage/environment_variables.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
5151
"FD_ATTENTION_BACKEND":
5252
lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"),
5353

54-
# 设置采样类别,当前可设置为 "base"、"air" 或 "rejection"
54+
# 设置采样类别,当前可设置为 "base"、"base_non_truncated"、"air" 或 "rejection"
5555
"FD_SAMPLING_CLASS":
5656
lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
5757

fastdeploy/envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
"FD_ATTENTION_BACKEND":
7575
lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"),
7676

77-
# Set sampling class. "base", "air" and "rejection" can be set currently.
77+
# Set sampling class. "base", "base_non_truncated", "air" and "rejection" can be set currently.
7878
"FD_SAMPLING_CLASS":
7979
lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
8080

fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ def top_k_top_p_sampling(
7171
elif top_p_class == "rejection":
7272
ids = rejection_top_p_sampling(x, top_p, top_k, seed, order)
7373
_ = None
74+
elif top_p_class == "base_non_truncated":
75+
_, ids = paddle.tensor.top_p_sampling(x,
76+
top_p,
77+
threshold=threshold,
78+
topp_seed=topp_seed,
79+
seed=seed,
80+
k=k,
81+
mode="non-truncated")
7482
else:
7583
if current_platform.is_gcu():
7684
_, ids = gcu_top_p_sampling(x, top_p)
@@ -81,7 +89,7 @@ def top_k_top_p_sampling(
8189
topp_seed=topp_seed,
8290
seed=seed,
8391
k=k,
84-
mode=mode)
92+
mode="truncated")
8593
return _, ids
8694

8795

@@ -109,26 +117,25 @@ def air_top_p_sampling(
109117
def rejection_top_p_sampling(
110118
x: paddle.Tensor,
111119
top_p: paddle.Tensor,
112-
top_k: Optional[paddle.Tensor] = None,
120+
top_k: paddle.Tensor,
113121
seed: int = -1,
114122
order: Literal['top_k_first', 'joint'] = "top_k_first",
115123
) -> paddle.Tensor:
116124
"""
117125
rejection_top_p_sampling
118126
"""
119-
assert top_p is not None, "Top_p should not be none when FD_SAMPLING_CLASS is rejection"
120127
try:
121128
from fastdeploy.model_executor.ops.gpu import (
122129
rejection_top_p_sampling, top_k_renorm_probs)
123130

124-
if top_k is None:
131+
if paddle.count_nonzero(top_k) == 0:
125132
ids = rejection_top_p_sampling(
126133
x,
127134
top_p,
128135
None,
129136
seed,
130137
)
131-
elif top_k is not None and top_p is not None:
138+
else:
132139
if order == "top_k_first":
133140
renorm_probs = top_k_renorm_probs(x, top_k)
134141
ids = rejection_top_p_sampling(
@@ -144,10 +151,6 @@ def rejection_top_p_sampling(
144151
top_k,
145152
seed,
146153
)
147-
else:
148-
raise ValueError(
149-
"Top_p cannot be none."
150-
)
151154
except ImportError:
152155
raise RuntimeError("Cannot import rejection_top_p_sampling op.")
153156
return ids

fastdeploy/worker/gcu_model_runner.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -155,29 +155,12 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
155155
-1].disaggregate_info["role"] == "prefill":
156156
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
157157

158-
top_k_reqs = []
159-
top_p_reqs = []
160-
max_num_seqs = self.parallel_config.max_num_seqs
161-
top_p_buffer = paddle.full([max_num_seqs, 1],
162-
self.model_config.top_p,
163-
dtype='float32')
164-
top_k_buffer = paddle.full([max_num_seqs, 1],
165-
0,
166-
dtype='int64')
167-
168158
req_len = len(req_dicts)
169159
for i in range(req_len):
170160
request = req_dicts[i]
171161
idx = request.idx
172162
length = len(request.prompt_token_ids)
173163

174-
if sampling_params := request.sampling_params:
175-
if sampling_params.top_p < 1:
176-
top_p_reqs.append(idx)
177-
top_k = sampling_params.top_k
178-
if top_k > 0:
179-
top_k_reqs.append(idx)
180-
181164
prefill_tokens = []
182165
if (request.guided_json is not None
183166
or request.guided_regex is not None
@@ -252,8 +235,8 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
252235
request.eos_token_ids.append(request.eos_token_ids[0])
253236
self.share_inputs["eos_token_id"][:] = np.array(
254237
request.eos_token_ids, dtype="int64").reshape(-1, 1)
255-
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
256-
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
238+
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 1.0)
239+
self.share_inputs["top_k"][idx:idx + 1] = request.get("top_k", 0)
257240
self.share_inputs["temperature"][idx:idx + 1] = request.get(
258241
"temperature", 0.95)
259242
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
@@ -304,16 +287,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
304287
if self.speculative_method in ["mtp"]:
305288
self.proposer.insert_prefill_inputs(req_dicts)
306289

307-
if len(top_k_reqs) == 0:
308-
self.share_inputs["top_k"] = None
309-
else:
310-
self.share_inputs["top_k"] = top_k_buffer
311-
312-
if len(top_p_reqs) == 0:
313-
self.share_inputs["top_p"] = None
314-
else:
315-
self.share_inputs["top_p"] = top_p_buffer
316-
317290
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
318291
expected_decode_len: int):
319292
""" Set dummy prefill inputs to share_inputs """

fastdeploy/worker/gpu_model_runner.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -164,29 +164,13 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
164164
-1].disaggregate_info["role"] == "prefill":
165165
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
166166

167-
top_k_reqs = []
168-
top_p_reqs = []
169-
max_num_seqs = self.parallel_config.max_num_seqs
170-
top_p_buffer = paddle.full([max_num_seqs, 1],
171-
self.model_config.top_p,
172-
dtype='float32')
173-
top_k_buffer = paddle.full([max_num_seqs, 1],
174-
0,
175-
dtype='int64')
176167
req_len = len(req_dicts)
177168
for i in range(req_len):
178169
request = req_dicts[i]
179170
idx = request.idx
180171
length = len(request.prompt_token_ids)
181172
assert length > 0, "The prompt requested must not be empty."
182173

183-
if sampling_params := request.sampling_params:
184-
if sampling_params.top_p < 1:
185-
top_p_reqs.append(idx)
186-
top_k = sampling_params.top_k
187-
if top_k > 0:
188-
top_k_reqs.append(idx)
189-
190174
prefill_tokens = []
191175
if (request.guided_json is not None
192176
or request.guided_regex is not None
@@ -261,8 +245,8 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
261245
request.eos_token_ids.append(request.eos_token_ids[0])
262246
self.share_inputs["eos_token_id"][:] = np.array(
263247
request.eos_token_ids, dtype="int64").reshape(-1, 1)
264-
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
265-
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
248+
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 1.0)
249+
self.share_inputs["top_k"][idx:idx + 1] = request.get("top_k", 0)
266250
self.share_inputs["temperature"][idx:idx + 1] = request.get(
267251
"temperature", 0.95)
268252
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
@@ -313,16 +297,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
313297
if self.speculative_method in ["mtp"]:
314298
self.proposer.insert_prefill_inputs(req_dicts)
315299

316-
if len(top_k_reqs) == 0:
317-
self.share_inputs["top_k"] = None
318-
else:
319-
self.share_inputs["top_k"] = top_k_buffer
320-
321-
if len(top_p_reqs) == 0:
322-
self.share_inputs["top_p"] = None
323-
else:
324-
self.share_inputs["top_p"] = top_p_buffer
325-
326300
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
327301
expected_decode_len: int):
328302
""" Set dummy prefill inputs to share_inputs """

fastdeploy/worker/iluvatar_model_runner.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -144,29 +144,12 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
144144
-1].disaggregate_info["role"] == "prefill":
145145
os.environ['PREFILL_NODE_ONE_STEP_STOP'] = "1"
146146

147-
top_k_reqs = []
148-
top_p_reqs = []
149-
max_num_seqs = self.parallel_config.max_num_seqs
150-
top_p_buffer = paddle.full([max_num_seqs, 1],
151-
self.model_config.top_p,
152-
dtype='float32')
153-
top_k_buffer = paddle.full([max_num_seqs, 1],
154-
0,
155-
dtype='int64')
156-
157147
req_len = len(req_dicts)
158148
for i in range(req_len):
159149
request = req_dicts[i]
160150
idx = request.idx
161151
length = len(request.prompt_token_ids)
162152

163-
if sampling_params := request.sampling_params:
164-
if sampling_params.top_p < 1:
165-
top_p_reqs.append(idx)
166-
top_k = sampling_params.top_k
167-
if top_k > 0:
168-
top_k_reqs.append(idx)
169-
170153
prefill_tokens = []
171154
if (request.guided_json is not None
172155
or request.guided_regex is not None
@@ -241,8 +224,8 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
241224
request.eos_token_ids.append(request.eos_token_ids[0])
242225
self.share_inputs["eos_token_id"][:] = np.array(
243226
request.eos_token_ids, dtype="int64").reshape(-1, 1)
244-
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
245-
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
227+
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 1.0)
228+
self.share_inputs["top_k"][idx:idx + 1] = request.get("top_k", 0)
246229
self.share_inputs["temperature"][idx:idx + 1] = request.get(
247230
"temperature", 0.95)
248231
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
@@ -289,15 +272,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request]):
289272
idx, request.get("logits_processor"), prefill_tokens)
290273

291274
self.share_inputs["not_need_stop"][0] = True
292-
if len(top_k_reqs) == 0:
293-
self.share_inputs["top_k"] = None
294-
else:
295-
self.share_inputs["top_k"] = top_k_buffer
296-
297-
if len(top_p_reqs) == 0:
298-
self.share_inputs["top_p"] = None
299-
else:
300-
self.share_inputs["top_p"] = top_p_buffer
301275

302276
def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
303277
expected_decode_len: int):

fastdeploy/worker/xpu_model_runner.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -282,26 +282,11 @@ def __init__(self, fd_config: FDConfig, device: str, rank: int,
282282

283283
def process_prefill_inputs(self, req_dicts: List[Request]):
284284
""" Process inputs for prefill tasks and update share_inputs buffer """
285-
top_k_reqs = []
286-
top_p_reqs = []
287-
max_num_seqs = self.parallel_config.max_num_seqs
288-
top_p_buffer = paddle.full([max_num_seqs, 1],
289-
self.model_config.top_p,
290-
dtype='float32')
291-
top_k_buffer = paddle.full([max_num_seqs, 1],
292-
0,
293-
dtype='int64')
294285
req_len = len(req_dicts)
295286
for i in range(req_len):
296287
request = req_dicts[i]
297288
idx = request.idx
298289
length = request.prompt_token_ids_len
299-
if sampling_params := request.sampling_params:
300-
if sampling_params.top_p < 1:
301-
top_p_reqs.append(idx)
302-
top_k = sampling_params.top_k
303-
if top_k > 0:
304-
top_k_reqs.append(idx)
305290
self.share_inputs["input_ids"][idx:idx + 1, :length] = np.array(
306291
request.prompt_token_ids)
307292
if len(request.eos_token_ids
@@ -310,8 +295,8 @@ def process_prefill_inputs(self, req_dicts: List[Request]):
310295
self.share_inputs["eos_token_id"][:] = np.array(
311296
request.eos_token_ids, dtype="int64").reshape(-1, 1)
312297
self.share_inputs["pre_ids"][idx:idx + 1] = -1
313-
top_p_buffer[idx:idx + 1] = request.get("top_p", 1.0)
314-
top_k_buffer[idx:idx + 1] = request.get("top_k", 0)
298+
self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 1.0)
299+
self.share_inputs["top_k"][idx:idx + 1] = request.get("top_k", 0)
315300
self.share_inputs["temperature"][idx:idx + 1] = request.get(
316301
"temperature", 0.95)
317302
self.share_inputs["penalty_score"][idx:idx + 1] = request.get(
@@ -360,15 +345,6 @@ def process_prefill_inputs(self, req_dicts: List[Request]):
360345
request.get("stop_token_ids"), dtype="int64")
361346

362347
self.share_inputs["not_need_stop"][0] = True
363-
if len(top_k_reqs) == 0:
364-
self.share_inputs["top_k"] = None
365-
else:
366-
self.share_inputs["top_k"] = top_k_buffer
367-
368-
if len(top_p_reqs) == 0:
369-
self.share_inputs["top_p"] = None
370-
else:
371-
self.share_inputs["top_p"] = top_p_buffer
372348

373349
def _init_share_inputs(self, max_num_seqs: int):
374350
"""Initialize all share buffers for model inputs.

0 commit comments

Comments
 (0)