Skip to content

Commit 1e2319c

Browse files
authored
Rename top_p_sampling to top_k_top_p_sampling (#2791)
1 parent e45050c commit 1e2319c

File tree

5 files changed

+23
-16
lines changed

5 files changed

+23
-16
lines changed

docs/offline_inference.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
FastDeploy supports offline inference by loading models locally and processing user data. Usage examples:
55

66
### Chat Interface (LLM.chat)
7+
78
```python
89
from fastdeploy import LLM, SamplingParams
910

@@ -77,10 +78,12 @@ for output in outputs:
7778
prompt = output.prompt
7879
generated_text = output.outputs.text
7980
```
81+
8082
> Note: Text completion interface, suitable for scenarios where users have predefined the context input and expect the model to output only the continuation content. No additional `prompt` concatenation will be added during the inference process.
8183
> For the `chat` model, it is recommended to use the Chat Interface (`LLM.chat`).
8284
8385
For multimodal models, such as `baidu/ERNIE-4.5-VL-28B-A3B-Paddle`, when calling the `generate interface`, you need to provide a prompt that includes images. The usage is as follows:
86+
8487
```python
8588
import io
8689
import os
@@ -96,7 +99,7 @@ tokenizer = ErnieBotTokenizer.from_pretrained(os.path.dirname(PATH))
9699

97100
messages = [
98101
{
99-
"role": "user",
102+
"role": "user",
100103
"content": [
101104
{"type":"image_url", "image_url": {"url":"https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0"}},
102105
{"type":"text", "text":"这张图片的内容是什么"}
@@ -141,6 +144,7 @@ for output in outputs:
141144
reasoning_text = output.outputs.reasoning_content
142145

143146
```
147+
144148
>Note: The `generate interface` does not currently support passing parameters to control the thinking function (on/off). It always uses the model's default parameters.
145149
146150
## 2. API Documentation
@@ -159,12 +163,12 @@ For ```LLM``` configuration, refer to [Parameter Documentation](parameters.md).
159163
* messages(list[dict],list[list[dict]]): Input messages (batch supported)
160164
* sampling_params: See 2.4 for parameter details
161165
* use_tqdm: Enable progress visualization
162-
* chat_template_kwargs(dict): Extra template parameters (currently supports enable_thinking(bool))
166+
* chat_template_kwargs(dict): Extra template parameters (currently supports enable_thinking(bool))
163167
*usage example: `chat_template_kwargs={"enable_thinking": False}`*
164168

165169
### 2.3 fastdeploy.LLM.generate
166170

167-
* prompts(str, list[str], list[int], list[list[int]], dict[str, Any], list[dict[str, Any]]): : Input prompts (batch supported), accepts decoded token ids
171+
* prompts(str, list[str], list[int], list[list[int]], dict[str, Any], list[dict[str, Any]]): : Input prompts (batch supported), accepts decoded token ids
168172
*example of using a dict-type parameter: `prompts={"prompt": prompt, "multimodal_data": {"image": images}}`*
169173
* sampling_params: See 2.4 for parameter details
170174
* use_tqdm: Enable progress visualization
@@ -176,6 +180,7 @@ For ```LLM``` configuration, refer to [Parameter Documentation](parameters.md).
176180
* repetition_penalty(float): Direct penalty for repeated tokens (>1 penalizes, <1 encourages)
177181
* temperature(float): Controls randomness (higher = more random)
178182
* top_p(float): Probability threshold for token selection
183+
* top_k(int): Number of tokens considered for sampling
179184
* max_tokens(int): Maximum generated tokens (input + output)
180185
* min_tokens(int): Minimum forced generation length
181186

@@ -206,4 +211,4 @@ For ```LLM``` configuration, refer to [Parameter Documentation](parameters.md).
206211
* first_token_time(float): First token latency
207212
* time_in_queue(float): Queuing time
208213
* model_forward_time(float): Forward pass duration
209-
* model_execute_time(float): Total execution time (including preprocessing)
214+
* model_execute_time(float): Total execution time (including preprocessing)

docs/zh/offline_inference.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,12 @@ for output in outputs:
7878
prompt = output.prompt
7979
generated_text = output.outputs.text
8080
```
81-
> 注: 续写接口, 适应于用户自定义好上下文输入, 并希望模型仅输出续写内容的场景; 推理过程不会增加其他 `prompt `拼接。
81+
82+
> 注: 续写接口, 适应于用户自定义好上下文输入, 并希望模型仅输出续写内容的场景; 推理过程不会增加其他 `prompt`拼接。
8283
> 对于 `chat`模型, 建议使用对话接口(LLM.chat)。
8384
8485
对于多模模型, 例如`baidu/ERNIE-4.5-VL-28B-A3B-Paddle`, 在调用`generate接口`时, 需要提供包含图片的prompt, 使用方式如下:
86+
8587
```python
8688
import io
8789
import os
@@ -97,7 +99,7 @@ tokenizer = ErnieBotTokenizer.from_pretrained(os.path.dirname(PATH))
9799

98100
messages = [
99101
{
100-
"role": "user",
102+
"role": "user",
101103
"content": [
102104
{"type":"image_url", "image_url": {"url":"https://ku.baidu-int.com/vk-assets-ltd/space/2024/09/13/933d1e0a0760498e94ec0f2ccee865e0"}},
103105
{"type":"text", "text":"这张图片的内容是什么"}
@@ -142,6 +144,7 @@ for output in outputs:
142144
reasoning_text = output.outputs.reasoning_content
143145

144146
```
147+
145148
> 注: `generate` 接口, 暂时不支持思考开关参数控制, 均使用模型默认思考能力。
146149
147150
## 2. 接口说明
@@ -155,18 +158,17 @@ for output in outputs:
155158
> 2. 模型服务启动后,会在日志文件log/fastdeploy.log中打印如 `Doing profile, the total_block_num:640` 的日志,其中640即表示自动计算得到的KV Cache block数量,将它乘以block_size(默认值64),即可得到部署后总共可以在KV Cache中缓存的Token数。
156159
> 3. `max_num_seqs` 用于配置decode阶段最大并发处理请求数,该参数可以基于第1点中缓存的Token数来计算一个较优值,例如线上统计输入平均token数800, 输出平均token数500,本次计>算得到KV Cache block为640, block_size为64。那么我们可以配置 `kv_cache_ratio = 800 / (800 + 500) = 0.6` , 配置 `max_seq_len = 640 * 64 / (800 + 500) = 31`
157160
158-
159161
### 2.2 fastdeploy.LLM.chat
160162

161163
* messages(list[dict],list[list[dict]]): 输入的message, 支持batch message 输入
162164
* sampling_params: 模型超参设置具体说明见2.4
163165
* use_tqdm: 是否打开推理进度可视化
164-
* chat_template_kwargs(dict): 传递给对话模板的额外参数,当前支持enable_thinking(bool)
166+
* chat_template_kwargs(dict): 传递给对话模板的额外参数,当前支持enable_thinking(bool)
165167
*使用示例`chat_template_kwargs={"enable_thinking": False}`*
166168

167169
### 2.3 fastdeploy.LLM.generate
168170

169-
* prompts(str, list[str], list[int], list[list[int]], dict[str, Any], list[dict[str, Any]]): 输入的prompt, 支持batch prompt 输入,解码后的token ids 进行输入
171+
* prompts(str, list[str], list[int], list[list[int]], dict[str, Any], list[dict[str, Any]]): 输入的prompt, 支持batch prompt 输入,解码后的token ids 进行输入
170172
*dict 类型使用示例`prompts={"prompt": prompt, "multimodal_data": {"image": images}}`*
171173
* sampling_params: 模型超参设置具体说明见2.4
172174
* use_tqdm: 是否打开推理进度可视化
@@ -178,7 +180,7 @@ for output in outputs:
178180
* repetition_penalty(float): 直接对重复生成的token进行惩罚的系数(>1时惩罚重复,<1时鼓励重复)
179181
* temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定
180182
* top_p(float): 概率累积分布截断阈值,仅考虑累计概率达到此阈值的最可能token集合
181-
* top_k(int): 采样概率最高的的token数量,考虑概率最高的k个token进行采样
183+
* top_k(int): 采样概率最高的token数量,考虑概率最高的k个token进行采样
182184
* max_tokens(int): 限制模型生成的最大token数量(包括输入和输出)
183185
* min_tokens(int): 强制模型生成的最少token数量,避免过早结束
184186

fastdeploy/model_executor/layers/sample/ops/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
from .apply_penalty_multi_scores import (
1818
apply_penalty_multi_scores, apply_speculative_penalty_multi_scores)
19-
from .top_p_sampling import top_p_sampling
19+
from .top_k_top_p_sampling import top_k_top_p_sampling
2020

2121
__all__ = [
2222
"apply_penalty_multi_scores",
2323
"apply_speculative_penalty_multi_scores",
24-
"top_p_sampling",
24+
"top_k_top_p_sampling",
2525
]

fastdeploy/model_executor/layers/sample/ops/top_p_sampling.py renamed to fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from fastdeploy.model_executor.ops.gcu import \
2626
top_p_sampling as gcu_top_p_sampling
2727

28-
def top_p_sampling(
28+
def top_k_top_p_sampling(
2929
x: paddle.Tensor,
3030
top_p: paddle.Tensor,
3131
top_k: Optional[paddle.Tensor] = None,

fastdeploy/model_executor/layers/sample/sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
2828
from fastdeploy.model_executor.layers.sample.ops import (
2929
apply_penalty_multi_scores, apply_speculative_penalty_multi_scores,
30-
top_p_sampling)
30+
top_k_top_p_sampling)
3131
from fastdeploy.platforms import current_platform
3232

3333

@@ -214,7 +214,7 @@ def forward_cuda(
214214

215215
probs = F.softmax(logits)
216216

217-
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
217+
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
218218

219219
self.processor.update_output_tokens(next_tokens, skip_idx_list)
220220
return next_tokens
@@ -367,5 +367,5 @@ def forward_cuda(
367367
)
368368
probs = F.softmax(logits)
369369

370-
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
370+
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
371371
return next_tokens

0 commit comments

Comments
 (0)