diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..65177ab2 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,56 @@ +name: test + +on: + workflow_dispatch: + push: + branches: + - main + pull_request: + branches: + - '**' + +concurrency: + group: build-test-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + test: + name: 'Unit Tests' + runs-on: ubuntu-latest + steps: + - name: Checkout repo + uses: actions/checkout@v4 + + - name: Set up Python + id: setup_python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Cache virtual environment + uses: actions/cache@v3 + with: + key: venv-${{ runner.os }}-${{ steps.setup_python.outputs.python-version}}-${{ hashFiles('dev-requirements.txt') }}-${{ hashFiles('test-requirements.txt') }} + path: .venv + + - name: Setup virtual environment + run: | + python -m venv .venv + + - name: Install dependencies + run: | + source .venv/bin/activate + python -m pip install --upgrade pip + pip install -r dev-requirements.txt -r test-requirements.txt + pip install -e . + + - name: Test with pytest + run: | + source .venv/bin/activate + pytest tests/ --cov=pipecat_flows --cov-report=xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + fail_ci_if_error: true diff --git a/dev-requirements.txt b/dev-requirements.txt new file mode 100644 index 00000000..38842c68 --- /dev/null +++ b/dev-requirements.txt @@ -0,0 +1,7 @@ +build~=1.2.1 +pip-tools~=7.4.1 +pytest~=8.3.2 +pytest-asyncio~=0.23.5 +pytest-cov~=4.1.0 +ruff~=0.6.7 +setuptools~=72.2.0 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d662141e..6554d16a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=64"] # Removed setuptools_scm +requires = ["setuptools>=64"] build-backend = "setuptools.build_meta" [project] @@ -9,7 +9,7 @@ description = "Conversation Flow management for Pipecat AI applications" license = { text = "BSD 2-Clause License" } readme = "README.md" requires-python = ">=3.10" -keywords = ["pipecat", " conversation", "flows", "state machine", "ai", "llm"] +keywords = ["pipecat", "conversation", "flows", "state machine", "ai", "llm"] classifiers = [ "Development Status :: 4 - Beta", "Intended Audience :: Developers", @@ -27,5 +27,10 @@ dependencies = [ Source = "https://github.com/pipecat-ai/pipecat-flows" Website = "https://www.pipecat.ai" +[tool.pytest.ini_options] +pythonpath = ["src"] +testpaths = ["tests"] +asyncio_mode = "auto" + [tool.ruff] line-length = 100 \ No newline at end of file diff --git a/src/pipecat_flows/state.py b/src/pipecat_flows/state.py index ac559982..d889b20e 100644 --- a/src/pipecat_flows/state.py +++ b/src/pipecat_flows/state.py @@ -36,6 +36,11 @@ def __init__(self, flow_config: FlowConfig, llm): Raises: ValueError: If required configuration keys are missing """ + if "initial_node" not in flow_config: + raise ValueError("Flow config must specify 'initial_node'") + if "nodes" not in flow_config: + raise ValueError("Flow config must specify 'nodes'") + self.nodes: Dict[str, NodeConfig] = {} self.current_node: str = flow_config["initial_node"] self.adapter = create_adapter(llm) @@ -48,13 +53,20 @@ def _load_config(self, config: FlowConfig): config: Dictionary containing the flow configuration Raises: - ValueError: If required configuration keys are missing + ValueError: If required configuration keys are missing or invalid """ - if "initial_node" not in config: - raise ValueError("Flow config must specify 'initial_node'") - if "nodes" not in config: - raise ValueError("Flow config must specify 'nodes'") - + initial_node = config["initial_node"] + if initial_node not in config["nodes"]: + raise ValueError(f"Initial node '{initial_node}' not found in nodes") + + # Validate node structure + for node_id, node in config["nodes"].items(): + if "messages" not in node: + raise ValueError(f"Node '{node_id}' missing required 'messages' field") + if "functions" not in node: + raise ValueError(f"Node '{node_id}' missing required 'functions' field") + + # Load the nodes self.nodes = config["nodes"] def get_current_messages(self) -> List[dict]: diff --git a/test-requirements.txt b/test-requirements.txt new file mode 100644 index 00000000..dbfa4e3e --- /dev/null +++ b/test-requirements.txt @@ -0,0 +1,5 @@ +pipecat-ai>=0.0.49 +loguru~=0.7.2 +anthropic~=0.30.0 +google-generativeai~=0.7.2 +openai~=1.37.2 \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_actions.py b/tests/test_actions.py new file mode 100644 index 00000000..e27d2cef --- /dev/null +++ b/tests/test_actions.py @@ -0,0 +1,179 @@ +import unittest +from unittest.mock import AsyncMock, patch + +from pipecat.frames.frames import EndFrame, TTSSpeakFrame + +from pipecat_flows.actions import ActionManager +from pipecat_flows.exceptions import ActionError + + +class TestActionManager(unittest.IsolatedAsyncioTestCase): + def setUp(self): + """ + Set up test fixtures before each test. + + Creates: + - Mock PipelineTask for frame queueing + - Mock TTS service for speech synthesis + - ActionManager instance with mocked dependencies + """ + self.mock_task = AsyncMock() + self.mock_task.queue_frame = AsyncMock() + + self.mock_tts = AsyncMock() + self.mock_tts.say = AsyncMock() + + self.action_manager = ActionManager(self.mock_task, self.mock_tts) + + async def test_initialization(self): + """Test ActionManager initialization and default handlers""" + # Verify built-in action handlers are registered + self.assertIn("tts_say", self.action_manager.action_handlers) + self.assertIn("end_conversation", self.action_manager.action_handlers) + + # Test initialization without TTS service + action_manager_no_tts = ActionManager(self.mock_task, None) + self.assertIsNone(action_manager_no_tts.tts) + + async def test_tts_action(self): + """Test basic TTS action execution""" + action = {"type": "tts_say", "text": "Hello"} + await self.action_manager.execute_actions([action]) + + # Verify TTS service was called with correct text + self.mock_tts.say.assert_called_once_with("Hello") + + @patch("loguru.logger.error") + async def test_tts_action_no_text(self, mock_logger): + """Test TTS action with missing text field""" + action = {"type": "tts_say"} # Missing text field + + # The implementation logs error but doesn't raise + await self.action_manager.execute_actions([action]) + + # Verify error was logged + mock_logger.assert_called_with("TTS action missing 'text' field") + + # Verify TTS service was not called + self.mock_tts.say.assert_not_called() + + @patch("loguru.logger.warning") + async def test_tts_action_no_service(self, mock_logger): + """Test TTS action when no TTS service is provided""" + action_manager = ActionManager(self.mock_task, None) + action = {"type": "tts_say", "text": "Hello"} + + # Should log warning but not raise error + await action_manager.execute_actions([action]) + + # Verify warning was logged + mock_logger.assert_called_with("TTS action called but no TTS service provided") + + # Verify no frames were queued + self.mock_task.queue_frame.assert_not_called() + + async def test_end_conversation_action(self): + """Test basic end conversation action""" + action = {"type": "end_conversation"} + await self.action_manager.execute_actions([action]) + + # Verify EndFrame was queued + self.mock_task.queue_frame.assert_called_once() + frame = self.mock_task.queue_frame.call_args[0][0] + self.assertIsInstance(frame, EndFrame) + + async def test_end_conversation_with_goodbye(self): + """Test end conversation action with goodbye message""" + action = {"type": "end_conversation", "text": "Goodbye!"} + await self.action_manager.execute_actions([action]) + + # Verify both frames were queued in correct order + self.assertEqual(self.mock_task.queue_frame.call_count, 2) + + # Verify TTSSpeakFrame + first_frame = self.mock_task.queue_frame.call_args_list[0][0][0] + self.assertIsInstance(first_frame, TTSSpeakFrame) + self.assertEqual(first_frame.text, "Goodbye!") + + # Verify EndFrame + second_frame = self.mock_task.queue_frame.call_args_list[1][0][0] + self.assertIsInstance(second_frame, EndFrame) + + async def test_custom_action(self): + """Test registering and executing custom actions""" + mock_handler = AsyncMock() + self.action_manager._register_action("custom", mock_handler) + + # Verify handler was registered + self.assertIn("custom", self.action_manager.action_handlers) + + # Execute custom action + action = {"type": "custom", "data": "test"} + await self.action_manager.execute_actions([action]) + + # Verify handler was called with correct data + mock_handler.assert_called_once_with(action) + + async def test_invalid_action(self): + """Test handling invalid actions""" + # Test missing type + with self.assertRaises(ActionError) as context: + await self.action_manager.execute_actions([{}]) + self.assertIn("missing required 'type' field", str(context.exception)) + + # Test unknown action type + with self.assertRaises(ActionError) as context: + await self.action_manager.execute_actions([{"type": "invalid"}]) + self.assertIn("No handler registered", str(context.exception)) + + async def test_multiple_actions(self): + """Test executing multiple actions in sequence""" + actions = [ + {"type": "tts_say", "text": "First"}, + {"type": "tts_say", "text": "Second"}, + ] + await self.action_manager.execute_actions(actions) + + # Verify TTS was called twice in correct order + self.assertEqual(self.mock_tts.say.call_count, 2) + expected_calls = [unittest.mock.call("First"), unittest.mock.call("Second")] + self.assertEqual(self.mock_tts.say.call_args_list, expected_calls) + + def test_register_invalid_handler(self): + """Test registering invalid action handlers""" + # Test non-callable handler + with self.assertRaises(ValueError) as context: + self.action_manager._register_action("invalid", "not_callable") + self.assertIn("must be callable", str(context.exception)) + + # Test None handler + with self.assertRaises(ValueError) as context: + self.action_manager._register_action("invalid", None) + self.assertIn("must be callable", str(context.exception)) + + async def test_none_or_empty_actions(self): + """Test handling None or empty action lists""" + # Test None actions + await self.action_manager.execute_actions(None) + self.mock_task.queue_frame.assert_not_called() + self.mock_tts.say.assert_not_called() + + # Test empty list + await self.action_manager.execute_actions([]) + self.mock_task.queue_frame.assert_not_called() + self.mock_tts.say.assert_not_called() + + @patch("loguru.logger.error") + async def test_action_error_handling(self, mock_logger): + """Test error handling during action execution""" + # Configure TTS mock to raise an error + self.mock_tts.say.side_effect = Exception("TTS error") + + action = {"type": "tts_say", "text": "Hello"} + await self.action_manager.execute_actions([action]) + + # Verify error was logged + mock_logger.assert_called_with("TTS error: TTS error") + + # Verify action was still marked as executed (doesn't raise) + self.mock_tts.say.assert_called_once() diff --git a/tests/test_adapters.py b/tests/test_adapters.py new file mode 100644 index 00000000..4932de0c --- /dev/null +++ b/tests/test_adapters.py @@ -0,0 +1,237 @@ +import unittest +from unittest.mock import MagicMock + +from pipecat.services.anthropic import AnthropicLLMService +from pipecat.services.google import GoogleLLMService +from pipecat.services.openai import OpenAILLMService + +from pipecat_flows.adapters import ( + AnthropicAdapter, + GeminiAdapter, + LLMAdapter, + OpenAIAdapter, + create_adapter, +) + + +class TestLLMAdapter(unittest.TestCase): + """Test the abstract base LLMAdapter class""" + + def test_abstract_methods(self): + """Verify that LLMAdapter cannot be instantiated without implementing all methods""" + + class IncompleteAdapter(LLMAdapter): + # Missing implementation of abstract methods + pass + + with self.assertRaises(TypeError): + IncompleteAdapter() + + class PartialAdapter(LLMAdapter): + def get_function_name(self, function_def): + return "test" + + # Still missing other required methods + + with self.assertRaises(TypeError): + PartialAdapter() + + +class TestLLMAdapters(unittest.TestCase): + def setUp(self): + """Set up test cases with sample function definitions for each provider""" + # OpenAI format + self.openai_function = { + "type": "function", + "function": { + "name": "test_function", + "description": "Test function", + "parameters": {"type": "object", "properties": {"param1": {"type": "string"}}}, + }, + } + + self.openai_function_call = {"name": "test_function", "arguments": {"param1": "value1"}} + + # Anthropic format + self.anthropic_function = { + "name": "test_function", + "description": "Test function", + "input_schema": {"type": "object", "properties": {"param1": {"type": "string"}}}, + } + + self.anthropic_function_call = {"name": "test_function", "arguments": {"param1": "value1"}} + + # Gemini format + self.gemini_function = { + "function_declarations": [ + { + "name": "test_function", + "description": "Test function", + "parameters": {"type": "object", "properties": {"param1": {"type": "string"}}}, + } + ] + } + + self.gemini_function_call = {"name": "test_function", "args": {"param1": "value1"}} + + # Message formats + self.openai_message = {"role": "system", "content": "Test message"} + + self.null_message = {"role": "system", "content": None} + + self.anthropic_message = { + "role": "user", + "content": [{"type": "text", "text": "Test message"}], + } + + self.gemini_message = {"role": "user", "content": "Test message"} + + def test_openai_adapter(self): + """Test OpenAI format handling""" + adapter = OpenAIAdapter() + + # Test function name extraction + self.assertEqual(adapter.get_function_name(self.openai_function), "test_function") + + # Test function arguments extraction + args = adapter.get_function_args(self.openai_function_call) + self.assertEqual(args, {"param1": "value1"}) + + # Test message content extraction + self.assertEqual(adapter.get_message_content(self.openai_message), "Test message") + + # Test null message content + # The implementation returns None for null content + self.assertIsNone(adapter.get_message_content(self.null_message)) + + # Test function formatting + formatted = adapter.format_functions([self.openai_function]) + self.assertEqual(formatted, [self.openai_function]) + + def test_anthropic_adapter(self): + """Test Anthropic format handling""" + adapter = AnthropicAdapter() + + # Test function name extraction + self.assertEqual(adapter.get_function_name(self.anthropic_function), "test_function") + + # Test function arguments extraction + self.assertEqual( + adapter.get_function_args(self.anthropic_function_call), {"param1": "value1"} + ) + + # Test message content extraction + self.assertEqual(adapter.get_message_content(self.anthropic_message), "Test message") + + # Test function formatting + formatted = adapter.format_functions([self.openai_function]) + self.assertTrue("input_schema" in formatted[0]) + self.assertEqual(formatted[0]["name"], "test_function") + + def test_gemini_adapter(self): + """Test Gemini format handling""" + adapter = GeminiAdapter() + + # Test function name extraction from function declarations + self.assertEqual( + adapter.get_function_name(self.gemini_function["function_declarations"][0]), + "test_function", + ) + + # Test function arguments extraction + self.assertEqual(adapter.get_function_args(self.gemini_function_call), {"param1": "value1"}) + + # Test message content extraction + self.assertEqual(adapter.get_message_content(self.gemini_message), "Test message") + + # Test function formatting + formatted = adapter.format_functions([self.openai_function]) + self.assertTrue("function_declarations" in formatted[0]) + + def test_adapter_factory(self): + """Test adapter creation based on LLM service type""" + # Test with valid LLM services + openai_llm = MagicMock(spec=OpenAILLMService) + self.assertIsInstance(create_adapter(openai_llm), OpenAIAdapter) + + anthropic_llm = MagicMock(spec=AnthropicLLMService) + self.assertIsInstance(create_adapter(anthropic_llm), AnthropicAdapter) + + gemini_llm = MagicMock(spec=GoogleLLMService) + self.assertIsInstance(create_adapter(gemini_llm), GeminiAdapter) + + def test_adapter_factory_error_cases(self): + """Test error cases in adapter creation""" + # Test with None + with self.assertRaises(ValueError) as context: + create_adapter(None) + self.assertIn("Unsupported LLM type", str(context.exception)) + + # Test with invalid service type + invalid_llm = MagicMock() + with self.assertRaises(ValueError) as context: + create_adapter(invalid_llm) + self.assertIn("Unsupported LLM type", str(context.exception)) + + def test_null_and_empty_values(self): + """Test handling of null and empty values""" + adapters = [OpenAIAdapter(), AnthropicAdapter(), GeminiAdapter()] + + for adapter in adapters: + # Test empty function call + empty_call = {"name": "test"} + self.assertEqual(adapter.get_function_args(empty_call), {}) + + # Test empty message + empty_message = {"role": "user", "content": ""} + self.assertEqual(adapter.get_message_content(empty_message), "") + + def test_special_characters_handling(self): + """Test handling of special characters in messages and function calls""" + special_chars = "!@#$%^&*()_+-=[]{}|;:'\",.<>?/~`" + + # Test in message content + message_with_special = {"role": "user", "content": f"Test with {special_chars}"} + + adapters = [OpenAIAdapter(), AnthropicAdapter(), GeminiAdapter()] + for adapter in adapters: + content = adapter.get_message_content(message_with_special) + self.assertEqual(content, f"Test with {special_chars}") + + # Test in function arguments + # Each adapter might handle arguments differently, so test them separately + + # OpenAI + openai_adapter = OpenAIAdapter() + openai_call = {"name": "test", "arguments": {"param1": special_chars}} + args = openai_adapter.get_function_args(openai_call) + self.assertEqual(args["param1"], special_chars) + + # Anthropic + anthropic_adapter = AnthropicAdapter() + anthropic_call = {"name": "test", "arguments": {"param1": special_chars}} + args = anthropic_adapter.get_function_args(anthropic_call) + self.assertEqual(args["param1"], special_chars) + + # Gemini + gemini_adapter = GeminiAdapter() + gemini_call = { + "name": "test", + "args": {"param1": special_chars}, # Note: Gemini uses 'args' instead of 'arguments' + } + args = gemini_adapter.get_function_args(gemini_call) + self.assertEqual(args["param1"], special_chars) + + def test_function_schema_validation(self): + """Test validation of function schemas during conversion""" + adapters = [OpenAIAdapter(), AnthropicAdapter(), GeminiAdapter()] + + # Test with minimal valid schema + minimal_function = { + "type": "function", + "function": {"name": "test", "parameters": {"type": "object", "properties": {}}}, + } + + for adapter in adapters: + formatted = adapter.format_functions([minimal_function]) + self.assertTrue(len(formatted) > 0) diff --git a/tests/test_manager.py b/tests/test_manager.py new file mode 100644 index 00000000..6c705dfc --- /dev/null +++ b/tests/test_manager.py @@ -0,0 +1,453 @@ +import asyncio +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from pipecat.frames.frames import ( + EndFrame, + LLMMessagesAppendFrame, + LLMMessagesUpdateFrame, + LLMSetToolsFrame, +) +from pipecat.services.anthropic import AnthropicLLMService +from pipecat.services.google import GoogleLLMService +from pipecat.services.openai import OpenAILLMService + +from pipecat_flows import FlowManager +from pipecat_flows.exceptions import ( + FlowTransitionError, + InvalidFunctionError, +) + + +class TestFlowManager(unittest.IsolatedAsyncioTestCase): + def setUp(self): + """Set up test fixtures before each test""" + # Reset mock call counts for each test + self.mock_task = AsyncMock() + self.mock_task.queue_frame = AsyncMock() + self.mock_tts = AsyncMock() + self.mock_tts.say = AsyncMock() + + # Create fresh LLM mocks for each test + self.mock_openai = MagicMock(spec=OpenAILLMService) + self.mock_anthropic = MagicMock(spec=AnthropicLLMService) + self.mock_gemini = MagicMock(spec=GoogleLLMService) + + # Provider-specific flow configurations + self.openai_config = { + "initial_node": "start", + "nodes": { + "start": { + "messages": [{"role": "system", "content": "Start node"}], + "functions": [ + { + "type": "function", + "function": { + "name": "process", + "description": "Process data", + "parameters": { + "type": "object", + "properties": {"data": {"type": "string"}}, + }, + }, + }, + { + "type": "function", + "function": { + "name": "middle", + "description": "Go to middle", + "parameters": {}, + }, + }, + ], + "pre_actions": [{"type": "tts_say", "text": "Starting..."}], + }, + "middle": { + "messages": [{"role": "system", "content": "Middle node"}], + "functions": [ + { + "type": "function", + "function": { + "name": "end", + "description": "End conversation", + "parameters": {}, + }, + } + ], + "post_actions": [{"type": "tts_say", "text": "Processing complete"}], + }, + "end": { + "messages": [{"role": "system", "content": "End node"}], + "functions": [], + "pre_actions": [ + {"type": "tts_say", "text": "Goodbye!"}, + {"type": "end_conversation"}, + ], + }, + }, + } + + self.anthropic_config = { + "initial_node": "start", + "nodes": { + "start": { + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "Start node"}]} + ], + "functions": [ + { + "name": "process", + "description": "Process data", + "input_schema": { + "type": "object", + "properties": {"data": {"type": "string"}}, + }, + }, + {"name": "middle", "description": "Go to middle", "input_schema": {}}, + ], + "pre_actions": [{"type": "tts_say", "text": "Starting..."}], + }, + "middle": { + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "Middle node"}]} + ], + "functions": [ + {"name": "end", "description": "End conversation", "input_schema": {}} + ], + }, + "end": { + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "End node"}]} + ], + "functions": [], + }, + }, + } + + self.gemini_config = { + "initial_node": "start", + "nodes": { + "start": { + "messages": [{"role": "system", "content": "Start node"}], + "functions": [ + { + "function_declarations": [ + { + "name": "process", + "description": "Process data", + "parameters": { + "type": "object", + "properties": {"data": {"type": "string"}}, + }, + }, + {"name": "middle", "description": "Go to middle", "parameters": {}}, + ] + } + ], + "pre_actions": [{"type": "tts_say", "text": "Starting..."}], + }, + "middle": { + "messages": [{"role": "system", "content": "Middle node"}], + "functions": [ + { + "function_declarations": [ + {"name": "end", "description": "End conversation", "parameters": {}} + ] + } + ], + }, + "end": {"messages": [{"role": "system", "content": "End node"}], "functions": []}, + }, + } + + async def test_initialization_all_providers(self): + """Test initialization with all LLM providers""" + initial_messages = [{"role": "system", "content": "Initial context"}] + + for config, llm in [ + (self.openai_config, self.mock_openai), + (self.anthropic_config, self.mock_anthropic), + (self.gemini_config, self.mock_gemini), + ]: + # Reset mock call counts + self.mock_task.reset_mock() + + flow_manager = FlowManager(config, self.mock_task, llm, self.mock_tts) + await flow_manager.initialize(initial_messages) + + # Verify initialization state + self.assertTrue(flow_manager.initialized) + + # Verify frames were queued + calls = self.mock_task.queue_frame.call_args_list + self.assertEqual(len(calls), 2) # Should have exactly 2 calls + + # Verify first call is LLMMessagesUpdateFrame + self.assertIsInstance(calls[0].args[0], LLMMessagesUpdateFrame) + self.assertEqual( + calls[0].args[0].messages, initial_messages + config["nodes"]["start"]["messages"] + ) + + # Verify second call is LLMSetToolsFrame + self.assertIsInstance(calls[1].args[0], LLMSetToolsFrame) + self.assertEqual(calls[1].args[0].tools, config["nodes"]["start"]["functions"]) + + async def test_function_registration_all_providers(self): + """Test function registration for all providers""" + for config, llm in [ + (self.openai_config, self.mock_openai), + (self.anthropic_config, self.mock_anthropic), + (self.gemini_config, self.mock_gemini), + ]: + # Reset mock + llm.register_function.reset_mock() + + flow_manager = FlowManager(config, self.mock_task, llm, self.mock_tts) + await flow_manager.initialize([]) + + # Verify edge functions were registered with wrapper function + llm.register_function.assert_any_call( + "middle", + unittest.mock.ANY, # Accept any function since it's wrapped + ) + + async def test_transitions_all_providers(self): + """Test transitions with all providers""" + for config, llm in [ + (self.openai_config, self.mock_openai), + (self.anthropic_config, self.mock_anthropic), + (self.gemini_config, self.mock_gemini), + ]: + # Reset mock + self.mock_task.reset_mock() + + flow_manager = FlowManager(config, self.mock_task, llm, self.mock_tts) + await flow_manager.initialize([]) + + # Test valid transition + await flow_manager.handle_transition("middle") + + # Get all calls after initialization + calls = self.mock_task.queue_frame.call_args_list + + # Verify LLMMessagesAppendFrame and LLMSetToolsFrame were called + append_frames = [c for c in calls if isinstance(c.args[0], LLMMessagesAppendFrame)] + tools_frames = [c for c in calls if isinstance(c.args[0], LLMSetToolsFrame)] + + self.assertTrue(len(append_frames) > 0) + self.assertTrue(len(tools_frames) > 0) + + # Verify content of last frames + last_append = append_frames[-1] + self.assertEqual(last_append.args[0].messages, config["nodes"]["middle"]["messages"]) + + async def test_complex_flow_all_providers(self): + """Test complete flow sequences with all providers""" + for config, llm in [ + (self.openai_config, self.mock_openai), + (self.anthropic_config, self.mock_anthropic), + (self.gemini_config, self.mock_gemini), + ]: + # Reset mocks + self.mock_task.reset_mock() + self.mock_tts.reset_mock() + + flow_manager = FlowManager(config, self.mock_task, llm, self.mock_tts) + await flow_manager.initialize([]) + + # Execute complete flow + await flow_manager.handle_transition("middle") + await flow_manager.handle_transition("end") + + # Add small delay to allow actions to complete + await asyncio.sleep(0.1) + + # Verify key frames were queued + calls = self.mock_task.queue_frame.call_args_list + + # Print actual calls for debugging + print("\nActual calls:") + for call in calls: + print(f"- {type(call.args[0]).__name__}") + + # Verify message and tool frames + self.assertTrue( + any(isinstance(call.args[0], LLMMessagesAppendFrame) for call in calls), + "No LLMMessagesAppendFrame found", + ) + self.assertTrue( + any(isinstance(call.args[0], LLMSetToolsFrame) for call in calls), + "No LLMSetToolsFrame found", + ) + + # Verify end conversation actions were executed + if "pre_actions" in config["nodes"]["end"]: + end_actions = config["nodes"]["end"]["pre_actions"] + if any(action["type"] == "end_conversation" for action in end_actions): + # Either verify EndFrame was queued + end_frame_queued = any(isinstance(call.args[0], EndFrame) for call in calls) + # Or verify end_conversation action was registered + action_registered = any( + "end_conversation" in str(call) + for call in self.mock_task.queue_frame.mock_calls + ) + self.assertTrue( + end_frame_queued or action_registered, "No end conversation action found" + ) + + async def test_error_handling_all_providers(self): + """Test error handling for all providers""" + for config, llm in [ + (self.openai_config, self.mock_openai), + (self.anthropic_config, self.mock_anthropic), + (self.gemini_config, self.mock_gemini), + ]: + # Test uninitialized transition + flow_manager = FlowManager(config, self.mock_task, llm, self.mock_tts) + with self.assertRaises(FlowTransitionError): + await flow_manager.handle_transition("middle") + + # Initialize and try invalid transition + await flow_manager.initialize([]) + with self.assertRaises(InvalidFunctionError): + await flow_manager.handle_transition("nonexistent") + + # Test double initialization + flow_manager = FlowManager(config, self.mock_task, llm, self.mock_tts) + await flow_manager.initialize([]) + # Just verify it doesn't raise an error + await flow_manager.initialize([]) # Should log warning but not raise + + async def test_action_execution_all_providers(self): + """Test action execution for all providers""" + for config, llm in [ + (self.openai_config, self.mock_openai), + (self.anthropic_config, self.mock_anthropic), + (self.gemini_config, self.mock_gemini), + ]: + # Reset mocks + self.mock_tts.reset_mock() + self.mock_task.reset_mock() + + flow_manager = FlowManager(config, self.mock_task, llm, self.mock_tts) + await flow_manager.initialize([]) + + # Print debug information + print( + f"\nTesting config with pre_actions: {config['nodes']['start'].get('pre_actions')}" + ) + print(f"TTS mock calls: {self.mock_tts.say.mock_calls}") + print(f"TTS mock called: {self.mock_tts.say.called}") + + # Test start node pre-actions + if "pre_actions" in config["nodes"]["start"]: + pre_actions = config["nodes"]["start"]["pre_actions"] + for action in pre_actions: + if action["type"] == "tts_say": + # Verify the action handler was registered + self.assertIn( + "tts_say", + flow_manager.action_manager.action_handlers, + "TTS action handler not registered", + ) + + # Execute the action explicitly + await flow_manager.action_manager.execute_actions([action]) + + # Verify TTS was called with correct text + self.assertTrue( + self.mock_tts.say.called, f"TTS say not called with action: {action}" + ) + self.mock_tts.say.assert_called_with(action["text"]) + + # Test middle node post-actions + if "post_actions" in config["nodes"]["middle"]: + # Reset TTS mock for middle node actions + self.mock_tts.reset_mock() + + # Transition to middle + await flow_manager.handle_transition("middle") + + post_actions = config["nodes"]["middle"]["post_actions"] + for action in post_actions: + if action["type"] == "tts_say": + self.assertTrue( + self.mock_tts.say.called, + f"TTS say not called with post-action: {action}", + ) + self.mock_tts.say.assert_called_with(action["text"]) + + # Test end node pre-actions + if "pre_actions" in config["nodes"]["end"]: + # Reset TTS mock for end node actions + self.mock_tts.reset_mock() + + # Transition to end from current node + await flow_manager.handle_transition("end") + + pre_actions = config["nodes"]["end"]["pre_actions"] + for action in pre_actions: + if action["type"] == "tts_say": + self.assertTrue( + self.mock_tts.say.called, + f"TTS say not called with end pre-action: {action}", + ) + self.mock_tts.say.assert_called_with(action["text"]) + + # Add debug output for state transitions + print(f"\nFinal node: {flow_manager.flow.current_node}") + print(f"Available functions: {flow_manager.flow.get_available_function_names()}") + + async def test_action_manager_setup(self): + """Test that action manager is properly initialized""" + flow_manager = FlowManager( + self.openai_config, self.mock_task, self.mock_openai, self.mock_tts + ) + + # Verify action manager exists + self.assertIsNotNone(flow_manager.action_manager) # Changed from _action_manager + + # Verify built-in actions are registered + self.assertIn("tts_say", flow_manager.action_manager.action_handlers) + self.assertIn("end_conversation", flow_manager.action_manager.action_handlers) + + # Verify TTS service is properly set + self.assertEqual(flow_manager.action_manager.tts, self.mock_tts) + + @patch("loguru.logger.debug") + async def test_logging_all_providers(self, mock_logger): + """Test logging for all providers""" + for config, llm in [ + (self.openai_config, self.mock_openai), + (self.anthropic_config, self.mock_anthropic), + (self.gemini_config, self.mock_gemini), + ]: + # Reset mock + mock_logger.reset_mock() + + flow_manager = FlowManager(config, self.mock_task, llm, self.mock_tts) + await flow_manager.initialize([]) + await flow_manager.handle_transition("middle") + + # Verify transition logging + mock_logger.assert_any_call("Attempting transition from start to middle") + + async def test_null_and_empty_cases(self): + """Test handling of null and empty values""" + # Test with empty messages + flow_manager = FlowManager( + self.openai_config, self.mock_task, self.mock_openai, self.mock_tts + ) + await flow_manager.initialize([]) + + # Test with empty functions + config = self.openai_config.copy() + config["nodes"]["start"]["functions"] = [] + flow_manager = FlowManager(config, self.mock_task, self.mock_openai, self.mock_tts) + await flow_manager.initialize([]) + + # Test with missing optional fields + config = self.openai_config.copy() + del config["nodes"]["start"]["pre_actions"] + flow_manager = FlowManager(config, self.mock_task, self.mock_openai, self.mock_tts) + await flow_manager.initialize([]) diff --git a/tests/test_state.py b/tests/test_state.py new file mode 100644 index 00000000..b8b8429b --- /dev/null +++ b/tests/test_state.py @@ -0,0 +1,281 @@ +import unittest +from unittest.mock import MagicMock, patch + +from pipecat.services.anthropic import AnthropicLLMService +from pipecat.services.google import GoogleLLMService +from pipecat.services.openai import OpenAILLMService + +from pipecat_flows.adapters import AnthropicAdapter, GeminiAdapter, OpenAIAdapter +from pipecat_flows.state import FlowState + + +class TestFlowState(unittest.IsolatedAsyncioTestCase): + def setUp(self): + """Set up test cases with configs for different LLM providers""" + # Create mock LLM services + self.mock_openai = MagicMock(spec=OpenAILLMService) + self.mock_anthropic = MagicMock(spec=AnthropicLLMService) + self.mock_gemini = MagicMock(spec=GoogleLLMService) + + # OpenAI format config + self.openai_config = { + "initial_node": "start", + "nodes": { + "start": { + "messages": [{"role": "system", "content": "Start node"}], + "functions": [ + { + "type": "function", + "function": { + "name": "process", + "description": "Process node function", + "parameters": {}, + }, + }, + { + "type": "function", + "function": { + "name": "middle", + "description": "Transition to middle", + "parameters": {}, + }, + }, + ], + }, + "middle": { + "messages": [{"role": "system", "content": "Middle node"}], + "functions": [ + { + "type": "function", + "function": { + "name": "end", + "description": "Transition to end", + "parameters": {}, + }, + } + ], + }, + "end": {"messages": [{"role": "system", "content": "End node"}], "functions": []}, + }, + } + + # Anthropic format config + self.anthropic_config = { + "initial_node": "start", + "nodes": { + "start": { + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "Start node"}]} + ], + "functions": [ + { + "name": "process", + "description": "Process node function", + "input_schema": {}, + }, + { + "name": "middle", + "description": "Transition to middle", + "input_schema": {}, + }, + ], + }, + "middle": { + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "Middle node"}]} + ], + "functions": [ + {"name": "end", "description": "Transition to end", "input_schema": {}} + ], + }, + "end": { + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "End node"}]} + ], + "functions": [], + }, + }, + } + + # Gemini format config + self.gemini_config = { + "initial_node": "start", + "nodes": { + "start": { + "messages": [{"role": "system", "content": "Start node"}], + "functions": [ + { + "function_declarations": [ + { + "name": "process", + "description": "Process node function", + "parameters": {}, + }, + { + "name": "middle", + "description": "Transition to middle", + "parameters": {}, + }, + ] + } + ], + }, + "middle": { + "messages": [{"role": "system", "content": "Middle node"}], + "functions": [ + { + "function_declarations": [ + { + "name": "end", + "description": "Transition to end", + "parameters": {}, + } + ] + } + ], + }, + "end": {"messages": [{"role": "system", "content": "End node"}], "functions": []}, + }, + } + + def test_initialization_with_different_llms(self): + """Test initialization with different LLM providers""" + # OpenAI + flow_openai = FlowState(self.openai_config, self.mock_openai) + self.assertEqual(flow_openai.current_node, "start") + functions = flow_openai.get_current_functions() + self.assertTrue(all("type" in f for f in functions)) + + # Anthropic + flow_anthropic = FlowState(self.anthropic_config, self.mock_anthropic) + functions = flow_anthropic.get_current_functions() + self.assertTrue(all("input_schema" in f for f in functions)) + + # Gemini + flow_gemini = FlowState(self.gemini_config, self.mock_gemini) + functions = flow_gemini.get_current_functions() + self.assertTrue("function_declarations" in functions[0]) + + @patch("pipecat_flows.state.create_adapter") + def test_initialization_errors(self, mock_create_adapter): + """Test initialization error cases""" + mock_create_adapter.return_value = OpenAIAdapter() + + # Test missing initial_node + invalid_config = {"nodes": {}} + with self.assertRaises(ValueError) as context: + FlowState(invalid_config, self.mock_openai) + self.assertEqual(str(context.exception), "Flow config must specify 'initial_node'") + + # Test missing nodes + invalid_config = {"initial_node": "start"} + with self.assertRaises(ValueError) as context: + FlowState(invalid_config, self.mock_openai) + self.assertEqual(str(context.exception), "Flow config must specify 'nodes'") + + # Test initial node not in nodes + invalid_config = {"initial_node": "invalid", "nodes": {}} + with self.assertRaises(ValueError) as context: + FlowState(invalid_config, self.mock_openai) + self.assertEqual(str(context.exception), "Initial node 'invalid' not found in nodes") + + def test_get_current_messages(self): + """Test retrieving messages for current node""" + flow = FlowState(self.openai_config, self.mock_openai) + messages = flow.get_current_messages() + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"], "Start node") + + def test_get_current_functions(self): + """Test retrieving functions for current node""" + flow = FlowState(self.openai_config, self.mock_openai) + functions = flow.get_current_functions() + self.assertEqual(len(functions), 2) + self.assertTrue(all("type" in f for f in functions)) + self.assertTrue(all("function" in f for f in functions)) + + def test_get_available_function_names(self): + """Test retrieving available function names""" + flow = FlowState(self.openai_config, self.mock_openai) + names = flow.get_available_function_names() + self.assertEqual(names, {"process", "middle"}) + + @patch("pipecat_flows.state.create_adapter") + def test_adapter_creation(self, mock_create_adapter): + """Test that the correct adapter is created for each LLM type""" + # Configure mock adapters + mock_create_adapter.return_value = OpenAIAdapter() + + FlowState(self.openai_config, self.mock_openai) + mock_create_adapter.assert_called_once_with(self.mock_openai) + + @patch("pipecat_flows.state.create_adapter") + def test_function_call_parsing(self, mock_create_adapter): + """Test parsing function calls for different LLM formats""" + # OpenAI + mock_create_adapter.return_value = OpenAIAdapter() + flow = FlowState(self.openai_config, self.mock_openai) + + openai_call = { + "type": "function", + "function": {"name": "process"}, + "arguments": {"data": "test"}, # Arguments at top level + } + self.assertEqual(flow.get_function_name_from_call(openai_call), "process") + self.assertEqual(flow.get_function_args_from_call(openai_call), {"data": "test"}) + + # Anthropic + mock_create_adapter.return_value = AnthropicAdapter() + flow = FlowState(self.anthropic_config, self.mock_anthropic) + + anthropic_call = {"name": "process", "arguments": {"data": "test"}} + self.assertEqual(flow.get_function_name_from_call(anthropic_call), "process") + self.assertEqual(flow.get_function_args_from_call(anthropic_call), {"data": "test"}) + + # Gemini + mock_create_adapter.return_value = GeminiAdapter() + flow = FlowState(self.gemini_config, self.mock_gemini) + + gemini_call = {"name": "process", "args": {"data": "test"}} + self.assertEqual(flow.get_function_name_from_call(gemini_call), "process") + self.assertEqual(flow.get_function_args_from_call(gemini_call), {"data": "test"}) + + def test_transition_with_different_formats(self): + """Test transitions work correctly with different function formats""" + # OpenAI + flow_openai = FlowState(self.openai_config, self.mock_openai) + self.assertEqual(flow_openai.transition("middle"), "middle") + + # Anthropic + flow_anthropic = FlowState(self.anthropic_config, self.mock_anthropic) + self.assertEqual(flow_anthropic.transition("middle"), "middle") + + # Gemini + flow_gemini = FlowState(self.gemini_config, self.mock_gemini) + self.assertEqual(flow_gemini.transition("middle"), "middle") + + def test_transition_edge_cases(self): + """Test transition edge cases and error conditions""" + flow = FlowState(self.openai_config, self.mock_openai) + + # Test transition with non-existent function + result = flow.transition("non_existent") + self.assertIsNone(result) + self.assertEqual(flow.get_current_node(), "start") + + # Test transition with node function (shouldn't change state) + result = flow.transition("process") + self.assertIsNone(result) + self.assertEqual(flow.get_current_node(), "start") + + # Test multiple valid transitions + self.assertEqual(flow.transition("middle"), "middle") + self.assertEqual(flow.get_current_node(), "middle") + self.assertEqual(flow.transition("end"), "end") + self.assertEqual(flow.get_current_node(), "end") + + def test_get_all_available_function_names(self): + """Test retrieving all function names across all nodes""" + flow = FlowState(self.openai_config, self.mock_openai) + all_names = flow.get_all_available_function_names() + self.assertEqual(all_names, {"process", "middle", "end"})