diff --git a/serverlessworkflow/sdk/state_machine_generator.py b/serverlessworkflow/sdk/state_machine_generator.py index e6058f3..e9f98cb 100644 --- a/serverlessworkflow/sdk/state_machine_generator.py +++ b/serverlessworkflow/sdk/state_machine_generator.py @@ -47,7 +47,7 @@ def __init__( "The provided state machine can not be of the HierarchicalMachine type." ) - def source_code(self): + def generate(self): self.definitions() self.transitions() @@ -182,12 +182,16 @@ def parallel_state_details(self): branches = self.state.branches if branches: if self.get_actions: + self.state_machine.get_state(state_name).initial = [] for branch in branches: if hasattr(branch, "actions") and branch.actions: branch_name = branch.name self.state_machine.get_state(state_name).add_substates( NestedState(branch_name) ) + self.state_machine.get_state(state_name).initial.append( + branch_name + ) branch_state = self.state_machine.get_state( state_name ).states[branch.name] @@ -196,12 +200,6 @@ def parallel_state_details(self): state_name=f"{state_name}.{branch_name}", actions=branch.actions, ) - self.generate_composite_state( - branch_state, - f"{state_name}.{branch_name}", - branch.actions, - "sequential", - ) def event_based_switch_state_details(self): ... @@ -242,7 +240,59 @@ def callback_state_details(self): actions=[action], ) - def generate_composite_state( + def get_subflow_state( + self, machine_state: NestedState, state_name: str, actions: List[Action] + ): + added_states = {} + for i, action in enumerate(actions): + if action.subFlowRef: + if isinstance(action.subFlowRef, str): + workflow_id = action.subFlowRef + workflow_version = None + else: + workflow_id = action.subFlowRef.workflowId + workflow_version = action.subFlowRef.version + none_found = True + for sf in self.subflows: + if sf.id == workflow_id and ( + (workflow_version and sf.version == workflow_version) + or not workflow_version + ): + none_found = False + new_machine = HierarchicalMachine( + model=None, initial=None, auto_transitions=False + ) + + # Generate the state machine for the subflow + for index, state in enumerate(sf.states): + StateMachineGenerator( + state=state, + state_machine=new_machine, + is_first_state=index == 0, + get_actions=self.get_actions, + subflows=self.subflows, + ).generate() + + # Convert the new_machine into a NestedState + added_states[i] = self.subflow_state_name( + action=action, subflow=sf + ) + nested_state = NestedState(added_states[i]) + machine_state.add_substate(nested_state) + self.state_machine_to_nested_state( + state_name=state_name, + state_machine=new_machine, + nested_state=nested_state, + ) + + if none_found: + warnings.warn( + f"Specified subflow [{workflow_id} {workflow_version if workflow_version else ''}] not found.", + category=UserWarning, + ) + return added_states + + def generate_actions_info( self, machine_state: NestedState, state_name: str, @@ -250,111 +300,70 @@ def generate_composite_state( action_mode: str = "sequential", ): parallel_states = [] - if actions: + new_subflows_names = self.get_subflow_state( + machine_state=machine_state, state_name=state_name, actions=actions + ) for i, action in enumerate(actions): - fn_name = ( - self.get_function_name(action.functionRef) - if isinstance(action.functionRef, str) - else ( - action.functionRef.refName - if isinstance(action.functionRef, FunctionRef) - else None + name = None + if action.functionRef: + name = ( + self.get_function_name(action.functionRef) + if isinstance(action.functionRef, str) + else ( + action.functionRef.refName + if isinstance(action.functionRef, FunctionRef) + else None + ) ) - ) - if fn_name: - if fn_name not in machine_state.states.keys(): - machine_state.add_substate(NestedState(fn_name)) + if name not in machine_state.states.keys(): + machine_state.add_substate(NestedState(name)) + elif action.subFlowRef: + name = new_subflows_names.get(i) + if name: if action_mode == "sequential": if i < len(actions) - 1: - next_fn_name = ( - self.get_function_name(actions[i + 1].functionRef) - if isinstance(actions[i + 1].functionRef, str) - else ( - actions[i + 1].functionRef.refName - if isinstance( - actions[i + 1].functionRef, FunctionRef + # get next name + next_name = None + if actions[i + 1].functionRef: + next_name = ( + self.get_function_name(actions[i + 1].functionRef) + if isinstance(actions[i + 1].functionRef, str) + else ( + actions[i + 1].functionRef.refName + if isinstance( + actions[i + 1].functionRef, FunctionRef + ) + else None ) - else None ) - ) - if ( - next_fn_name - not in self.state_machine.get_state( - state_name - ).states.keys() - ): - machine_state.add_substate(NestedState(next_fn_name)) + if ( + next_name + not in self.state_machine.get_state( + state_name + ).states.keys() + ): + machine_state.add_substate(NestedState(next_name)) + elif actions[i + 1].subFlowRef: + next_name = new_subflows_names.get(i + 1) self.state_machine.add_transition( trigger="", - source=f"{state_name}.{fn_name}", - dest=f"{state_name}.{next_fn_name}", + source=f"{state_name}.{name}", + dest=f"{state_name}.{next_name}", ) if i == 0: - machine_state.initial = fn_name + machine_state.initial = name elif action_mode == "parallel": - parallel_states.append(fn_name) + parallel_states.append(name) if action_mode == "parallel": machine_state.initial = parallel_states - def generate_actions_info( - self, - machine_state: NestedState, - state_name: str, - actions: List[Action], - action_mode: str = "sequential", - ): - if actions: - if self.get_actions: - self.generate_composite_state( - machine_state, - state_name, - actions, - action_mode, - ) - for action in actions: - if action.subFlowRef: - if isinstance(action.subFlowRef, str): - workflow_id = action.subFlowRef - workflow_version = None - else: - workflow_id = action.subFlowRef.workflowId - workflow_version = action.subFlowRef.version - none_found = True - for sf in self.subflows: - if sf.id == workflow_id and ( - (workflow_version and sf.version == workflow_version) - or not workflow_version - ): - none_found = False - new_machine = HierarchicalMachine( - model=None, initial=None, auto_transitions=False - ) - - # Generate the state machine for the subflow - for index, state in enumerate(sf.states): - StateMachineGenerator( - state=state, - state_machine=new_machine, - is_first_state=index == 0, - get_actions=self.get_actions, - subflows=self.subflows, - ).source_code() - - # Convert the new_machine into a NestedState - nested_state = NestedState( - action.name - if action.name - else f"{sf.id}/{sf.version.replace(NestedState.separator, '-')}" - ) - self.state_machine_to_nested_state( - state_machine=new_machine, nested_state=nested_state - ) - if none_found: - warnings.warn( - f"Specified subflow [{workflow_id} {workflow_version if workflow_version else ''}] not found.", - category=UserWarning, - ) + def subflow_state_name(self, action: Action, subflow: Workflow): + return ( + action.name + if action.name + else f"{subflow.id}/{subflow.version.replace(NestedState.separator, '-')}" + ) def add_all_sub_states( cls, @@ -366,12 +375,14 @@ def add_all_sub_states( for substate in original_state.states.values(): new_state.add_substate(ns := NestedState(substate.name)) cls.add_all_sub_states(substate, ns) + new_state.initial = original_state.initial def state_machine_to_nested_state( - self, state_machine: HierarchicalMachine, nested_state: NestedState + self, + state_name: str, + state_machine: HierarchicalMachine, + nested_state: NestedState, ) -> NestedState: - self.state_machine.get_state(self.state.name).add_substate(nested_state) - self.add_all_sub_states(state_machine, nested_state) for trigger, event in state_machine.events.items(): @@ -381,8 +392,8 @@ def state_machine_to_nested_state( dest = transition.dest self.state_machine.add_transition( trigger=trigger, - source=f"{self.state.name}.{nested_state.name}.{source}", - dest=f"{self.state.name}.{nested_state.name}.{dest}", + source=f"{state_name}.{nested_state.name}.{source}", + dest=f"{state_name}.{nested_state.name}.{dest}", ) def get_function_name( diff --git a/serverlessworkflow/sdk/state_machine_helper.py b/serverlessworkflow/sdk/state_machine_helper.py index 887bae9..e0de9e4 100644 --- a/serverlessworkflow/sdk/state_machine_helper.py +++ b/serverlessworkflow/sdk/state_machine_helper.py @@ -32,10 +32,13 @@ def __init__( ) for index, state in enumerate(workflow.states): StateMachineGenerator( - state=state, state_machine=self.machine, is_first_state=index == 0, get_actions=self.get_actions, subflows=subflows - ).source_code() - - + state=state, + state_machine=self.machine, + is_first_state=index == 0, + get_actions=self.get_actions, + subflows=subflows, + ).generate() + delattr(self.machine, "get_graph") self.machine.add_model(machine_type.self_literal)