Skip to content

Commit ace169a

Browse files
authored
fix(core): Fix AWEL branch bug (#1640)
1 parent 49b56b4 commit ace169a

32 files changed

+866
-477
lines changed

dbgpt/app/scene/operators/app_operator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from dbgpt.core.awel import (
1414
DAG,
1515
BaseOperator,
16+
BranchJoinOperator,
1617
InputOperator,
17-
JoinOperator,
1818
MapOperator,
1919
SimpleCallDataInputSource,
2020
)
@@ -195,9 +195,7 @@ def build_cached_chat_operator(
195195
cache_task_name=cache_task_name,
196196
)
197197
# Create a join node to merge outputs from the model and cache nodes, just keep the first not empty output
198-
join_task = JoinOperator(
199-
combine_function=lambda model_out, cache_out: cache_out or model_out
200-
)
198+
join_task = BranchJoinOperator()
201199

202200
# Define the workflow structure using the >> operator
203201
input_task >> cache_check_branch_task

dbgpt/core/awel/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .operators.base import BaseOperator, WorkflowRunner
1818
from .operators.common_operator import (
1919
BranchFunc,
20+
BranchJoinOperator,
2021
BranchOperator,
2122
BranchTaskType,
2223
InputOperator,
@@ -78,6 +79,7 @@
7879
"ReduceStreamOperator",
7980
"TriggerOperator",
8081
"MapOperator",
82+
"BranchJoinOperator",
8183
"BranchOperator",
8284
"InputOperator",
8385
"BranchFunc",

dbgpt/core/awel/flow/base.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""The mixin of DAGs."""
2+
23
import abc
34
import dataclasses
45
import inspect
@@ -337,6 +338,9 @@ class Parameter(TypeMetadata, Serializable):
337338
value: Optional[Any] = Field(
338339
None, description="The value of the parameter(Saved in the dag file)"
339340
)
341+
alias: Optional[List[str]] = Field(
342+
None, description="The alias of the parameter(Compatible with old version)"
343+
)
340344

341345
@model_validator(mode="before")
342346
@classmethod
@@ -398,6 +402,7 @@ def build_from(
398402
description: Optional[str] = None,
399403
options: Optional[Union[BaseDynamicOptions, List[OptionValue]]] = None,
400404
resource_type: ResourceType = ResourceType.INSTANCE,
405+
alias: Optional[List[str]] = None,
401406
):
402407
"""Build the parameter from the type."""
403408
type_name = type.__qualname__
@@ -419,6 +424,7 @@ def build_from(
419424
placeholder=placeholder,
420425
description=description or label,
421426
options=options,
427+
alias=alias,
422428
)
423429

424430
@classmethod
@@ -452,7 +458,7 @@ def build_from_ui(cls, data: Dict) -> "Parameter":
452458

453459
def to_dict(self) -> Dict:
454460
"""Convert current metadata to json dict."""
455-
dict_value = model_to_dict(self, exclude={"options"})
461+
dict_value = model_to_dict(self, exclude={"options", "alias"})
456462
if not self.options:
457463
dict_value["options"] = None
458464
elif isinstance(self.options, BaseDynamicOptions):
@@ -677,9 +683,18 @@ def get_runnable_parameters(
677683
for parameter in self.parameters
678684
if not parameter.optional
679685
}
680-
current_parameters = {
681-
parameter.name: parameter for parameter in self.parameters
682-
}
686+
current_parameters = {}
687+
current_aliases_parameters = {}
688+
for parameter in self.parameters:
689+
current_parameters[parameter.name] = parameter
690+
if parameter.alias:
691+
for alias in parameter.alias:
692+
if alias in current_aliases_parameters:
693+
raise FlowMetadataException(
694+
f"Alias {alias} already exists in the metadata."
695+
)
696+
current_aliases_parameters[alias] = parameter
697+
683698
if len(view_required_parameters) < len(current_required_parameters):
684699
# TODO, skip the optional parameters.
685700
raise FlowParameterMetadataException(
@@ -691,12 +706,16 @@ def get_runnable_parameters(
691706
)
692707
for view_param in view_parameters:
693708
view_param_key = view_param.name
694-
if view_param_key not in current_parameters:
709+
if view_param_key in current_parameters:
710+
current_parameter = current_parameters[view_param_key]
711+
elif view_param_key in current_aliases_parameters:
712+
current_parameter = current_aliases_parameters[view_param_key]
713+
else:
695714
raise FlowParameterMetadataException(
696715
f"Parameter {view_param_key} not in the metadata."
697716
)
698717
runnable_parameters.update(
699-
current_parameters[view_param_key].to_runnable_parameter(
718+
current_parameter.to_runnable_parameter(
700719
view_param.get_typed_value(), resources, key_to_resource_instance
701720
)
702721
)

dbgpt/core/awel/flow/compat.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,29 @@
11
"""Compatibility mapping for flow classes."""
22

3+
from dataclasses import dataclass
34
from typing import Dict, Optional
45

5-
_COMPAT_FLOW_MAPPING: Dict[str, str] = {}
6+
7+
@dataclass
8+
class _RegisterItem:
9+
"""Register item for compatibility mapping."""
10+
11+
old_module: str
12+
new_module: str
13+
old_name: str
14+
new_name: Optional[str] = None
15+
after: Optional[str] = None
16+
17+
def old_cls_key(self) -> str:
18+
"""Get the old class key."""
19+
return f"{self.old_module}.{self.old_name}"
20+
21+
def new_cls_key(self) -> str:
22+
"""Get the new class key."""
23+
return f"{self.new_module}.{self.new_name}"
24+
25+
26+
_COMPAT_FLOW_MAPPING: Dict[str, _RegisterItem] = {}
627

728

829
_OLD_AGENT_RESOURCE_MODULE_1 = "dbgpt.serve.agent.team.layout.agent_operator_resource"
@@ -11,17 +32,24 @@
1132

1233

1334
def _register(
14-
old_module: str, new_module: str, old_name: str, new_name: Optional[str] = None
35+
old_module: str,
36+
new_module: str,
37+
old_name: str,
38+
new_name: Optional[str] = None,
39+
after_version: Optional[str] = None,
1540
):
1641
if not new_name:
1742
new_name = old_name
18-
_COMPAT_FLOW_MAPPING[f"{old_module}.{old_name}"] = f"{new_module}.{new_name}"
43+
item = _RegisterItem(old_module, new_module, old_name, new_name, after_version)
44+
_COMPAT_FLOW_MAPPING[item.old_cls_key()] = item
1945

2046

2147
def get_new_class_name(old_class_name: str) -> Optional[str]:
2248
"""Get the new class name for the old class name."""
23-
new_cls_name = _COMPAT_FLOW_MAPPING.get(old_class_name, None)
24-
return new_cls_name
49+
if old_class_name not in _COMPAT_FLOW_MAPPING:
50+
return None
51+
item = _COMPAT_FLOW_MAPPING[old_class_name]
52+
return item.new_cls_key()
2553

2654

2755
_register(
@@ -54,3 +82,9 @@ def get_new_class_name(old_class_name: str) -> Optional[str]:
5482
_register(
5583
_OLD_AGENT_RESOURCE_MODULE_2, _NEW_AGENT_RESOURCE_MODULE, "AWELAgent", "AWELAgent"
5684
)
85+
_register(
86+
"dbgpt.storage.vector_store.connector",
87+
"dbgpt.serve.rag.connector",
88+
"VectorStoreConnector",
89+
after_version="v0.5.8",
90+
)

dbgpt/core/awel/flow/flow_factory.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -555,14 +555,6 @@ def build(self, flow_panel: FlowPanel) -> DAG:
555555
downstream = key_to_downstream.get(operator_key, [])
556556
if not downstream:
557557
raise ValueError("Branch operator should have downstream.")
558-
if len(downstream) != len(view_metadata.parameters):
559-
raise ValueError(
560-
"Branch operator should have the same number of downstream as "
561-
"parameters."
562-
)
563-
for i, param in enumerate(view_metadata.parameters):
564-
downstream_key, _, _ = downstream[i]
565-
param.value = key_to_operator_nodes[downstream_key].data.name
566558

567559
try:
568560
runnable_params = metadata.get_runnable_parameters(

dbgpt/core/awel/operators/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def __init__(
137137
task_name: Optional[str] = None,
138138
dag: Optional[DAG] = None,
139139
runner: Optional[WorkflowRunner] = None,
140+
can_skip_in_branch: bool = True,
140141
**kwargs,
141142
) -> None:
142143
"""Create a BaseOperator with an optional workflow runner.
@@ -157,6 +158,7 @@ def __init__(
157158

158159
self._runner: WorkflowRunner = runner
159160
self._dag_ctx: Optional[DAGContext] = None
161+
self._can_skip_in_branch = can_skip_in_branch
160162

161163
@property
162164
def current_dag_context(self) -> DAGContext:
@@ -321,6 +323,10 @@ def current_event_loop_task_id(self) -> int:
321323
"""Get the current event loop task id."""
322324
return id(asyncio.current_task())
323325

326+
def can_skip_in_branch(self) -> bool:
327+
"""Check if the operator can be skipped in the branch."""
328+
return self._can_skip_in_branch
329+
324330

325331
def initialize_runner(runner: WorkflowRunner):
326332
"""Initialize the default runner."""

dbgpt/core/awel/operators/common_operator.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ReduceFunc,
1717
TaskContext,
1818
TaskOutput,
19+
is_empty_data,
1920
)
2021
from .base import BaseOperator
2122

@@ -28,13 +29,16 @@ class JoinOperator(BaseOperator, Generic[OUT]):
2829
This node type is useful for combining the outputs of upstream nodes.
2930
"""
3031

31-
def __init__(self, combine_function: JoinFunc, **kwargs):
32+
def __init__(
33+
self, combine_function: JoinFunc, can_skip_in_branch: bool = True, **kwargs
34+
):
3235
"""Create a JoinDAGNode with a combine function.
3336
3437
Args:
3538
combine_function: A function that defines how to combine inputs.
39+
can_skip_in_branch(bool): Whether the node can be skipped in a branch.
3640
"""
37-
super().__init__(**kwargs)
41+
super().__init__(can_skip_in_branch=can_skip_in_branch, **kwargs)
3842
if not callable(combine_function):
3943
raise ValueError("combine_function must be callable")
4044
self.combine_function = combine_function
@@ -57,6 +61,12 @@ async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]:
5761
curr_task_ctx.set_task_output(join_output)
5862
return join_output
5963

64+
async def _return_first_non_empty(self, *inputs):
65+
for data in inputs:
66+
if not is_empty_data(data):
67+
return data
68+
raise ValueError("All inputs are empty")
69+
6070

6171
class ReduceStreamOperator(BaseOperator, Generic[IN, OUT]):
6272
"""Operator that reduces inputs using a custom reduce function."""
@@ -287,6 +297,32 @@ async def branches(self) -> Dict[BranchFunc[IN], BranchTaskType]:
287297
raise NotImplementedError
288298

289299

300+
class BranchJoinOperator(JoinOperator, Generic[OUT]):
301+
"""Operator that joins inputs using a custom combine function.
302+
303+
This node type is useful for combining the outputs of upstream nodes.
304+
"""
305+
306+
def __init__(
307+
self,
308+
combine_function: Optional[JoinFunc] = None,
309+
can_skip_in_branch: bool = False,
310+
**kwargs,
311+
):
312+
"""Create a JoinDAGNode with a combine function.
313+
314+
Args:
315+
combine_function: A function that defines how to combine inputs.
316+
can_skip_in_branch(bool): Whether the node can be skipped in a branch(
317+
default True).
318+
"""
319+
super().__init__(
320+
combine_function=combine_function or self._return_first_non_empty,
321+
can_skip_in_branch=can_skip_in_branch,
322+
**kwargs,
323+
)
324+
325+
290326
class InputOperator(BaseOperator, Generic[OUT]):
291327
"""Operator node that reads data from an input source."""
292328

dbgpt/core/awel/runner/local_runner.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
This runner will run the workflow in the current process.
44
"""
5+
56
import asyncio
67
import logging
78
import traceback
@@ -11,7 +12,7 @@
1112

1213
from ..dag.base import DAGContext, DAGVar
1314
from ..operators.base import CALL_DATA, BaseOperator, WorkflowRunner
14-
from ..operators.common_operator import BranchOperator, JoinOperator
15+
from ..operators.common_operator import BranchOperator
1516
from ..task.base import SKIP_DATA, TaskContext, TaskState
1617
from ..task.task_impl import DefaultInputContext, DefaultTaskContext, SimpleTaskOutput
1718
from .job_manager import JobManager
@@ -184,14 +185,14 @@ def _skip_current_downstream_by_node_name(
184185
return
185186
for child in branch_node.downstream:
186187
child = cast(BaseOperator, child)
187-
if child.node_name in skip_nodes:
188+
if child.node_name in skip_nodes or child.node_id in skip_node_ids:
188189
logger.info(f"Skip node name {child.node_name}, node id {child.node_id}")
189190
_skip_downstream_by_id(child, skip_node_ids)
190191

191192

192193
def _skip_downstream_by_id(node: BaseOperator, skip_node_ids: Set[str]):
193-
if isinstance(node, JoinOperator):
194-
# Not skip join node
194+
if not node.can_skip_in_branch():
195+
# Current node can not skip, so skip its downstream
195196
return
196197
skip_node_ids.add(node.node_id)
197198
for child in node.downstream:

dbgpt/core/awel/tests/test_run_dag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def join_func(o1, o2) -> int:
130130
even_node = MapOperator(
131131
lambda x: 888, task_id="even_node", task_name="even_node_name"
132132
)
133-
join_node = JoinOperator(join_func)
133+
join_node = JoinOperator(join_func, can_skip_in_branch=False)
134134
branch_node = BranchOperator(
135135
{lambda x: x % 2 == 1: odd_node, lambda x: x % 2 == 0: even_node}
136136
)

0 commit comments

Comments
 (0)