Skip to content

Commit b0f5259

Browse files
authored
[SOT] Remove breakgraph in post processing && fix datatype (#2780)
1 parent 2ea267f commit b0f5259

File tree

3 files changed

+20
-17
lines changed

3 files changed

+20
-17
lines changed

fastdeploy/model_executor/layers/attention/ops/get_block_shape_and_split_kv_block.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818

1919
from fastdeploy.platforms import current_platform
2020

21+
if current_platform.is_cuda():
22+
from fastdeploy.model_executor.ops.gpu import \
23+
get_block_shape_and_split_kv_block as \
24+
get_block_shape_and_split_kv_block_cuda
25+
2126

2227
def get_block_shape_and_split_kv_block(
2328
seq_lens_encoder: paddle.Tensor,
@@ -34,7 +39,6 @@ def get_block_shape_and_split_kv_block(
3439
get_block_shape_and_split_kv_block
3540
"""
3641
if current_platform.is_cuda():
37-
from fastdeploy.model_executor.ops.gpu import get_block_shape_and_split_kv_block
3842
(
3943
encoder_batch_ids,
4044
encoder_tile_ids_per_batch,
@@ -47,7 +51,7 @@ def get_block_shape_and_split_kv_block(
4751
decoder_num_blocks,
4852
max_len_kv,
4953
set_max_lengths,
50-
) = get_block_shape_and_split_kv_block(
54+
) = get_block_shape_and_split_kv_block_cuda(
5155
seq_lens_encoder,
5256
seq_lens_decoder,
5357
seq_lens_this_time,

fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,8 @@ def forward(
395395
image_mask = ids_remove_padding == self.im_patch_id
396396
token_type_ids = image_mask.cast("int32")
397397
token_num = hidden_states.shape[0]
398-
image_token_num = paddle.count_nonzero(token_type_ids).cast("int32")
399-
text_token_num = paddle.maximum(token_num - image_token_num, paddle.ones([], dtype="int32"))
398+
image_token_num = paddle.count_nonzero(token_type_ids)
399+
text_token_num = paddle.maximum((token_num - image_token_num), paddle.ones([], dtype="int64"))
400400
if image_mask.any():
401401
hidden_states[image_mask] = image_features.cast(self._dtype)
402402
text_input = paddle.full(
@@ -444,7 +444,7 @@ def forward(
444444
hidden_states = extract_text_token_output(
445445
max_seq_len,
446446
max_seq_len_index.cast("int32"),
447-
image_token_num,
447+
image_token_num.cast("int32"),
448448
forward_meta.seq_lens_this_time,
449449
forward_meta.cu_seqlens_q,
450450
score_text,

fastdeploy/worker/vl_gpu_model_runner.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -929,18 +929,17 @@ def post_process(self, next_tokens: paddle.Tensor) -> None:
929929
False,
930930
) # multi ends
931931
# update inputs
932-
with paddle.framework._no_check_dy2st_diff():
933-
update_inputs(
934-
self.share_inputs["stop_flags"],
935-
self.share_inputs["not_need_stop"],
936-
self.share_inputs["seq_lens_this_time"],
937-
self.share_inputs["seq_lens_encoder"],
938-
self.share_inputs["seq_lens_decoder"],
939-
self.share_inputs["input_ids"],
940-
self.share_inputs["stop_nums"],
941-
next_tokens,
942-
self.share_inputs["is_block_step"],
943-
)
932+
update_inputs(
933+
self.share_inputs["stop_flags"],
934+
self.share_inputs["not_need_stop"],
935+
self.share_inputs["seq_lens_this_time"],
936+
self.share_inputs["seq_lens_encoder"],
937+
self.share_inputs["seq_lens_decoder"],
938+
self.share_inputs["input_ids"],
939+
self.share_inputs["stop_nums"],
940+
next_tokens,
941+
self.share_inputs["is_block_step"],
942+
)
944943
save_output(
945944
next_tokens,
946945
self.share_inputs["not_need_stop"],

0 commit comments

Comments
 (0)