Skip to content

Commit f72c4de

Browse files
authored
[SOT] Make custom_op dy&st unified (#2733)
* make_custom_op dy&st unified * add instance judgement
1 parent f6ffbc3 commit f72c4de

File tree

3 files changed

+22
-21
lines changed

3 files changed

+22
-21
lines changed

fastdeploy/engine/engine.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -931,8 +931,8 @@ def _exit_sub_services(self):
931931

932932
def _setting_environ_variables(self):
933933
"""
934-
配置环境变量
935-
"""
934+
配置环境变量
935+
"""
936936
variables = {
937937
"PADDLE_TRAINER_ID": 0,
938938
"PADDLE_TRAINERS_NUM": 1,
@@ -998,8 +998,8 @@ def _start_worker_service(self):
998998
py_script = os.path.join(current_dir_path, worker_path)
999999

10001000
ori_vocab_size = (
1001-
len(self.data_processor.tokenizer.sp_model)
1002-
if hasattr(self.data_processor.tokenizer, 'sp_model')
1001+
len(self.data_processor.tokenizer.sp_model)
1002+
if hasattr(self.data_processor.tokenizer, 'sp_model')
10031003
else len(self.data_processor.tokenizer.vocab)
10041004
)
10051005

fastdeploy/import_ops.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import functools
1616
import importlib
1717
import inspect
18-
import os
1918

2019
import paddle
2120

@@ -77,7 +76,13 @@ def wrap_unified_op(original_cpp_ext_op, original_custom_op):
7776
@functools.wraps(original_custom_op)
7877
def unified_op(*args, **kwargs):
7978
if paddle.in_dynamic_mode():
80-
return original_cpp_ext_op(*args, **kwargs)
79+
res = original_cpp_ext_op(*args, **kwargs)
80+
if res is None:
81+
return None
82+
# TODO(DrRyanHuang): Remove this if when we align the implementation of custom op and C++ extension
83+
if isinstance(res, list) and len(res) == 1:
84+
return res[0]
85+
return res
8186
return original_custom_op(*args, **kwargs)
8287

8388
return unified_op
@@ -93,17 +98,13 @@ def preprocess_static_op(global_ns):
9398
"""
9499
static_op_prefix = "static_op_"
95100
static_op_names = [k for k in global_ns if k.startswith(static_op_prefix)]
96-
enforce_eager = int(os.getenv("FD_ENFORCE_EAGER", "0")) == 1
97-
98-
for static_op in static_op_names:
99-
op_name = static_op[len(static_op_prefix):]
100-
has_dynamic_op = op_name in global_ns
101-
102-
if has_dynamic_op:
103-
if not enforce_eager:
104-
original_cpp_ext_op = global_ns[op_name]
105-
original_custom_op = global_ns[static_op]
106-
global_ns[op_name] = wrap_unified_op(original_cpp_ext_op,
107-
original_custom_op)
108-
else:
109-
global_ns[op_name] = global_ns[static_op]
101+
102+
for static_op_name in static_op_names:
103+
op_name = static_op_name.removeprefix(static_op_prefix)
104+
if op_name not in global_ns:
105+
global_ns[op_name] = global_ns[static_op_name]
106+
continue
107+
108+
original_cpp_ext_op = global_ns[op_name]
109+
original_custom_op = global_ns[static_op_name]
110+
global_ns[op_name] = wrap_unified_op(original_cpp_ext_op, original_custom_op)

fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def forward(
445445
forward_meta.seq_lens_this_time,
446446
forward_meta.cu_seqlens_q,
447447
score_text,
448-
)[0].cast(self._dtype)
448+
).cast(self._dtype)
449449
# -----------------------
450450

451451
out = self.norm(hidden_states)

0 commit comments

Comments
 (0)