From 134bf4697d6e9305a658b43706bca6de73388583 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Fri, 13 Jun 2025 12:28:09 -0400 Subject: [PATCH 1/5] Fix: Handle run_in_parallel=False, simplify pending function call tracking --- CHANGELOG.md | 5 + src/pipecat_flows/manager.py | 183 +++++++++++++++-------------------- 2 files changed, 84 insertions(+), 104 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c423044..9280b27 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -130,6 +130,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed an issue where if `run_in_parallel=False` was set for the LLM, the bot + would trigger N completions for each sequential function call. Now, Flows + uses Pipecat's internal function tracking to determine when there are more + edge functions to call. + - Overhauled `pre_actions` and `post_actions` timing logic, making their timing more predictable and eliminating some bugs. For example, now `tts_say` actions will always run after the bot response, when used in `post_actions`. diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index 9baa027..de46b80 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -132,7 +132,7 @@ def __init__( self.adapter = create_adapter(llm) self.initialized = False self._context_aggregator = context_aggregator - self._pending_function_calls = 0 + self._pending_transition: Optional[Dict[str, Any]] = None self._context_strategy = context_strategy or ContextStrategyConfig( strategy=ContextStrategy.APPEND ) @@ -329,87 +329,10 @@ async def _create_transition_func( if transition_callback: self._validate_transition_callback(name, transition_callback) - def decrease_pending_function_calls() -> None: - """Decrease the pending function calls counter if greater than zero.""" - if self._pending_function_calls > 0: - self._pending_function_calls -= 1 - logger.debug( - f"Function call completed: {name} (remaining: {self._pending_function_calls})" - ) - - async def on_context_updated_edge( - next_node: Optional[NodeConfig | str], - args: Optional[Dict[str, Any]], - result: Optional[Any], - result_callback: Callable, - ) -> None: - """ - Handle context updates for edge functions with transitions. - - If `next_node` is provided: - - Ignore `args` and `result` and just transition to it. - - Otherwise, if `transition_to` is available: - - Use it to look up the next node. - - Otherwise, if `transition_callback` is provided: - - Call it with `args` and `result` to determine the next node. - """ - try: - decrease_pending_function_calls() - - # Only process transition if this was the last pending call - if self._pending_function_calls == 0: - if next_node: # Function-returned next node (as opposed to next node specified by transition_*) - if isinstance(next_node, str): # Static flow - node_name = next_node - node = self.nodes[next_node] - else: # Dynamic flow - node_name = get_or_generate_node_name(next_node) - node = next_node - logger.debug(f"Transition to function-returned node: {node_name}") - await self._set_node(node_name, node) - elif transition_to: # Static flow - logger.debug(f"Static transition to: {transition_to}") - await self._set_node(transition_to, self.nodes[transition_to]) - elif transition_callback: # Dynamic flow - logger.debug(f"Dynamic transition for: {name}") - # Check callback signature - sig = inspect.signature(transition_callback) - if len(sig.parameters) == 2: - # Old style: (args, flow_manager) - await transition_callback(args, self) - else: - # New style: (args, result, flow_manager) - await transition_callback(args, result, self) - # Reset counter after transition completes - self._pending_function_calls = 0 - logger.debug("Reset pending function calls counter") - else: - logger.debug( - f"Skipping transition, {self._pending_function_calls} calls still pending" - ) - except Exception as e: - logger.error(f"Error in transition: {str(e)}") - self._pending_function_calls = 0 - await result_callback( - {"status": "error", "error": str(e)}, - properties=None, # Clear properties to prevent further callbacks - ) - raise # Re-raise to prevent further processing - - async def on_context_updated_node() -> None: - """Handle context updates for node functions without transitions.""" - decrease_pending_function_calls() - async def transition_func(params: FunctionCallParams) -> None: """Inner function that handles the actual tool invocation.""" try: - # Track pending function call - self._pending_function_calls += 1 - logger.debug( - f"Function call pending: {name} (total: {self._pending_function_calls})" - ) + logger.debug(f"Function called: {name}") # Execute handler if present is_transition_only_function = False @@ -439,47 +362,97 @@ async def transition_func(params: FunctionCallParams) -> None: result = acknowledged_result next_node = None is_transition_only_function = True + logger.debug( f"{'Transition-only function called for' if is_transition_only_function else 'Function handler completed for'} {name}" ) - # For edge functions, prevent LLM completion until transition (run_llm=False) - # For node functions, allow immediate completion (run_llm=True) + # Determine if this is an edge function has_explicit_transition = bool(transition_to) or bool(transition_callback) + is_edge_function = bool(next_node) or has_explicit_transition - async def on_context_updated() -> None: - if next_node: - await on_context_updated_edge( - next_node=next_node, - args=None, - result=None, - result_callback=params.result_callback, - ) - elif has_explicit_transition: - await on_context_updated_edge( - next_node=None, - args=params.arguments, - result=result, - result_callback=params.result_callback, - ) - else: - await on_context_updated_node() + if is_edge_function: + # Store transition info for coordinated execution + transition_info = { + "next_node": next_node, + "transition_to": transition_to, + "transition_callback": transition_callback, + "function_name": name, + "arguments": params.arguments, + "result": result, + } + self._pending_transition = transition_info + + properties = FunctionCallResultProperties( + run_llm=False, # Don't run LLM until transition completes + on_context_updated=self._check_and_execute_transition, + ) + else: + # Node function - run LLM immediately + properties = FunctionCallResultProperties( + run_llm=True, + on_context_updated=None, + ) - is_edge_function = bool(next_node) or has_explicit_transition - properties = FunctionCallResultProperties( - run_llm=not is_edge_function, - on_context_updated=on_context_updated, - ) await params.result_callback(result, properties=properties) except Exception as e: logger.error(f"Error in transition function {name}: {str(e)}") - self._pending_function_calls = 0 error_result = {"status": "error", "error": str(e)} await params.result_callback(error_result) return transition_func + async def _check_and_execute_transition(self) -> None: + """Check if all functions are complete and execute transition if so.""" + if not self._pending_transition: + return + + # Check if all function calls are complete using Pipecat's state + assistant_aggregator = self._context_aggregator.assistant() + if not assistant_aggregator._function_calls_in_progress: + # All functions complete, execute transition + transition_info = self._pending_transition + self._pending_transition = None + + await self._execute_transition(transition_info) + + async def _execute_transition(self, transition_info: Dict[str, Any]) -> None: + """Execute the stored transition.""" + next_node = transition_info.get("next_node") + transition_to = transition_info.get("transition_to") + transition_callback = transition_info.get("transition_callback") + function_name = transition_info.get("function_name") + arguments = transition_info.get("arguments") + result = transition_info.get("result") + + try: + if next_node: # Function-returned next node (consolidated function) + if isinstance(next_node, str): # Static flow + node_name = next_node + node = self.nodes[next_node] + else: # Dynamic flow + node_name = get_or_generate_node_name(next_node) + node = next_node + logger.debug(f"Transition to function-returned node: {node_name}") + await self._set_node(node_name, node) + elif transition_to: # Static flow (deprecated) + logger.debug(f"Static transition to: {transition_to}") + await self._set_node(transition_to, self.nodes[transition_to]) + elif transition_callback: # Dynamic flow (deprecated) + logger.debug(f"Dynamic transition for: {function_name}") + # Check callback signature + sig = inspect.signature(transition_callback) + if len(sig.parameters) == 2: + # Old style: (args, flow_manager) + await transition_callback(arguments, self) + else: + # New style: (args, result, flow_manager) + await transition_callback(arguments, result, self) + except Exception as e: + logger.error(f"Error executing transition: {str(e)}") + raise + def _lookup_function(self, func_name: str) -> Callable: """Look up a function by name in the main module. @@ -614,6 +587,8 @@ async def _set_node(self, node_id: str, node_config: NodeConfig) -> None: raise FlowTransitionError(f"{self.__class__.__name__} must be initialized first") try: + self._pending_transition = None + self._validate_node_config(node_id, node_config) logger.debug(f"Setting node: {node_id}") From 3ac577b2996492356bc858e5ae4f1068e0dd2053 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Fri, 13 Jun 2025 12:46:57 -0400 Subject: [PATCH 2/5] Update unit tests --- tests/test_manager.py | 120 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 114 insertions(+), 6 deletions(-) diff --git a/tests/test_manager.py b/tests/test_manager.py index d605534..191caee 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -54,6 +54,10 @@ async def asyncSetUp(self): self.mock_llm = OpenAILLMService(api_key="") self.mock_llm.register_function = MagicMock() + # Create mock assistant aggregator with function call tracking + self.mock_assistant_aggregator = MagicMock() + self.mock_assistant_aggregator._function_calls_in_progress = {} + # Create mock context aggregator self.mock_context_aggregator = MagicMock() self.mock_context_aggregator.user = MagicMock() @@ -62,6 +66,10 @@ async def asyncSetUp(self): return_value=MagicMock() ) + self.mock_context_aggregator.assistant = MagicMock( + return_value=self.mock_assistant_aggregator + ) + self.mock_result_callback = AsyncMock() # Sample node configurations @@ -255,6 +263,8 @@ async def result_callback(result, properties=None): await func(params) + self.mock_assistant_aggregator._function_calls_in_progress = {} + # Execute the context_updated callback self.assertIsNotNone(context_updated_callback, "Context updated callback not set") await context_updated_callback() @@ -297,6 +307,8 @@ async def result_callback(result, properties=None): ) await func(params) + self.mock_assistant_aggregator._function_calls_in_progress = {} + # Execute the context_updated callback self.assertIsNotNone(context_updated_callback, "Context updated callback not set") await context_updated_callback() @@ -652,6 +664,8 @@ async def result_callback(result, properties=None): # Call function await transition_func(params) + self.mock_assistant_aggregator._function_calls_in_progress = {} + # Execute the context updated callback which should trigger the error self.assertIsNotNone(context_updated_callback, "Context updated callback not set") try: @@ -659,13 +673,10 @@ async def result_callback(result, properties=None): except ValueError: pass # Expected error - # Verify error handling - should have two results: - # 1. The initial acknowledged status - # 2. The error status after the callback fails - self.assertEqual(len(final_results), 2) + # Verify error handling - should have only one result (the initial acknowledged status) + # The error handling in our new implementation doesn't call result_callback again + self.assertEqual(len(final_results), 1) self.assertEqual(final_results[0]["status"], "acknowledged") - self.assertEqual(final_results[1]["status"], "error") - self.assertIn("Transition error", final_results[1]["error"]) async def test_register_function_error_handling(self): """Test error handling in function registration.""" @@ -1320,3 +1331,100 @@ async def test_node_with_empty_functions(self): if any(isinstance(frame, LLMSetToolsFrame) for frame in call[0][0]) ] self.assertTrue(len(tools_frames_call) > 0, "Should have called LLMSetToolsFrame") + + async def test_multiple_edge_functions_single_transition(self): + """Test that multiple edge functions coordinate properly and only transition once.""" + flow_manager = FlowManager( + task=self.mock_task, + llm=self.mock_llm, + context_aggregator=self.mock_context_aggregator, + ) + await flow_manager.initialize() + + transitions_executed = 0 + + async def transition_callback(args, flow_manager): + nonlocal transitions_executed + transitions_executed += 1 + + # Create real async handler functions instead of AsyncMock + async def edge_handler_1(args): + return {"status": "success", "function": "edge_func_1"} + + async def edge_handler_2(args): + return {"status": "success", "function": "edge_func_2"} + + # Create node with multiple edge functions pointing to same transition + node_config: NodeConfig = { + "task_messages": [{"role": "system", "content": "Test"}], + "functions": [ + { + "type": "function", + "function": { + "name": "edge_func_1", + "handler": edge_handler_1, + "description": "Edge function 1", + "parameters": {}, + "transition_callback": transition_callback, + }, + }, + { + "type": "function", + "function": { + "name": "edge_func_2", + "handler": edge_handler_2, + "description": "Edge function 2", + "parameters": {}, + "transition_callback": transition_callback, + }, + }, + ], + } + + await flow_manager._set_node("test", node_config) + + # Get both registered functions + func1 = None + func2 = None + for call_args in self.mock_llm.register_function.call_args_list: + name, func = call_args[0] + if name == "edge_func_1": + func1 = func + elif name == "edge_func_2": + func2 = func + + self.assertIsNotNone(func1, "edge_func_1 should be registered") + self.assertIsNotNone(func2, "edge_func_2 should be registered") + + # Simulate both functions being called + context_callbacks = [] + + async def result_callback(result, properties=None): + if properties and properties.on_context_updated: + context_callbacks.append(properties.on_context_updated) + + # Call both functions + await func1(FunctionCallParams("edge_func_1", "id1", {}, None, None, result_callback)) + await func2(FunctionCallParams("edge_func_2", "id2", {}, None, None, result_callback)) + + # Verify both functions created context callbacks + self.assertEqual( + len(context_callbacks), 2, "Both functions should create context callbacks" + ) + + # Initially both functions are "in progress" + self.mock_assistant_aggregator._function_calls_in_progress = {"id1": True, "id2": True} + + # First function completes - should not transition yet + self.mock_assistant_aggregator._function_calls_in_progress = {"id2": True} + await context_callbacks[0]() + self.assertEqual( + transitions_executed, 0, "Should not transition while functions still pending" + ) + + # Second function completes - should transition now + self.mock_assistant_aggregator._function_calls_in_progress = {} + await context_callbacks[1]() + self.assertEqual( + transitions_executed, 1, "Should transition exactly once when all functions complete" + ) From 7e89eedc694d87e780f15a67f28ef7bb789e823d Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Fri, 13 Jun 2025 22:09:58 -0400 Subject: [PATCH 3/5] Use the new public property has_function_calls_in_progress --- src/pipecat_flows/manager.py | 2 +- tests/test_manager.py | 33 +++++++++++++++++++++------------ 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index de46b80..307117c 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -410,7 +410,7 @@ async def _check_and_execute_transition(self) -> None: # Check if all function calls are complete using Pipecat's state assistant_aggregator = self._context_aggregator.assistant() - if not assistant_aggregator._function_calls_in_progress: + if not assistant_aggregator.has_function_calls_in_progress: # All functions complete, execute transition transition_info = self._pending_transition self._pending_transition = None diff --git a/tests/test_manager.py b/tests/test_manager.py index 191caee..7b0c340 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -20,7 +20,7 @@ import unittest from typing import Dict -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch from pipecat.frames.frames import ( LLMMessagesAppendFrame, @@ -54,9 +54,11 @@ async def asyncSetUp(self): self.mock_llm = OpenAILLMService(api_key="") self.mock_llm.register_function = MagicMock() - # Create mock assistant aggregator with function call tracking + # Create mock assistant aggregator with public property only self.mock_assistant_aggregator = MagicMock() - self.mock_assistant_aggregator._function_calls_in_progress = {} + type(self.mock_assistant_aggregator).has_function_calls_in_progress = PropertyMock( + return_value=False # Default to no functions in progress + ) # Create mock context aggregator self.mock_context_aggregator = MagicMock() @@ -263,7 +265,9 @@ async def result_callback(result, properties=None): await func(params) - self.mock_assistant_aggregator._function_calls_in_progress = {} + # Set up the property mock to return False (no functions in progress) + property_mock = PropertyMock(return_value=False) + type(self.mock_assistant_aggregator).has_function_calls_in_progress = property_mock # Execute the context_updated callback self.assertIsNotNone(context_updated_callback, "Context updated callback not set") @@ -307,7 +311,9 @@ async def result_callback(result, properties=None): ) await func(params) - self.mock_assistant_aggregator._function_calls_in_progress = {} + # Set up the property mock to return False (no functions in progress) + property_mock = PropertyMock(return_value=False) + type(self.mock_assistant_aggregator).has_function_calls_in_progress = property_mock # Execute the context_updated callback self.assertIsNotNone(context_updated_callback, "Context updated callback not set") @@ -664,7 +670,9 @@ async def result_callback(result, properties=None): # Call function await transition_func(params) - self.mock_assistant_aggregator._function_calls_in_progress = {} + # Set up the property mock to return False (no functions in progress) + property_mock = PropertyMock(return_value=False) + type(self.mock_assistant_aggregator).has_function_calls_in_progress = property_mock # Execute the context updated callback which should trigger the error self.assertIsNotNone(context_updated_callback, "Context updated callback not set") @@ -1412,18 +1420,19 @@ async def result_callback(result, properties=None): len(context_callbacks), 2, "Both functions should create context callbacks" ) - # Initially both functions are "in progress" - self.mock_assistant_aggregator._function_calls_in_progress = {"id1": True, "id2": True} + # Create a mock property that we can control dynamically + property_mock = PropertyMock() + type(self.mock_assistant_aggregator).has_function_calls_in_progress = property_mock - # First function completes - should not transition yet - self.mock_assistant_aggregator._function_calls_in_progress = {"id2": True} + # First function completes - should not transition yet (functions still in progress) + property_mock.return_value = True await context_callbacks[0]() self.assertEqual( transitions_executed, 0, "Should not transition while functions still pending" ) - # Second function completes - should transition now - self.mock_assistant_aggregator._function_calls_in_progress = {} + # Second function completes - should transition now (no functions in progress) + property_mock.return_value = False await context_callbacks[1]() self.assertEqual( transitions_executed, 1, "Should transition exactly once when all functions complete" From dd1baf5b531bafb0c1c5b513dd7699f4a8483ce8 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Mon, 16 Jun 2025 15:01:21 -0400 Subject: [PATCH 4/5] Code review fixes, remove _set_node from test_manager --- src/pipecat_flows/manager.py | 5 +-- tests/test_manager.py | 66 ++++++++++++++++++------------------ 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index 307117c..e9203ce 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -368,8 +368,9 @@ async def transition_func(params: FunctionCallParams) -> None: ) # Determine if this is an edge function - has_explicit_transition = bool(transition_to) or bool(transition_callback) - is_edge_function = bool(next_node) or has_explicit_transition + is_edge_function = ( + bool(next_node) or bool(transition_to) or bool(transition_callback) + ) if is_edge_function: # Store transition info for coordinated execution diff --git a/tests/test_manager.py b/tests/test_manager.py index 7b0c340..94c1c57 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -95,8 +95,8 @@ async def asyncSetUp(self): self.static_flow_config = { "initial_node": "start", "nodes": { - "start": self.sample_node, - "next_node": self.sample_node, + "start": {"name": "start", **self.sample_node}, + "next_node": {"name": "next_node", **self.sample_node}, }, } @@ -168,7 +168,7 @@ async def test_static_flow_transitions(self): # In static flows, transitions happen through set_node with a # predefined node configuration from the flow_config - await flow_manager._set_node("next_node", flow_manager.nodes["next_node"]) + await flow_manager.set_node_from_config(flow_manager.nodes["next_node"]) # Verify node transition occurred self.assertEqual(flow_manager.current_node, "next_node") @@ -242,7 +242,7 @@ async def test_handler(args: FlowArgs) -> FlowResult: } # Test old style callback - await flow_manager._set_node("old_style", old_style_node) + await flow_manager.set_node_from_config(old_style_node) func = flow_manager.llm.register_function.call_args[0][1] # Store the context_updated callback @@ -294,7 +294,7 @@ async def result_callback(result, properties=None): } # Test new style callback - await flow_manager._set_node("new_style", new_style_node) + await flow_manager.set_node_from_config(new_style_node) func = flow_manager.llm.register_function.call_args[0][1] # Reset context_updated callback @@ -335,12 +335,12 @@ async def test_node_validation(self): # Test missing task_messages invalid_config = {"functions": []} with self.assertRaises(FlowError) as context: - await flow_manager._set_node("test", invalid_config) + await flow_manager.set_node_from_config(invalid_config) self.assertIn("missing required 'task_messages' field", str(context.exception)) # Test valid config - valid_config = {"task_messages": []} - await flow_manager._set_node("test", valid_config) + valid_config = {"name": "test", "task_messages": []} + await flow_manager.set_node_from_config(valid_config) self.assertEqual(flow_manager.current_node, "test") self.assertEqual(flow_manager.current_functions, set()) @@ -358,7 +358,7 @@ async def test_function_registration(self): self.mock_llm.register_function.reset_mock() # Set node with function - await flow_manager._set_node("test", self.sample_node) + await flow_manager.set_node_from_config(self.sample_node) # Verify function was registered self.mock_llm.register_function.assert_called_once() @@ -388,7 +388,7 @@ async def test_action_execution(self): self.mock_task.queue_frame.reset_mock() # Set node with actions - await flow_manager._set_node("test", node_with_actions) + await flow_manager.set_node_from_config(node_with_actions) assert_tts_speak_frames_queued(self.mock_task, ["Pre action", "Post action"]) @@ -408,7 +408,7 @@ async def test_error_handling(self): # Test setting node before initialization with self.assertRaises(FlowTransitionError): - await flow_manager._set_node("test", self.sample_node) + await flow_manager.set_node_from_config(self.sample_node) # Initialize normally await flow_manager.initialize() @@ -417,7 +417,7 @@ async def test_error_handling(self): # Test node setting error self.mock_task.queue_frames.side_effect = Exception("Queue error") with self.assertRaises(FlowError): - await flow_manager._set_node("test", self.sample_node) + await flow_manager.set_node_from_config(self.sample_node) # Verify flow manager remains initialized despite error self.assertTrue(flow_manager.initialized) @@ -439,7 +439,7 @@ async def test_state_management(self): self.mock_task.queue_frames.reset_mock() # Verify state persists across node transitions - await flow_manager._set_node("test", self.sample_node) + await flow_manager.set_node_from_config(self.sample_node) self.assertEqual(flow_manager.state["test_key"], test_value) async def test_multiple_function_registration(self): @@ -467,7 +467,7 @@ async def test_multiple_function_registration(self): ], } - await flow_manager._set_node("test", node_config) + await flow_manager.set_node_from_config(node_config) # Verify all functions were registered self.assertEqual(self.mock_llm.register_function.call_count, 3) @@ -577,11 +577,12 @@ async def test_node_validation_edge_cases(self): "functions": [{"type": "function"}], # Missing name } with self.assertRaises(FlowError) as context: - await flow_manager._set_node("test", invalid_config) + await flow_manager.set_node_from_config(invalid_config) self.assertIn("invalid format", str(context.exception)) # Test node function without handler or transition_to invalid_config = { + "name": "test", "task_messages": [{"role": "system", "content": "Test"}], "functions": [ { @@ -603,7 +604,7 @@ def capture_warning(msg, *args, **kwargs): warning_message = msg with patch("loguru.logger.warning", side_effect=capture_warning): - await flow_manager._set_node("test", invalid_config) + await flow_manager.set_node_from_config(invalid_config) self.assertIsNotNone(warning_message) self.assertIn( "Function 'test_func' in node 'test' has neither handler, transition_to, nor transition_callback", @@ -645,7 +646,7 @@ async def failing_handler(args, flow_manager): } # Set up node and get registered function - await flow_manager._set_node("test", test_node) + await flow_manager.set_node_from_config(test_node) transition_func = flow_manager.llm.register_function.call_args[0][1] # Track the result and context_updated callback @@ -721,7 +722,7 @@ async def test_action_execution_error_handling(self): # Should raise FlowError due to invalid actions with self.assertRaises(FlowError): - await flow_manager._set_node("test", node_config) + await flow_manager.set_node_from_config(node_config) # Verify error handling for pre and post actions separately with self.assertRaises(FlowError): @@ -785,7 +786,7 @@ async def test_handler(args): } # Set node and verify function registration - await flow_manager._set_node("test", node_config) + await flow_manager.set_node_from_config(node_config) # Verify both functions were registered self.assertIn("test1", flow_manager.current_functions) @@ -825,7 +826,7 @@ async def test_handler_main(args): ], } - await flow_manager._set_node("test", node_config) + await flow_manager.set_node_from_config(node_config) self.assertIn("test_function", flow_manager.current_functions) finally: @@ -857,7 +858,7 @@ async def test_function_token_handling_not_found(self): } with self.assertRaises(FlowError) as context: - await flow_manager._set_node("test", node_config) + await flow_manager.set_node_from_config(node_config) self.assertIn("Function 'nonexistent_handler' not found", str(context.exception)) @@ -898,7 +899,7 @@ async def test_handler(args): ], } - await flow_manager._set_node("test", node_config) + await flow_manager.set_node_from_config(node_config) # Get the registered function and test it name, func = self.mock_llm.register_function.call_args[0] @@ -947,7 +948,7 @@ async def test_role_message_inheritance(self): } # Set first node and verify UpdateFrame - await flow_manager._set_node("first", first_node) + await flow_manager.set_node_from_config(first_node) first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call first_frames = first_call[0][0] update_frames = [f for f in first_frames if isinstance(f, LLMMessagesUpdateFrame)] @@ -959,7 +960,7 @@ async def test_role_message_inheritance(self): # Reset mock and set second node self.mock_task.queue_frames.reset_mock() - await flow_manager._set_node("second", second_node) + await flow_manager.set_node_from_config(second_node) # Verify AppendFrame for second node first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call @@ -985,7 +986,7 @@ async def test_frame_type_selection(self): } # First node should use UpdateFrame - await flow_manager._set_node("first", test_node) + await flow_manager.set_node_from_config(test_node) first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call first_frames = first_call[0][0] self.assertTrue( @@ -1001,7 +1002,7 @@ async def test_frame_type_selection(self): self.mock_task.queue_frames.reset_mock() # Second node should use AppendFrame - await flow_manager._set_node("second", test_node) + await flow_manager.set_node_from_config(test_node) first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call second_frames = first_call[0][0] self.assertTrue( @@ -1163,8 +1164,7 @@ async def test_completion_timing(self): self.mock_task.queue_frames.reset_mock() self.mock_context_aggregator.user().get_context_frame.reset_mock() - await flow_manager._set_node( - "initial", + await flow_manager.set_node_from_config( { "task_messages": [{"role": "system", "content": "Test"}], "functions": [], @@ -1189,7 +1189,7 @@ async def test_completion_timing(self): self.mock_task.queue_frames.reset_mock() self.mock_context_aggregator.user().get_context_frame.reset_mock() - await flow_manager._set_node("next", next_node) + await flow_manager.set_node_from_config(next_node) # Should see context update and completion trigger again self.assertTrue(self.mock_task.queue_frames.called) @@ -1230,7 +1230,7 @@ async def test_transition_configuration_exclusivity(self): # Should raise error when trying to use both with self.assertRaises(FlowError) as context: - await flow_manager._set_node("test", test_node) + await flow_manager.set_node_from_config(test_node) self.assertIn( "Cannot specify both transition_to and transition_callback", str(context.exception) ) @@ -1298,7 +1298,7 @@ async def test_node_without_functions(self): } # Set node and verify it works without error - await flow_manager._set_node("no_functions", node_config) + await flow_manager.set_node_from_config(node_config) # Verify current_functions is empty set self.assertEqual(flow_manager.current_functions, set()) @@ -1327,7 +1327,7 @@ async def test_node_with_empty_functions(self): } # Set node and verify it works without error - await flow_manager._set_node("empty_functions", node_config) + await flow_manager.set_node_from_config(node_config) # Verify current_functions is empty set self.assertEqual(flow_manager.current_functions, set()) @@ -1389,7 +1389,7 @@ async def edge_handler_2(args): ], } - await flow_manager._set_node("test", node_config) + await flow_manager.set_node_from_config(node_config) # Get both registered functions func1 = None From d22d2ae61a172e2acc48f959b10beee74a66ef46 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Mon, 16 Jun 2025 16:37:45 -0400 Subject: [PATCH 5/5] Add clarifying comment about self_pending_transition in _set_node --- src/pipecat_flows/manager.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index e9203ce..50b04ed 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -588,6 +588,10 @@ async def _set_node(self, node_id: str, node_config: NodeConfig) -> None: raise FlowTransitionError(f"{self.__class__.__name__} must be initialized first") try: + # Clear any pending transition state when starting a new node + # This ensures clean state regardless of how we arrived here: + # - Normal transition flow (already cleared in _check_and_execute_transition) + # - Direct calls to set_node/set_node_from_config self._pending_transition = None self._validate_node_config(node_id, node_config)