From b5f12e6877ff470021d42153b8a6a5a9516239b6 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Wed, 21 May 2025 13:36:34 -0400 Subject: [PATCH 01/47] =?UTF-8?q?[WIP]=20Initial=20progress=20on=20?= =?UTF-8?q?=E2=80=9Cdirect=20functions=E2=80=9D:=20the=20ability=20to=20di?= =?UTF-8?q?rectly=20provide=20Python=20functions=20as=20NodeConfig=20funct?= =?UTF-8?q?ions.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When direct functions are used: - Function name, description, properties, and required are all automatically read from the Python function signature and docstring - There’s no need for separate handler and transition callbacks --- src/pipecat_flows/types.py | 12 ++++++++++++ tests/test_flows_functions.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 tests/test_flows_functions.py diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index 1034f55..03dd323 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -208,6 +208,18 @@ def to_function_schema(self) -> FunctionSchema: ) +class FlowsFunction: + def __init__(self, function: Callable): + self.function = function + self._initialize_metadata() + + def _initialize_metadata(self): + self.name = self.function.__name__ + self.description = ( + self.function.__doc__ + ) # TODO: do we need a default description? what happens if it's None? + + class NodeConfigRequired(TypedDict): """Required fields for node configuration.""" diff --git a/tests/test_flows_functions.py b/tests/test_flows_functions.py new file mode 100644 index 0000000..06592b4 --- /dev/null +++ b/tests/test_flows_functions.py @@ -0,0 +1,35 @@ +import unittest + +from pipecat_flows.types import FlowsFunction + +# Copyright (c) 2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Tests for FlowsFunction class.""" + + +class TestFlowsFunction(unittest.TestCase): + def test_name_is_set_from_function(self): + """Test that FlowsFunction extracts the name from the function.""" + + def test_function(args): + return {} + + func = FlowsFunction(function=test_function) + self.assertEqual(func.name, "test_function") + + def test_description_is_set_from_function(self): + """Test that FlowsFunction extracts the description from the function.""" + + def test_function(args): + """This is a test function.""" + return {} + + func = FlowsFunction(function=test_function) + self.assertEqual(func.description, "This is a test function.") + + +if __name__ == "__main__": + unittest.main() From 643fd8fbe66b500a61768d4ed657b41fde374fec Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Wed, 21 May 2025 13:44:00 -0400 Subject: [PATCH 02/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/types.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index 03dd323..a7fe18c 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -16,6 +16,7 @@ and function interactions. """ +import inspect from dataclasses import dataclass from enum import Enum from typing import Any, Awaitable, Callable, Dict, List, Optional, TypedDict, TypeVar, Union @@ -215,9 +216,7 @@ def __init__(self, function: Callable): def _initialize_metadata(self): self.name = self.function.__name__ - self.description = ( - self.function.__doc__ - ) # TODO: do we need a default description? what happens if it's None? + self.description = inspect.getdoc(self.function) class NodeConfigRequired(TypedDict): From 59c2a334059ef513be8f3b2d0cc8f36274941013 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Wed, 21 May 2025 13:53:13 -0400 Subject: [PATCH 03/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index a7fe18c..a270f0e 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -216,7 +216,7 @@ def __init__(self, function: Callable): def _initialize_metadata(self): self.name = self.function.__name__ - self.description = inspect.getdoc(self.function) + self.description = inspect.getdoc(self.function) or "" class NodeConfigRequired(TypedDict): From c991f8d2933d7fdf2117db22e6cf28de94038e2c Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Wed, 21 May 2025 15:20:39 -0400 Subject: [PATCH 04/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/types.py | 139 +++++++++++++++++++++++++++++++++- tests/test_flows_functions.py | 47 ++++++++++-- 2 files changed, 179 insertions(+), 7 deletions(-) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index a270f0e..d769c8f 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -19,8 +19,24 @@ import inspect from dataclasses import dataclass from enum import Enum -from typing import Any, Awaitable, Callable, Dict, List, Optional, TypedDict, TypeVar, Union - +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + TypedDict, + TypeVar, + Union, + get_args, + get_origin, + get_type_hints, +) + +from loguru import logger from pipecat.adapters.schemas.function_schema import FunctionSchema T = TypeVar("T") @@ -215,9 +231,128 @@ def __init__(self, function: Callable): self._initialize_metadata() def _initialize_metadata(self): + # Get function name self.name = self.function.__name__ + + # Get function description + # TODO: should ignore args and return type, right? Just the top-level docstring? self.description = inspect.getdoc(self.function) or "" + # Get function properties as JSON schema + # TODO: also get whether each property is required + # TODO: is there a way to get "args" from doc string and use it to fill in descriptions? + self.properties = self._get_parameters_as_jsonschema(self.function) + + # TODO: maybe to better support things like enums, check if each type is a pydantic type and use its convert-to-jsonschema function + def _get_parameters_as_jsonschema(self, func: Callable) -> Dict[str, Any]: + """ + Get function parameters as a dictionary of JSON schemas. + + Args: + func: Function to get parameters from + + Returns: + A dictionary mapping each function parameter to its JSON schema + """ + + sig = inspect.signature(func) + hints = get_type_hints(func) + properties = {} + + # TODO: use param or ignore it + for name, param in sig.parameters.items(): + # Ignore 'self' parameter + if name == "self": + continue + + type_hint = hints.get(name) + + # Convert type hint to JSON schema + properties[name] = self._typehint_to_jsonschema(type_hint) + + return properties + + # TODO: test this way more, throwing crazy types at it + def _typehint_to_jsonschema(self, type_hint: Any) -> Dict[str, Any]: + """ + Convert a Python type hint to a JSON Schema. + + Args: + hint: A Python type hint + + Returns: + A dictionary representing the JSON Schema + """ + if type_hint is None: + return {} + + # Handle basic types + if type_hint is type(None): + return {"type": "null"} + if type_hint is str: + return {"type": "string"} + elif type_hint is int: + return {"type": "integer"} + elif type_hint is float: + return {"type": "number"} + elif type_hint is bool: + return {"type": "boolean"} + elif type_hint is dict or type_hint is Dict: + return {"type": "object"} + elif type_hint is list or type_hint is List: + return {"type": "array"} + + # Get origin and arguments for complex types + origin = get_origin(type_hint) + args = get_args(type_hint) + + # Handle Optional/Union types + if origin is Union: + # Check if this is an Optional (Union with None) + has_none = type(None) in args + non_none_args = [arg for arg in args if arg is not type(None)] + + if has_none and len(non_none_args) == 1: + # This is an Optional[X] + schema = self._typehint_to_jsonschema(non_none_args[0]) + schema["nullable"] = True + return schema + else: + # This is a general Union + return {"anyOf": [self._typehint_to_jsonschema(arg) for arg in args]} + + # Handle List, Tuple, Set with specific item types + if origin in (list, List, tuple, Tuple, set, Set) and args: + return {"type": "array", "items": self._typehint_to_jsonschema(args[0])} + + # Handle Dict with specific key/value types + if origin in (dict, Dict) and len(args) == 2: + # For JSON Schema, keys must be strings + return {"type": "object", "additionalProperties": self._typehint_to_jsonschema(args[1])} + + # Handle TypedDict + if hasattr(type_hint, "__annotations__"): + properties = {} + required = [] + + for field_name, field_type in get_type_hints(type_hint).items(): + properties[field_name] = self._typehint_to_jsonschema(field_type) + # Check if field is required (this is a simplification, might need adjustment) + if not getattr(type_hint, "__total__", True) or not isinstance( + field_type, Optional + ): + required.append(field_name) + + schema = {"type": "object", "properties": properties} + + if required: + schema["required"] = required + + return schema + + # Default to any type if we can't determine the specific schema + return {} + class NodeConfigRequired(TypedDict): """Required fields for node configuration.""" diff --git a/tests/test_flows_functions.py b/tests/test_flows_functions.py index 06592b4..99ddd99 100644 --- a/tests/test_flows_functions.py +++ b/tests/test_flows_functions.py @@ -1,3 +1,4 @@ +from typing import Union import unittest from pipecat_flows.types import FlowsFunction @@ -14,22 +15,58 @@ class TestFlowsFunction(unittest.TestCase): def test_name_is_set_from_function(self): """Test that FlowsFunction extracts the name from the function.""" - def test_function(args): + def my_function(): return {} - func = FlowsFunction(function=test_function) - self.assertEqual(func.name, "test_function") + func = FlowsFunction(function=my_function) + self.assertEqual(func.name, "my_function") def test_description_is_set_from_function(self): """Test that FlowsFunction extracts the description from the function.""" - def test_function(args): + def my_function(): """This is a test function.""" return {} - func = FlowsFunction(function=test_function) + func = FlowsFunction(function=my_function) self.assertEqual(func.description, "This is a test function.") + def test_properties_are_set_from_function(self): + """Test that FlowsFunction extracts the properties from the function.""" + + def my_function_no_params(): + return {} + + func = FlowsFunction(function=my_function_no_params) + self.assertEqual(func.properties, {}) + + def my_function_simple_params(name: str, age: int, height: float): + return {} + + func = FlowsFunction(function=my_function_simple_params) + self.assertEqual( + func.properties, + {"name": {"type": "string"}, "age": {"type": "integer"}, "height": {"type": "number"}}, + ) + + def my_function_complex_params( + address_lines: list[str], extra: Union[dict[str, str], None] + ): + return {} + + func = FlowsFunction(function=my_function_complex_params) + self.assertEqual( + func.properties, + { + "address_lines": {"type": "array", "items": {"type": "string"}}, + "extra": { + "type": "object", + "additionalProperties": {"type": "string"}, + "nullable": True, + }, + }, + ) + if __name__ == "__main__": unittest.main() From e02fc9983ae5249b06c717aac90d7edfd8ee682b Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Wed, 21 May 2025 15:23:49 -0400 Subject: [PATCH 05/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_flows_functions.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/test_flows_functions.py b/tests/test_flows_functions.py index 99ddd99..2f05f02 100644 --- a/tests/test_flows_functions.py +++ b/tests/test_flows_functions.py @@ -40,17 +40,21 @@ def my_function_no_params(): func = FlowsFunction(function=my_function_no_params) self.assertEqual(func.properties, {}) - def my_function_simple_params(name: str, age: int, height: float): + def my_function_simple_params(name: str, age: int, height: Union[float, None]): return {} func = FlowsFunction(function=my_function_simple_params) self.assertEqual( func.properties, - {"name": {"type": "string"}, "age": {"type": "integer"}, "height": {"type": "number"}}, + { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "height": {"type": "number", "nullable": True}, + }, ) def my_function_complex_params( - address_lines: list[str], extra: Union[dict[str, str], None] + address_lines: list[str], nickname: Union[str, int], extra: Union[dict[str, str], None] ): return {} @@ -59,6 +63,7 @@ def my_function_complex_params( func.properties, { "address_lines": {"type": "array", "items": {"type": "string"}}, + "nickname": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "extra": { "type": "object", "additionalProperties": {"type": "string"}, From d4496b5375ba9d4e918f305efa31c08c2accf61e Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Wed, 21 May 2025 15:29:28 -0400 Subject: [PATCH 06/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_flows_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_flows_functions.py b/tests/test_flows_functions.py index 2f05f02..3a265d6 100644 --- a/tests/test_flows_functions.py +++ b/tests/test_flows_functions.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Optional, Union import unittest from pipecat_flows.types import FlowsFunction @@ -54,7 +54,7 @@ def my_function_simple_params(name: str, age: int, height: Union[float, None]): ) def my_function_complex_params( - address_lines: list[str], nickname: Union[str, int], extra: Union[dict[str, str], None] + address_lines: list[str], nickname: Union[str, int], extra: Optional[dict[str, str]] ): return {} From 44d579f6125ae6abea8ae1e48a5b73ba73ee31ff Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Wed, 21 May 2025 15:54:26 -0400 Subject: [PATCH 07/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/types.py | 3 ++- tests/test_flows_functions.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index d769c8f..279cb40 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -19,6 +19,7 @@ import inspect from dataclasses import dataclass from enum import Enum +import types from typing import ( Any, Awaitable, @@ -307,7 +308,7 @@ def _typehint_to_jsonschema(self, type_hint: Any) -> Dict[str, Any]: args = get_args(type_hint) # Handle Optional/Union types - if origin is Union: + if origin is Union or origin is types.UnionType: # Check if this is an Optional (Union with None) has_none = type(None) in args non_none_args = [arg for arg in args if arg is not type(None)] diff --git a/tests/test_flows_functions.py b/tests/test_flows_functions.py index 3a265d6..bb69c44 100644 --- a/tests/test_flows_functions.py +++ b/tests/test_flows_functions.py @@ -54,7 +54,7 @@ def my_function_simple_params(name: str, age: int, height: Union[float, None]): ) def my_function_complex_params( - address_lines: list[str], nickname: Union[str, int], extra: Optional[dict[str, str]] + address_lines: list[str], nickname: str | int, extra: Optional[dict[str, str]] ): return {} From 244e61894483c6ce7c318d444b773d3aa5c51a00 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Wed, 21 May 2025 16:31:22 -0400 Subject: [PATCH 08/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20union=20with=20`None`=20should=20?= =?UTF-8?q?not=20be=20conflated=20with=20being=20required=20(that=20should?= =?UTF-8?q?=20happen=20by=20way=20of=20a=20default=20argument)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/types.py | 37 +++++++++++++----------------- tests/test_flows_functions.py | 42 +++++++++++++++++++++++++++++------ 2 files changed, 51 insertions(+), 28 deletions(-) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index 279cb40..d835553 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -17,9 +17,9 @@ """ import inspect +import types from dataclasses import dataclass from enum import Enum -import types from typing import ( Any, Awaitable, @@ -239,28 +239,29 @@ def _initialize_metadata(self): # TODO: should ignore args and return type, right? Just the top-level docstring? self.description = inspect.getdoc(self.function) or "" - # Get function properties as JSON schema - # TODO: also get whether each property is required + # Get function parameters as JSON schemas, and the list of required parameters # TODO: is there a way to get "args" from doc string and use it to fill in descriptions? - self.properties = self._get_parameters_as_jsonschema(self.function) + self.properties, self.required = self._get_parameters_as_jsonschema(self.function) # TODO: maybe to better support things like enums, check if each type is a pydantic type and use its convert-to-jsonschema function - def _get_parameters_as_jsonschema(self, func: Callable) -> Dict[str, Any]: + def _get_parameters_as_jsonschema(self, func: Callable) -> Tuple[Dict[str, Any], List[str]]: """ - Get function parameters as a dictionary of JSON schemas. + Get function parameters as a dictionary of JSON schemas and a list of required parameters. Args: func: Function to get parameters from Returns: - A dictionary mapping each function parameter to its JSON schema + A tuple containing: + - A dictionary mapping each function parameter to its JSON schema + - A list of required parameter names """ sig = inspect.signature(func) hints = get_type_hints(func) properties = {} + required = [] - # TODO: use param or ignore it for name, param in sig.parameters.items(): # Ignore 'self' parameter if name == "self": @@ -271,7 +272,12 @@ def _get_parameters_as_jsonschema(self, func: Callable) -> Dict[str, Any]: # Convert type hint to JSON schema properties[name] = self._typehint_to_jsonschema(type_hint) - return properties + # Check if the parameter is required + # If the parameter has no default value, it's required + if param.default is inspect.Parameter.empty: + required.append(name) + + return properties, required # TODO: test this way more, throwing crazy types at it def _typehint_to_jsonschema(self, type_hint: Any) -> Dict[str, Any]: @@ -309,18 +315,7 @@ def _typehint_to_jsonschema(self, type_hint: Any) -> Dict[str, Any]: # Handle Optional/Union types if origin is Union or origin is types.UnionType: - # Check if this is an Optional (Union with None) - has_none = type(None) in args - non_none_args = [arg for arg in args if arg is not type(None)] - - if has_none and len(non_none_args) == 1: - # This is an Optional[X] - schema = self._typehint_to_jsonschema(non_none_args[0]) - schema["nullable"] = True - return schema - else: - # This is a general Union - return {"anyOf": [self._typehint_to_jsonschema(arg) for arg in args]} + return {"anyOf": [self._typehint_to_jsonschema(arg) for arg in args]} # Handle List, Tuple, Set with specific item types if origin in (list, List, tuple, Tuple, set, Set) and args: diff --git a/tests/test_flows_functions.py b/tests/test_flows_functions.py index bb69c44..36348f4 100644 --- a/tests/test_flows_functions.py +++ b/tests/test_flows_functions.py @@ -1,5 +1,5 @@ -from typing import Optional, Union import unittest +from typing import Optional, Union from pipecat_flows.types import FlowsFunction @@ -49,12 +49,12 @@ def my_function_simple_params(name: str, age: int, height: Union[float, None]): { "name": {"type": "string"}, "age": {"type": "integer"}, - "height": {"type": "number", "nullable": True}, + "height": {"anyOf": [{"type": "number"}, {"type": "null"}]}, }, ) def my_function_complex_params( - address_lines: list[str], nickname: str | int, extra: Optional[dict[str, str]] + address_lines: list[str], nickname: str | int | float, extra: Optional[dict[str, str]] ): return {} @@ -63,15 +63,43 @@ def my_function_complex_params( func.properties, { "address_lines": {"type": "array", "items": {"type": "string"}}, - "nickname": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "nickname": { + "anyOf": [{"type": "string"}, {"type": "integer"}, {"type": "number"}] + }, "extra": { - "type": "object", - "additionalProperties": {"type": "string"}, - "nullable": True, + "anyOf": [ + {"type": "object", "additionalProperties": {"type": "string"}}, + {"type": "null"}, + ] }, }, ) + def test_required_is_set_from_function(self): + """Test that FlowsFunction extracts the required properties from the function.""" + + def my_function_no_params(): + return {} + + func = FlowsFunction(function=my_function_no_params) + self.assertEqual(func.required, []) + + def my_function_simple_params(name: str, age: int, height: Union[float, None] = None): + return {} + + func = FlowsFunction(function=my_function_simple_params) + self.assertEqual(func.required, ["name", "age"]) + + def my_function_complex_params( + address_lines: list[str], + nickname: str | int | None = "Bud", + extra: Optional[dict[str, str]] = None, + ): + return {} + + func = FlowsFunction(function=my_function_complex_params) + self.assertEqual(func.required, ["address_lines"]) + if __name__ == "__main__": unittest.main() From a09ca72d936a535d92b1fc78cf49b57248f56fe9 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Wed, 21 May 2025 17:02:55 -0400 Subject: [PATCH 09/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20read=20parameter=20descriptions?= =?UTF-8?q?=20from=20docstring?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 1 + src/pipecat_flows/types.py | 23 ++++++++++++---- tests/test_flows_functions.py | 52 +++++++++++++++++++++++++++++++++-- 3 files changed, 68 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e4739d5..533829b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ classifiers = [ dependencies = [ "pipecat-ai>=0.0.67", "loguru~=0.7.2", + "docstring_parser~=0.16" ] [project.urls] diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index d835553..ac3556f 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -37,6 +37,7 @@ get_type_hints, ) +import docstring_parser from loguru import logger from pipecat.adapters.schemas.function_schema import FunctionSchema @@ -235,16 +236,21 @@ def _initialize_metadata(self): # Get function name self.name = self.function.__name__ + # Parse docstring for description and parameters + docstring = docstring_parser.parse(inspect.getdoc(self.function)) + # Get function description - # TODO: should ignore args and return type, right? Just the top-level docstring? - self.description = inspect.getdoc(self.function) or "" + self.description = (docstring.description or "").strip() # Get function parameters as JSON schemas, and the list of required parameters - # TODO: is there a way to get "args" from doc string and use it to fill in descriptions? - self.properties, self.required = self._get_parameters_as_jsonschema(self.function) + self.properties, self.required = self._get_parameters_as_jsonschema( + self.function, docstring.params + ) # TODO: maybe to better support things like enums, check if each type is a pydantic type and use its convert-to-jsonschema function - def _get_parameters_as_jsonschema(self, func: Callable) -> Tuple[Dict[str, Any], List[str]]: + def _get_parameters_as_jsonschema( + self, func: Callable, docstring_params: List[docstring_parser.DocstringParam] + ) -> Tuple[Dict[str, Any], List[str]]: """ Get function parameters as a dictionary of JSON schemas and a list of required parameters. @@ -272,11 +278,16 @@ def _get_parameters_as_jsonschema(self, func: Callable) -> Tuple[Dict[str, Any], # Convert type hint to JSON schema properties[name] = self._typehint_to_jsonschema(type_hint) - # Check if the parameter is required + # Add whether the parameter is required # If the parameter has no default value, it's required if param.default is inspect.Parameter.empty: required.append(name) + # Add parameter description from docstring + for doc_param in docstring_params: + if doc_param.arg_name == name: + properties[name]["description"] = doc_param.description or "" + return properties, required # TODO: test this way more, throwing crazy types at it diff --git a/tests/test_flows_functions.py b/tests/test_flows_functions.py index 36348f4..5ef09e3 100644 --- a/tests/test_flows_functions.py +++ b/tests/test_flows_functions.py @@ -24,13 +24,29 @@ def my_function(): def test_description_is_set_from_function(self): """Test that FlowsFunction extracts the description from the function.""" - def my_function(): + def my_function_short_description(): """This is a test function.""" return {} - func = FlowsFunction(function=my_function) + func = FlowsFunction(function=my_function_short_description) self.assertEqual(func.description, "This is a test function.") + def my_function_long_description(): + """ + This is a test function. + + It does some really cool stuff. + + Trust me, you'll want to use it. + """ + return {} + + func = FlowsFunction(function=my_function_long_description) + self.assertEqual( + func.description, + "This is a test function.\n\nIt does some really cool stuff.\n\nTrust me, you'll want to use it.", + ) + def test_properties_are_set_from_function(self): """Test that FlowsFunction extracts the properties from the function.""" @@ -100,6 +116,38 @@ def my_function_complex_params( func = FlowsFunction(function=my_function_complex_params) self.assertEqual(func.required, ["address_lines"]) + def test_property_descriptions_are_set_from_function(self): + """Test that FlowsFunction extracts the property descriptions from the function.""" + + def my_function(name: str, age: int, height: Union[float, None]): + """ + This is a test function. + + Args: + name (str): The name of the person. + age (int): The age of the person. + height (float | None): The height of the person in meters. Defaults to None. + """ + return {} + + func = FlowsFunction(function=my_function) + + # Validate that the function description is still set correctly even with the longer docstring + self.assertEqual(func.description, "This is a test function.") + + # Validate that the property descriptions are set correctly + self.assertEqual( + func.properties, + { + "name": {"type": "string", "description": "The name of the person."}, + "age": {"type": "integer", "description": "The age of the person."}, + "height": { + "anyOf": [{"type": "number"}, {"type": "null"}], + "description": "The height of the person in meters. Defaults to None.", + }, + }, + ) + if __name__ == "__main__": unittest.main() From 205b1337f073c8dfc3dd1d35c586a8c045489cda Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Thu, 22 May 2025 09:50:46 -0400 Subject: [PATCH 10/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/types.py | 2 +- tests/test_flows_direct_functions.py | 153 +++++++++++++++++++++++++++ 2 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 tests/test_flows_direct_functions.py diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index ac3556f..39cde9f 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -227,7 +227,7 @@ def to_function_schema(self) -> FunctionSchema: ) -class FlowsFunction: +class FlowsDirectFunction: def __init__(self, function: Callable): self.function = function self._initialize_metadata() diff --git a/tests/test_flows_direct_functions.py b/tests/test_flows_direct_functions.py new file mode 100644 index 0000000..f5e268e --- /dev/null +++ b/tests/test_flows_direct_functions.py @@ -0,0 +1,153 @@ +import unittest +from typing import Optional, Union + +from pipecat_flows.types import FlowsDirectFunction + +# Copyright (c) 2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Tests for FlowsDirectFunction class.""" + + +class TestFlowsDirectFunction(unittest.TestCase): + def test_name_is_set_from_function(self): + """Test that FlowsDirectFunction extracts the name from the function.""" + + def my_function(): + return {} + + func = FlowsDirectFunction(function=my_function) + self.assertEqual(func.name, "my_function") + + def test_description_is_set_from_function(self): + """Test that FlowsDirectFunction extracts the description from the function.""" + + def my_function_short_description(): + """This is a test function.""" + return {} + + func = FlowsDirectFunction(function=my_function_short_description) + self.assertEqual(func.description, "This is a test function.") + + def my_function_long_description(): + """ + This is a test function. + + It does some really cool stuff. + + Trust me, you'll want to use it. + """ + return {} + + func = FlowsDirectFunction(function=my_function_long_description) + self.assertEqual( + func.description, + "This is a test function.\n\nIt does some really cool stuff.\n\nTrust me, you'll want to use it.", + ) + + def test_properties_are_set_from_function(self): + """Test that FlowsDirectFunction extracts the properties from the function.""" + + def my_function_no_params(): + return {} + + func = FlowsDirectFunction(function=my_function_no_params) + self.assertEqual(func.properties, {}) + + def my_function_simple_params(name: str, age: int, height: Union[float, None]): + return {} + + func = FlowsDirectFunction(function=my_function_simple_params) + self.assertEqual( + func.properties, + { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "height": {"anyOf": [{"type": "number"}, {"type": "null"}]}, + }, + ) + + def my_function_complex_params( + address_lines: list[str], nickname: str | int | float, extra: Optional[dict[str, str]] + ): + return {} + + func = FlowsDirectFunction(function=my_function_complex_params) + self.assertEqual( + func.properties, + { + "address_lines": {"type": "array", "items": {"type": "string"}}, + "nickname": { + "anyOf": [{"type": "string"}, {"type": "integer"}, {"type": "number"}] + }, + "extra": { + "anyOf": [ + {"type": "object", "additionalProperties": {"type": "string"}}, + {"type": "null"}, + ] + }, + }, + ) + + def test_required_is_set_from_function(self): + """Test that FlowsDirectFunction extracts the required properties from the function.""" + + def my_function_no_params(): + return {} + + func = FlowsDirectFunction(function=my_function_no_params) + self.assertEqual(func.required, []) + + def my_function_simple_params(name: str, age: int, height: Union[float, None] = None): + return {} + + func = FlowsDirectFunction(function=my_function_simple_params) + self.assertEqual(func.required, ["name", "age"]) + + def my_function_complex_params( + address_lines: list[str], + nickname: str | int | None = "Bud", + extra: Optional[dict[str, str]] = None, + ): + return {} + + func = FlowsDirectFunction(function=my_function_complex_params) + self.assertEqual(func.required, ["address_lines"]) + + def test_property_descriptions_are_set_from_function(self): + """Test that FlowsDirectFunction extracts the property descriptions from the function.""" + + def my_function(name: str, age: int, height: Union[float, None]): + """ + This is a test function. + + Args: + name (str): The name of the person. + age (int): The age of the person. + height (float | None): The height of the person in meters. Defaults to None. + """ + return {} + + func = FlowsDirectFunction(function=my_function) + + # Validate that the function description is still set correctly even with the longer docstring + self.assertEqual(func.description, "This is a test function.") + + # Validate that the property descriptions are set correctly + self.assertEqual( + func.properties, + { + "name": {"type": "string", "description": "The name of the person."}, + "age": {"type": "integer", "description": "The age of the person."}, + "height": { + "anyOf": [{"type": "number"}, {"type": "null"}], + "description": "The height of the person in meters. Defaults to None.", + }, + }, + ) + + +if __name__ == "__main__": + unittest.main() From d663e8b0f8ab4a1fe20bfae959e468948dba37c6 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Thu, 22 May 2025 11:36:23 -0400 Subject: [PATCH 11/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...restaurant_reservation_direct_functions.py | 280 ++++++++++++++++++ src/pipecat_flows/manager.py | 15 +- src/pipecat_flows/types.py | 8 + tests/test_flows_functions.py | 153 ---------- 4 files changed, 301 insertions(+), 155 deletions(-) create mode 100644 examples/dynamic/restaurant_reservation_direct_functions.py delete mode 100644 tests/test_flows_functions.py diff --git a/examples/dynamic/restaurant_reservation_direct_functions.py b/examples/dynamic/restaurant_reservation_direct_functions.py new file mode 100644 index 0000000..d2f627f --- /dev/null +++ b/examples/dynamic/restaurant_reservation_direct_functions.py @@ -0,0 +1,280 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys +from pathlib import Path +from typing import Dict, Optional + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.cartesia.tts import CartesiaTTSService +from pipecat.services.deepgram.stt import DeepgramSTTService +from pipecat.services.openai.llm import OpenAILLMService +from pipecat.transports.services.daily import DailyParams, DailyTransport + +from pipecat_flows import FlowArgs, FlowManager, FlowResult, FlowsFunctionSchema, NodeConfig + +sys.path.append(str(Path(__file__).parent.parent)) +import argparse + +from runner import configure + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +# Mock reservation system +class MockReservationSystem: + """Simulates a restaurant reservation system API.""" + + def __init__(self): + # Mock data: Times that are "fully booked" + self.booked_times = {"7:00 PM", "8:00 PM"} # Changed to AM/PM format + + async def check_availability( + self, party_size: int, requested_time: str + ) -> tuple[bool, list[str]]: + """Check if a table is available for the given party size and time.""" + # Simulate API call delay + await asyncio.sleep(0.5) + + # Check if time is booked + is_available = requested_time not in self.booked_times + + # If not available, suggest alternative times + alternatives = [] + if not is_available: + base_times = ["5:00 PM", "6:00 PM", "7:00 PM", "8:00 PM", "9:00 PM", "10:00 PM"] + alternatives = [t for t in base_times if t not in self.booked_times] + + return is_available, alternatives + + +# Initialize mock system +reservation_system = MockReservationSystem() + + +# Type definitions for function results +class PartySizeResult(FlowResult): + size: int + status: str + + +class TimeResult(FlowResult): + status: str + time: str + available: bool + alternative_times: list[str] + + +# Function handlers +async def collect_party_size(size: int) -> tuple[PartySizeResult, NodeConfig]: + """ + Record the number of people in the party. + + Args: + size (int): Number of people in the party. Must be between 1 and 12. + """ + # Result: the recorded party size + result = PartySizeResult(size=size, status="success") + + # Next node: time selection + next_node = create_time_selection_node() + + return result, next_node + + +async def check_availability(time: str, party_size: int) -> tuple[TimeResult, NodeConfig]: + """ + Check availability for requested time. + + Args: + time (str): Requested reservation time in "HH:MM AM/PM" format. Must be between 5 PM and 10 PM. + party_size (int): Number of people in the party. + """ + # Check availability with mock API + is_available, alternative_times = await reservation_system.check_availability(party_size, time) + + # Result: availability status and alternative times, if any + result = TimeResult( + status="success", time=time, available=is_available, alternative_times=alternative_times + ) + + # Next node: confirmation or no availability + if is_available: + next_node = create_confirmation_node() + else: + next_node = create_no_availability_node(alternative_times) + + return result, next_node + + +# TODO: might be cleaner to not return a tuple but instead write result within the function, avoiding to have to return None. +# (Although, in the case of a function that *only* returns a result, we'd still have the same issue) +async def end_conversation() -> tuple[Optional[FlowResult], NodeConfig]: + """End the conversation.""" + return None, create_end_node() + + +# Node configurations +def create_initial_node(wait_for_user: bool) -> NodeConfig: + """Create initial node for party size collection.""" + return { + "role_messages": [ + { + "role": "system", + "content": "You are a restaurant reservation assistant for La Maison, an upscale French restaurant. Be casual and friendly. This is a voice conversation, so avoid special characters and emojis.", + } + ], + "task_messages": [ + { + "role": "system", + "content": "Warmly greet the customer and ask how many people are in their party. This is your only job for now; if the customer asks for something else, politely remind them you can't do it.", + } + ], + "functions": [collect_party_size], + "respond_immediately": not wait_for_user, + } + + +def create_time_selection_node() -> NodeConfig: + """Create node for time selection and availability check.""" + logger.debug("Creating time selection node") + return { + "task_messages": [ + { + "role": "system", + "content": "Ask what time they'd like to dine. Restaurant is open 5 PM to 10 PM.", + } + ], + "functions": [check_availability], + } + + +def create_confirmation_node() -> NodeConfig: + """Create confirmation node for successful reservations.""" + return { + "task_messages": [ + { + "role": "system", + "content": "Confirm the reservation details and ask if they need anything else.", + } + ], + "functions": [end_conversation], + } + + +def create_no_availability_node(alternative_times: list[str]) -> NodeConfig: + """Create node for handling no availability.""" + times_list = ", ".join(alternative_times) + return { + "task_messages": [ + { + "role": "system", + "content": ( + f"Apologize that the requested time is not available. " + f"Suggest these alternative times: {times_list}. " + "Ask if they'd like to try one of these times." + ), + } + ], + "functions": [check_availability, end_conversation], + } + + +def create_end_node() -> NodeConfig: + """Create the final node.""" + return { + "task_messages": [ + { + "role": "system", + "content": "Thank them and end the conversation.", + } + ], + "functions": [], + "post_actions": [{"type": "end_conversation"}], + } + + +# Main setup +async def main(wait_for_user: bool): + async with aiohttp.ClientSession() as session: + (room_url, _) = await configure(session) + + transport = DailyTransport( + room_url, + None, + "Reservation bot", + DailyParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + ) + + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady + ) + llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") + + context = OpenAILLMContext() + context_aggregator = llm.create_context_aggregator(context) + + pipeline = Pipeline( + [ + transport.input(), + stt, + context_aggregator.user(), + llm, + tts, + transport.output(), + context_aggregator.assistant(), + ] + ) + + task = PipelineTask(pipeline, params=PipelineParams(allow_interruptions=True)) + + # Initialize flow manager + flow_manager = FlowManager( + task=task, + llm=llm, + context_aggregator=context_aggregator, + ) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + await transport.capture_participant_transcription(participant["id"]) + logger.debug("Initializing flow manager") + await flow_manager.initialize() + logger.debug("Setting initial node") + await flow_manager.set_node("initial", create_initial_node(wait_for_user)) + + runner = PipelineRunner() + await runner.run(task) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Restaurant reservation bot") + parser.add_argument( + "--wait-for-user", + action="store_true", + help="If set, the bot will wait for the user to speak first", + ) + args = parser.parse_args() + + asyncio.run(main(args.wait_for_user)) diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index eb061eb..18276e7 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -49,6 +49,7 @@ FlowArgs, FlowConfig, FlowResult, + FlowsDirectFunction, FlowsFunctionSchema, FunctionHandler, NodeConfig, @@ -522,8 +523,12 @@ async def register_function_schema(schema): ) for func_config in functions_list: + # Handle FlowsDirectFunctions + if callable(func_config): + print("[pk] It's a direct function!") + pass # Handle Gemini's nested function declarations as a special case - if ( + elif ( not isinstance(func_config, FlowsFunctionSchema) and "function_declarations" in func_config ): @@ -533,8 +538,8 @@ async def register_function_schema(schema): {"function_declarations": [declaration]} ) await register_function_schema(schema) + # Convert to FlowsFunctionSchema if needed and process it else: - # Convert to FlowsFunctionSchema if needed and process it schema = ( func_config if isinstance(func_config, FlowsFunctionSchema) @@ -686,6 +691,7 @@ def _validate_node_config(self, node_id: str, config: NodeConfig) -> None: 2. Functions have valid configurations based on their type: - FlowsFunctionSchema objects have proper handler/transition fields - Dictionary format functions have valid handler/transition entries + - Direct functions are valid according to the FlowsDirectFunctions validation 3. Edge functions (matching node names) are allowed without handlers/transitions Args: @@ -704,6 +710,11 @@ def _validate_node_config(self, node_id: str, config: NodeConfig) -> None: # Validate each function configuration if there are any for func in functions_list: + # If the function is callable, validate using FlowsDirectFunction + if callable(func): + FlowsDirectFunction.validate_function(func) + continue + # Extract function name using adapter (handles all formats) try: name = self.adapter.get_function_name(func) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index 39cde9f..a26dfab 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -232,6 +232,14 @@ def __init__(self, function: Callable): self.function = function self._initialize_metadata() + # TODO: implement this. + # Throw an error on any validation failure. + # Add unit tests. + # Test with different callable types (functions, lambdas, methods, etc.), and different siganatures. + @staticmethod + def validate_function(function: Callable) -> None: + pass + def _initialize_metadata(self): # Get function name self.name = self.function.__name__ diff --git a/tests/test_flows_functions.py b/tests/test_flows_functions.py deleted file mode 100644 index 5ef09e3..0000000 --- a/tests/test_flows_functions.py +++ /dev/null @@ -1,153 +0,0 @@ -import unittest -from typing import Optional, Union - -from pipecat_flows.types import FlowsFunction - -# Copyright (c) 2025, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -"""Tests for FlowsFunction class.""" - - -class TestFlowsFunction(unittest.TestCase): - def test_name_is_set_from_function(self): - """Test that FlowsFunction extracts the name from the function.""" - - def my_function(): - return {} - - func = FlowsFunction(function=my_function) - self.assertEqual(func.name, "my_function") - - def test_description_is_set_from_function(self): - """Test that FlowsFunction extracts the description from the function.""" - - def my_function_short_description(): - """This is a test function.""" - return {} - - func = FlowsFunction(function=my_function_short_description) - self.assertEqual(func.description, "This is a test function.") - - def my_function_long_description(): - """ - This is a test function. - - It does some really cool stuff. - - Trust me, you'll want to use it. - """ - return {} - - func = FlowsFunction(function=my_function_long_description) - self.assertEqual( - func.description, - "This is a test function.\n\nIt does some really cool stuff.\n\nTrust me, you'll want to use it.", - ) - - def test_properties_are_set_from_function(self): - """Test that FlowsFunction extracts the properties from the function.""" - - def my_function_no_params(): - return {} - - func = FlowsFunction(function=my_function_no_params) - self.assertEqual(func.properties, {}) - - def my_function_simple_params(name: str, age: int, height: Union[float, None]): - return {} - - func = FlowsFunction(function=my_function_simple_params) - self.assertEqual( - func.properties, - { - "name": {"type": "string"}, - "age": {"type": "integer"}, - "height": {"anyOf": [{"type": "number"}, {"type": "null"}]}, - }, - ) - - def my_function_complex_params( - address_lines: list[str], nickname: str | int | float, extra: Optional[dict[str, str]] - ): - return {} - - func = FlowsFunction(function=my_function_complex_params) - self.assertEqual( - func.properties, - { - "address_lines": {"type": "array", "items": {"type": "string"}}, - "nickname": { - "anyOf": [{"type": "string"}, {"type": "integer"}, {"type": "number"}] - }, - "extra": { - "anyOf": [ - {"type": "object", "additionalProperties": {"type": "string"}}, - {"type": "null"}, - ] - }, - }, - ) - - def test_required_is_set_from_function(self): - """Test that FlowsFunction extracts the required properties from the function.""" - - def my_function_no_params(): - return {} - - func = FlowsFunction(function=my_function_no_params) - self.assertEqual(func.required, []) - - def my_function_simple_params(name: str, age: int, height: Union[float, None] = None): - return {} - - func = FlowsFunction(function=my_function_simple_params) - self.assertEqual(func.required, ["name", "age"]) - - def my_function_complex_params( - address_lines: list[str], - nickname: str | int | None = "Bud", - extra: Optional[dict[str, str]] = None, - ): - return {} - - func = FlowsFunction(function=my_function_complex_params) - self.assertEqual(func.required, ["address_lines"]) - - def test_property_descriptions_are_set_from_function(self): - """Test that FlowsFunction extracts the property descriptions from the function.""" - - def my_function(name: str, age: int, height: Union[float, None]): - """ - This is a test function. - - Args: - name (str): The name of the person. - age (int): The age of the person. - height (float | None): The height of the person in meters. Defaults to None. - """ - return {} - - func = FlowsFunction(function=my_function) - - # Validate that the function description is still set correctly even with the longer docstring - self.assertEqual(func.description, "This is a test function.") - - # Validate that the property descriptions are set correctly - self.assertEqual( - func.properties, - { - "name": {"type": "string", "description": "The name of the person."}, - "age": {"type": "integer", "description": "The age of the person."}, - "height": { - "anyOf": [{"type": "number"}, {"type": "null"}], - "description": "The height of the person in meters. Defaults to None.", - }, - }, - ) - - -if __name__ == "__main__": - unittest.main() From 6399351eba22c0f986d83724af8eb7ff137f1148 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Thu, 22 May 2025 14:13:57 -0400 Subject: [PATCH 12/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20direct=20function=20calls=20are?= =?UTF-8?q?=20working!=20Still=20more=20work=20to=20do=20before=20it?= =?UTF-8?q?=E2=80=99s=20ready,=20though?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/manager.py | 166 +++++++++++++++++++++++++++++++++-- src/pipecat_flows/types.py | 15 +++- 2 files changed, 172 insertions(+), 9 deletions(-) diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index 18276e7..f958ab7 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -26,6 +26,7 @@ import asyncio import inspect import sys +import uuid from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Union, cast from loguru import logger @@ -384,6 +385,98 @@ async def on_context_updated() -> None: return transition_func + async def _create_transition_func_from_direct_function( + self, func: FlowsDirectFunction + ) -> Callable: + """Create a transition function for the given direct function. + + Args: + func: The FlowsDirectFunction to create a transition function for + + Returns: + Callable: The created transition function + """ + name = func.name + + def decrease_pending_function_calls() -> None: + """Decrease the pending function calls counter if greater than zero.""" + if self._pending_function_calls > 0: + self._pending_function_calls -= 1 + logger.debug( + f"Function call completed: {name} (remaining: {self._pending_function_calls})" + ) + + async def on_context_updated_edge(next_node: NodeConfig, result_callback: Callable) -> None: + """Handle context updates for edge functions with transitions.""" + try: + decrease_pending_function_calls() + + # Only process transition if this was the last pending call + if self._pending_function_calls == 0: + # TODO: figure out how to elegantly get name of the next node + # TODO: handle possibility of next_node being a string identifying a node? for static flows? mabe we can just say direct functions are only supported in dynamic flows? + await self.set_node(str(uuid.uuid4()), next_node) + # Reset counter after transition completes + self._pending_function_calls = 0 + logger.debug("Reset pending function calls counter") + else: + logger.debug( + f"Skipping transition, {self._pending_function_calls} calls still pending" + ) + except Exception as e: + logger.error(f"Error in transition: {str(e)}") + self._pending_function_calls = 0 + await result_callback( + {"status": "error", "error": str(e)}, + properties=None, # Clear properties to prevent further callbacks + ) + raise # Re-raise to prevent further processing + + async def on_context_updated_node() -> None: + """Handle context updates for node functions without transitions.""" + decrease_pending_function_calls() + + async def transition_func(params: FunctionCallParams) -> None: + """Inner function that handles the actual tool invocation.""" + # TODO: implement + try: + # Track pending function call + self._pending_function_calls += 1 + logger.debug( + f"Function call pending: {name} (total: {self._pending_function_calls})" + ) + + # Execute function + # TODO: we need to pass FlowManager, too, huh (just like in _call_handler())... + result, next_node = await func.function(**params.arguments) + if result is None: + result = {"status": "acknowledged"} + logger.debug(f"Function called without 'handler' logic: {name}") + else: + logger.debug(f"Function with 'handler' logic {name}") + + # For edge functions (where there's a next node), prevent LLM completion until transition (run_llm=False) + # For node functions, allow immediate completion (run_llm=True) + async def on_context_updated() -> None: + if next_node: + await on_context_updated_edge(next_node, params.result_callback) + else: + await on_context_updated_node() + + properties = FunctionCallResultProperties( + run_llm=not next_node, + on_context_updated=on_context_updated, + ) + await params.result_callback(result, properties=properties) + + except Exception as e: + logger.error(f"Error in transition function {name}: {str(e)}") + self._pending_function_calls = 0 + error_result = {"status": "error", "error": str(e)} + await params.result_callback(error_result) + + return transition_func + def _lookup_function(self, func_name: str) -> Callable: """Look up a function by name in the main module. @@ -411,6 +504,57 @@ def _lookup_function(self, func_name: str) -> Callable: raise FlowError(error_message) + async def _register_function_schema( + self, schema: FlowsFunctionSchema, new_functions: Set[str] + ) -> None: + # TODO: add docstring + name = schema.name + handler = schema.handler + if name not in self.current_functions: + try: + # Handle special token format (e.g. "__function__:function_name") + if isinstance(handler, str) and handler.startswith("__function__:"): + func_name = handler.split(":")[1] + handler = self._lookup_function(func_name) + + # Create transition function + transition_func = await self._create_transition_func( + name, handler, schema.transition_to, schema.transition_callback + ) + + # Register function with LLM + self.llm.register_function( + name, + transition_func, + ) + + new_functions.add(name) + logger.debug(f"Registered function: {name}") + except Exception as e: + logger.error(f"Failed to register function {name}: {str(e)}") + raise FlowError(f"Function registration failed: {str(e)}") from e + + async def _register_direct_function( + self, func: FlowsDirectFunction, new_functions: Set[str] + ) -> None: + name = func.name + if name not in self.current_functions: + try: + # Create transition function + transition_func = await self._create_transition_func_from_direct_function(func) + + # Register function with LLM + self.llm.register_function( + name, + transition_func, + ) + + new_functions.add(name) + logger.debug(f"Registered function: {name}") + except Exception as e: + logger.error(f"Failed to register function {name}: {str(e)}") + raise FlowError(f"Function registration failed: {str(e)}") from e + async def _register_function( self, name: str, @@ -505,7 +649,7 @@ async def set_node(self, node_id: str, node_config: NodeConfig) -> None: messages.extend(node_config["task_messages"]) # Register functions and prepare tools - tools = [] + tools: List[FlowsFunctionSchema | FlowsDirectFunction] = [] new_functions: Set[str] = set() # Get functions list with default empty list if not provided @@ -514,19 +658,25 @@ async def set_node(self, node_id: str, node_config: NodeConfig) -> None: async def register_function_schema(schema): """Helper to register a single FlowsFunctionSchema.""" tools.append(schema) - await self._register_function( - name=schema.name, + await self._register_function_schema( + schema=schema, + new_functions=new_functions, + ) + + async def register_direct_function(func): + """Helper to register a single direct function.""" + direct_function = FlowsDirectFunction(function=func) + tools.append(direct_function) + # TODO: ensure that "traditional" (non-direct-function) examples still work + await self._register_direct_function( + func=direct_function, new_functions=new_functions, - handler=schema.handler, - transition_to=schema.transition_to, - transition_callback=schema.transition_callback, ) for func_config in functions_list: # Handle FlowsDirectFunctions if callable(func_config): - print("[pk] It's a direct function!") - pass + await register_direct_function(func_config) # Handle Gemini's nested function declarations as a special case elif ( not isinstance(func_config, FlowsFunctionSchema) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index a26dfab..9e02f80 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -235,11 +235,24 @@ def __init__(self, function: Callable): # TODO: implement this. # Throw an error on any validation failure. # Add unit tests. - # Test with different callable types (functions, lambdas, methods, etc.), and different siganatures. + # Test with different callable types (functions, lambdas, methods, etc.), and different signatures. @staticmethod def validate_function(function: Callable) -> None: pass + def to_function_schema(self) -> FunctionSchema: + """Convert to a standard FunctionSchema for use with LLMs. + + Returns: + FunctionSchema without flow-specific fields + """ + return FunctionSchema( + name=self.name, + description=self.description, + properties=self.properties, + required=self.required, + ) + def _initialize_metadata(self): # Get function name self.name = self.function.__name__ From 7cececee40acee8cfa2aaecf9b2a429d9591964b Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Thu, 22 May 2025 16:17:27 -0400 Subject: [PATCH 13/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...restaurant_reservation_direct_functions.py | 2 - src/pipecat_flows/manager.py | 56 +++++-------------- 2 files changed, 13 insertions(+), 45 deletions(-) diff --git a/examples/dynamic/restaurant_reservation_direct_functions.py b/examples/dynamic/restaurant_reservation_direct_functions.py index d2f627f..35f5a13 100644 --- a/examples/dynamic/restaurant_reservation_direct_functions.py +++ b/examples/dynamic/restaurant_reservation_direct_functions.py @@ -122,8 +122,6 @@ async def check_availability(time: str, party_size: int) -> tuple[TimeResult, No return result, next_node -# TODO: might be cleaner to not return a tuple but instead write result within the function, avoiding to have to return None. -# (Although, in the case of a function that *only* returns a result, we'd still have the same issue) async def end_conversation() -> tuple[Optional[FlowResult], NodeConfig]: """End the conversation.""" return None, create_end_node() diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index f958ab7..f2f279d 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -438,7 +438,6 @@ async def on_context_updated_node() -> None: async def transition_func(params: FunctionCallParams) -> None: """Inner function that handles the actual tool invocation.""" - # TODO: implement try: # Track pending function call self._pending_function_calls += 1 @@ -507,7 +506,15 @@ def _lookup_function(self, func_name: str) -> Callable: async def _register_function_schema( self, schema: FlowsFunctionSchema, new_functions: Set[str] ) -> None: - # TODO: add docstring + """Register a function with the LLM if not already registered. + + Args: + schema: The FlowsFunctionSchema for the function to register with the LLM + new_functions: Set to track newly registered functions for this node + + Raises: + FlowError: If function registration fails + """ name = schema.name handler = schema.handler if name not in self.current_functions: @@ -536,57 +543,21 @@ async def _register_function_schema( async def _register_direct_function( self, func: FlowsDirectFunction, new_functions: Set[str] - ) -> None: - name = func.name - if name not in self.current_functions: - try: - # Create transition function - transition_func = await self._create_transition_func_from_direct_function(func) - - # Register function with LLM - self.llm.register_function( - name, - transition_func, - ) - - new_functions.add(name) - logger.debug(f"Registered function: {name}") - except Exception as e: - logger.error(f"Failed to register function {name}: {str(e)}") - raise FlowError(f"Function registration failed: {str(e)}") from e - - async def _register_function( - self, - name: str, - new_functions: Set[str], - handler: Optional[Callable], - transition_to: Optional[str] = None, - transition_callback: Optional[Callable] = None, ) -> None: """Register a function with the LLM if not already registered. Args: - name: Name of the function to register with the LLM + func: The FlowsDirectFunction for the function to register with the LLM new_functions: Set to track newly registered functions for this node - handler: Either a callable function or a string. If string starts with - '__function__:', extracts the function name after the prefix - transition_to: Optional node name to transition to after function execution - transition_callback: Optional callback for dynamic transitions Raises: - FlowError: If function registration fails or handler lookup fails + FlowError: If function registration fails """ + name = func.name if name not in self.current_functions: try: - # Handle special token format (e.g. "__function__:function_name") - if isinstance(handler, str) and handler.startswith("__function__:"): - func_name = handler.split(":")[1] - handler = self._lookup_function(func_name) - # Create transition function - transition_func = await self._create_transition_func( - name, handler, transition_to, transition_callback - ) + transition_func = await self._create_transition_func_from_direct_function(func) # Register function with LLM self.llm.register_function( @@ -667,7 +638,6 @@ async def register_direct_function(func): """Helper to register a single direct function.""" direct_function = FlowsDirectFunction(function=func) tools.append(direct_function) - # TODO: ensure that "traditional" (non-direct-function) examples still work await self._register_direct_function( func=direct_function, new_functions=new_functions, From d1b3bd7354664d427feebbe4d0e4ab995b417240 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Thu, 22 May 2025 16:33:48 -0400 Subject: [PATCH 14/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20fix=20broken=20unit=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_manager.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_manager.py b/tests/test_manager.py index 5d2ecdc..ffae7aa 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -28,7 +28,7 @@ from pipecat_flows.exceptions import FlowError, FlowTransitionError from pipecat_flows.manager import FlowConfig, FlowManager, NodeConfig -from pipecat_flows.types import FlowArgs, FlowResult +from pipecat_flows.types import FlowArgs, FlowResult, FlowsFunctionSchema class TestFlowManager(unittest.IsolatedAsyncioTestCase): @@ -681,7 +681,10 @@ async def test_register_function_error_handling(self): new_functions = set() with self.assertRaises(FlowError): - await flow_manager._register_function("test", None, None, new_functions) + await flow_manager._register_function_schema( + FlowsFunctionSchema(name="test", description="test", properties={}, required=[]), + new_functions, + ) async def test_action_execution_error_handling(self): """Test error handling in action execution.""" From b16547523616cceca0e9eccea1e5cf7b2d289397 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Thu, 22 May 2025 21:18:53 -0400 Subject: [PATCH 15/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/dynamic/restaurant_reservation_direct_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dynamic/restaurant_reservation_direct_functions.py b/examples/dynamic/restaurant_reservation_direct_functions.py index 35f5a13..a3a36e9 100644 --- a/examples/dynamic/restaurant_reservation_direct_functions.py +++ b/examples/dynamic/restaurant_reservation_direct_functions.py @@ -122,7 +122,7 @@ async def check_availability(time: str, party_size: int) -> tuple[TimeResult, No return result, next_node -async def end_conversation() -> tuple[Optional[FlowResult], NodeConfig]: +async def end_conversation() -> tuple[None, NodeConfig]: """End the conversation.""" return None, create_end_node() From 1ea7add7dee5bd5308091ab0e8f83c7eba76a4c2 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 23 May 2025 11:11:23 -0400 Subject: [PATCH 16/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20split=20apart=20the=20concepts=20?= =?UTF-8?q?of=20=E2=80=9Cunified=E2=80=9D=20functions=20(functions=20that?= =?UTF-8?q?=20return=20both=20result=20and=20next=5Fnode,=20obviating=20th?= =?UTF-8?q?e=20need=20to=20use=20`transition=5Fto`/`transition=5Fcallback`?= =?UTF-8?q?)=20and=20=E2=80=9Cdirect=E2=80=9D=20functions=20(functions=20t?= =?UTF-8?q?hat=20are=20provided=20in=20lieu=20of=20function=20declarations?= =?UTF-8?q?/`FlowsFunctionSchema`s,=20and=20whose=20metadata=20is=20parsed?= =?UTF-8?q?=20out=20automatically=20from=20its=20signature=20and=20docstri?= =?UTF-8?q?ng).=20This=20allows=20the=20user=20to=20pass=20a=20"unified"?= =?UTF-8?q?=20function=20as=20a=20`handler`=20and=20simply=20omit=20`trans?= =?UTF-8?q?ition=5Fto`/`transition=5Fcallback`=20without=20necessarily=20o?= =?UTF-8?q?pting=20out=20of=20function=20declarations/`FlowsFunctionSchema?= =?UTF-8?q?`s.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...restaurant_reservation_direct_functions.py | 10 +- src/pipecat_flows/manager.py | 235 +++++++----------- src/pipecat_flows/types.py | 23 +- tests/test_manager.py | 6 +- 4 files changed, 120 insertions(+), 154 deletions(-) diff --git a/examples/dynamic/restaurant_reservation_direct_functions.py b/examples/dynamic/restaurant_reservation_direct_functions.py index a3a36e9..53b608d 100644 --- a/examples/dynamic/restaurant_reservation_direct_functions.py +++ b/examples/dynamic/restaurant_reservation_direct_functions.py @@ -81,7 +81,9 @@ class TimeResult(FlowResult): # Function handlers -async def collect_party_size(size: int) -> tuple[PartySizeResult, NodeConfig]: +async def collect_party_size( + size: int, flow_manager: FlowManager +) -> tuple[PartySizeResult, NodeConfig]: """ Record the number of people in the party. @@ -97,7 +99,9 @@ async def collect_party_size(size: int) -> tuple[PartySizeResult, NodeConfig]: return result, next_node -async def check_availability(time: str, party_size: int) -> tuple[TimeResult, NodeConfig]: +async def check_availability( + time: str, party_size: int, flow_manager: FlowManager +) -> tuple[TimeResult, NodeConfig]: """ Check availability for requested time. @@ -122,7 +126,7 @@ async def check_availability(time: str, party_size: int) -> tuple[TimeResult, No return result, next_node -async def end_conversation() -> tuple[None, NodeConfig]: +async def end_conversation(flow_manager: FlowManager) -> tuple[None, NodeConfig]: """End the conversation.""" return None, create_end_node() diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index f2f279d..affe911 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -54,6 +54,7 @@ FlowsFunctionSchema, FunctionHandler, NodeConfig, + UnifiedFunctionResult, ) if TYPE_CHECKING: @@ -226,7 +227,9 @@ def _register_action_from_config(self, action: ActionConfig) -> None: "Provide handler in action config or register manually." ) - async def _call_handler(self, handler: FunctionHandler, args: FlowArgs) -> FlowResult: + async def _call_handler( + self, handler: FunctionHandler, args: FlowArgs + ) -> FlowResult | UnifiedFunctionResult: """Call handler with appropriate parameters based on its signature. Detects whether the handler can accept a flow_manager parameter and @@ -265,7 +268,7 @@ async def _call_handler(self, handler: FunctionHandler, args: FlowArgs) -> FlowR async def _create_transition_func( self, name: str, - handler: Optional[Callable], + handler: Optional[Callable | FlowsDirectFunction], transition_to: Optional[str], transition_callback: Optional[Callable] = None, ) -> Callable: @@ -292,8 +295,6 @@ async def _create_transition_func( if transition_callback: self._validate_transition_callback(name, transition_callback) - is_edge_function = bool(transition_to) or bool(transition_callback) - def decrease_pending_function_calls() -> None: """Decrease the pending function calls counter if greater than zero.""" if self._pending_function_calls > 0: @@ -303,15 +304,35 @@ def decrease_pending_function_calls() -> None: ) async def on_context_updated_edge( - args: Dict[str, Any], result: Any, result_callback: Callable + next_node: Optional[NodeConfig], + args: Optional[Dict[str, Any]], + result: Optional[Any], + result_callback: Callable, ) -> None: - """Handle context updates for edge functions with transitions.""" + """ + Handle context updates for edge functions with transitions. + + If `next_node` is provided: + - Ignore `args` and `result` and just transition to it. + + Otherwise, if `transition_to` is available: + - Use it to look up the next node. + + Otherwise, if `transition_callback` is provided: + - Call it with `args` and `result` to determine the next node. + """ try: decrease_pending_function_calls() # Only process transition if this was the last pending call if self._pending_function_calls == 0: - if transition_to: # Static flow + if next_node: + # TODO: figure out how to elegantly get name of the next node + # TODO: handle possibility of next_node being a string identifying a node? for static flows? mabe we can just say direct functions are only supported in dynamic flows? + # TODO: put name of next_node in the debug log message + logger.debug(f"Transition to handler-returned node for: {name}") + await self.set_node(str(uuid.uuid4()), next_node) + elif transition_to: # Static flow logger.debug(f"Static transition to: {transition_to}") await self.set_node(transition_to, self.nodes[transition_to]) elif transition_callback: # Dynamic flow @@ -354,116 +375,58 @@ async def transition_func(params: FunctionCallParams) -> None: ) # Execute handler if present + is_transition_only_function = False + acknowledged_result = {"status": "acknowledged"} if handler: - result = await self._call_handler(handler, params.arguments) - logger.debug(f"Handler completed for {name}") - else: - result = {"status": "acknowledged"} - logger.debug(f"Function called without handler: {name}") - - # For edge functions, prevent LLM completion until transition (run_llm=False) - # For node functions, allow immediate completion (run_llm=True) - async def on_context_updated() -> None: - if is_edge_function: - await on_context_updated_edge( - params.arguments, result, params.result_callback - ) + # Invoke the handler with the provided arguments + if isinstance(handler, FlowsDirectFunction): + handler_response = await handler.invoke(params.arguments, self) else: - await on_context_updated_node() - - properties = FunctionCallResultProperties( - run_llm=not is_edge_function, - on_context_updated=on_context_updated, - ) - await params.result_callback(result, properties=properties) - - except Exception as e: - logger.error(f"Error in transition function {name}: {str(e)}") - self._pending_function_calls = 0 - error_result = {"status": "error", "error": str(e)} - await params.result_callback(error_result) - - return transition_func - - async def _create_transition_func_from_direct_function( - self, func: FlowsDirectFunction - ) -> Callable: - """Create a transition function for the given direct function. - - Args: - func: The FlowsDirectFunction to create a transition function for - - Returns: - Callable: The created transition function - """ - name = func.name - - def decrease_pending_function_calls() -> None: - """Decrease the pending function calls counter if greater than zero.""" - if self._pending_function_calls > 0: - self._pending_function_calls -= 1 - logger.debug( - f"Function call completed: {name} (remaining: {self._pending_function_calls})" - ) - - async def on_context_updated_edge(next_node: NodeConfig, result_callback: Callable) -> None: - """Handle context updates for edge functions with transitions.""" - try: - decrease_pending_function_calls() - - # Only process transition if this was the last pending call - if self._pending_function_calls == 0: - # TODO: figure out how to elegantly get name of the next node - # TODO: handle possibility of next_node being a string identifying a node? for static flows? mabe we can just say direct functions are only supported in dynamic flows? - await self.set_node(str(uuid.uuid4()), next_node) - # Reset counter after transition completes - self._pending_function_calls = 0 - logger.debug("Reset pending function calls counter") + handler_response = await self._call_handler(handler, params.arguments) + # Support both "unified" handlers that return (result, next_node) and handlers + # that return just the result. + if isinstance(handler_response, tuple): + result, next_node = handler_response + if result is None: + result = acknowledged_result + is_transition_only_function = True + else: + result = handler_response + next_node = None else: - logger.debug( - f"Skipping transition, {self._pending_function_calls} calls still pending" - ) - except Exception as e: - logger.error(f"Error in transition: {str(e)}") - self._pending_function_calls = 0 - await result_callback( - {"status": "error", "error": str(e)}, - properties=None, # Clear properties to prevent further callbacks - ) - raise # Re-raise to prevent further processing - - async def on_context_updated_node() -> None: - """Handle context updates for node functions without transitions.""" - decrease_pending_function_calls() - - async def transition_func(params: FunctionCallParams) -> None: - """Inner function that handles the actual tool invocation.""" - try: - # Track pending function call - self._pending_function_calls += 1 + result = acknowledged_result + next_node = None + is_transition_only_function = True + # TODO: test transition-only and non-transition-only functions using both transitional and unified functions logger.debug( - f"Function call pending: {name} (total: {self._pending_function_calls})" + f"{'Transition-only function called for' if is_transition_only_function else 'Function handler completed for'} {name}" ) - # Execute function - # TODO: we need to pass FlowManager, too, huh (just like in _call_handler())... - result, next_node = await func.function(**params.arguments) - if result is None: - result = {"status": "acknowledged"} - logger.debug(f"Function called without 'handler' logic: {name}") - else: - logger.debug(f"Function with 'handler' logic {name}") - - # For edge functions (where there's a next node), prevent LLM completion until transition (run_llm=False) + # For edge functions, prevent LLM completion until transition (run_llm=False) # For node functions, allow immediate completion (run_llm=True) + has_explicit_transition = bool(transition_to) or bool(transition_callback) + async def on_context_updated() -> None: if next_node: - await on_context_updated_edge(next_node, params.result_callback) + await on_context_updated_edge( + next_node=next_node, + args=None, + result=None, + result_callback=params.result_callback, + ) + elif has_explicit_transition: + await on_context_updated_edge( + next_node=None, + args=params.arguments, + result=result, + result_callback=params.result_callback, + ) else: await on_context_updated_node() + is_edge_function = bool(next_node) or has_explicit_transition properties = FunctionCallResultProperties( - run_llm=not next_node, + run_llm=not is_edge_function, on_context_updated=on_context_updated, ) await params.result_callback(result, properties=properties) @@ -503,20 +466,26 @@ def _lookup_function(self, func_name: str) -> Callable: raise FlowError(error_message) - async def _register_function_schema( - self, schema: FlowsFunctionSchema, new_functions: Set[str] + async def _register_function( + self, + name: str, + new_functions: Set[str], + handler: Optional[Callable | FlowsDirectFunction], + transition_to: Optional[str] = None, + transition_callback: Optional[Callable] = None, ) -> None: """Register a function with the LLM if not already registered. Args: - schema: The FlowsFunctionSchema for the function to register with the LLM + name: Name of the function to register + handler: The function handler to register + transition_to: Optional node to transition to (static flows) + transition_callback: Optional transition callback (dynamic flows) new_functions: Set to track newly registered functions for this node Raises: FlowError: If function registration fails """ - name = schema.name - handler = schema.handler if name not in self.current_functions: try: # Handle special token format (e.g. "__function__:function_name") @@ -526,39 +495,9 @@ async def _register_function_schema( # Create transition function transition_func = await self._create_transition_func( - name, handler, schema.transition_to, schema.transition_callback - ) - - # Register function with LLM - self.llm.register_function( - name, - transition_func, + name, handler, transition_to, transition_callback ) - new_functions.add(name) - logger.debug(f"Registered function: {name}") - except Exception as e: - logger.error(f"Failed to register function {name}: {str(e)}") - raise FlowError(f"Function registration failed: {str(e)}") from e - - async def _register_direct_function( - self, func: FlowsDirectFunction, new_functions: Set[str] - ) -> None: - """Register a function with the LLM if not already registered. - - Args: - func: The FlowsDirectFunction for the function to register with the LLM - new_functions: Set to track newly registered functions for this node - - Raises: - FlowError: If function registration fails - """ - name = func.name - if name not in self.current_functions: - try: - # Create transition function - transition_func = await self._create_transition_func_from_direct_function(func) - # Register function with LLM self.llm.register_function( name, @@ -626,21 +565,27 @@ async def set_node(self, node_id: str, node_config: NodeConfig) -> None: # Get functions list with default empty list if not provided functions_list = node_config.get("functions", []) - async def register_function_schema(schema): + async def register_function_schema(schema: FlowsFunctionSchema): """Helper to register a single FlowsFunctionSchema.""" tools.append(schema) - await self._register_function_schema( - schema=schema, + await self._register_function( + name=schema.name, new_functions=new_functions, + handler=schema.handler, + transition_to=schema.transition_to, + transition_callback=schema.transition_callback, ) async def register_direct_function(func): """Helper to register a single direct function.""" direct_function = FlowsDirectFunction(function=func) tools.append(direct_function) - await self._register_direct_function( - func=direct_function, + await self._register_function( + name=direct_function.name, new_functions=new_functions, + handler=direct_function, + transition_to=None, + transition_callback=None, ) for func_config in functions_list: diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index 9e02f80..04e27c1 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -26,6 +26,7 @@ Callable, Dict, List, + Mapping, Optional, Set, Tuple, @@ -80,7 +81,10 @@ class FlowResult(TypedDict, total=False): } """ -LegacyFunctionHandler = Callable[[FlowArgs], Awaitable[FlowResult]] +UnifiedFunctionResult = Tuple[Optional[FlowResult], Optional["NodeConfig"]] +"""Return type for "unified" functions that do either or both of handling some processing as well as specifying the next node.""" + +LegacyFunctionHandler = Callable[[FlowArgs], Awaitable[FlowResult | UnifiedFunctionResult]] """Legacy function handler that only receives arguments. Args: @@ -90,7 +94,9 @@ class FlowResult(TypedDict, total=False): FlowResult: Result of the function execution """ -FlowFunctionHandler = Callable[[FlowArgs, "FlowManager"], Awaitable[FlowResult]] +FlowFunctionHandler = Callable[ + [FlowArgs, "FlowManager"], Awaitable[FlowResult | UnifiedFunctionResult] +] """Modern function handler that receives both arguments and flow_manager. Args: @@ -240,6 +246,12 @@ def __init__(self, function: Callable): def validate_function(function: Callable) -> None: pass + async def invoke( + self, args: Mapping[str, Any], flow_manager: "FlowManager" + ) -> UnifiedFunctionResult: + # print(f"[pk] Invoking function {self.name} with args: {args}") + return await self.function(**args, flow_manager=flow_manager) + def to_function_schema(self) -> FunctionSchema: """Convert to a standard FunctionSchema for use with LLMs. @@ -274,6 +286,7 @@ def _get_parameters_as_jsonschema( ) -> Tuple[Dict[str, Any], List[str]]: """ Get function parameters as a dictionary of JSON schemas and a list of required parameters. + Ignore the last parameter, as it's expected to be the flow_manager. Args: func: Function to get parameters from @@ -294,6 +307,12 @@ def _get_parameters_as_jsonschema( if name == "self": continue + # Ignore the last parameter, which is expected to be the flow_manager + param_names = [n for n in sig.parameters] + is_last_param = name == param_names[-1] + if is_last_param: + continue + type_hint = hints.get(name) # Convert type hint to JSON schema diff --git a/tests/test_manager.py b/tests/test_manager.py index ffae7aa..4879da5 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -509,6 +509,7 @@ async def handler_no_args(): result = await flow_manager._call_handler(handler_no_args, {}) self.assertEqual(result["status"], "success") + # TODO: test async def test_transition_func_error_handling(self): """Test error handling in transition functions.""" flow_manager = FlowManager( @@ -681,10 +682,7 @@ async def test_register_function_error_handling(self): new_functions = set() with self.assertRaises(FlowError): - await flow_manager._register_function_schema( - FlowsFunctionSchema(name="test", description="test", properties={}, required=[]), - new_functions, - ) + await flow_manager._register_function("test", new_functions, None) async def test_action_execution_error_handling(self): """Test error handling in action execution.""" From 31f92069c97b9f7018d0c2121eeb9b22bcbfa3bd Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 23 May 2025 14:26:39 -0400 Subject: [PATCH 17/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20fix=20direct=20functions=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_flows_direct_functions.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/test_flows_direct_functions.py b/tests/test_flows_direct_functions.py index f5e268e..fd9c8ff 100644 --- a/tests/test_flows_direct_functions.py +++ b/tests/test_flows_direct_functions.py @@ -1,6 +1,7 @@ import unittest from typing import Optional, Union +from pipecat_flows.manager import FlowManager from pipecat_flows.types import FlowsDirectFunction # Copyright (c) 2025, Daily @@ -15,7 +16,7 @@ class TestFlowsDirectFunction(unittest.TestCase): def test_name_is_set_from_function(self): """Test that FlowsDirectFunction extracts the name from the function.""" - def my_function(): + def my_function(flow_manager: FlowManager): return {} func = FlowsDirectFunction(function=my_function) @@ -24,14 +25,14 @@ def my_function(): def test_description_is_set_from_function(self): """Test that FlowsDirectFunction extracts the description from the function.""" - def my_function_short_description(): + def my_function_short_description(flow_manager: FlowManager): """This is a test function.""" return {} func = FlowsDirectFunction(function=my_function_short_description) self.assertEqual(func.description, "This is a test function.") - def my_function_long_description(): + def my_function_long_description(flow_manager: FlowManager): """ This is a test function. @@ -50,13 +51,15 @@ def my_function_long_description(): def test_properties_are_set_from_function(self): """Test that FlowsDirectFunction extracts the properties from the function.""" - def my_function_no_params(): + def my_function_no_params(flow_manager: FlowManager): return {} func = FlowsDirectFunction(function=my_function_no_params) self.assertEqual(func.properties, {}) - def my_function_simple_params(name: str, age: int, height: Union[float, None]): + def my_function_simple_params( + name: str, age: int, height: Union[float, None], flow_manager: FlowManager + ): return {} func = FlowsDirectFunction(function=my_function_simple_params) @@ -70,7 +73,10 @@ def my_function_simple_params(name: str, age: int, height: Union[float, None]): ) def my_function_complex_params( - address_lines: list[str], nickname: str | int | float, extra: Optional[dict[str, str]] + address_lines: list[str], + nickname: str | int | float, + extra: Optional[dict[str, str]], + flow_manager: FlowManager, ): return {} @@ -94,13 +100,15 @@ def my_function_complex_params( def test_required_is_set_from_function(self): """Test that FlowsDirectFunction extracts the required properties from the function.""" - def my_function_no_params(): + def my_function_no_params(flow_manager: FlowManager): return {} func = FlowsDirectFunction(function=my_function_no_params) self.assertEqual(func.required, []) - def my_function_simple_params(name: str, age: int, height: Union[float, None] = None): + def my_function_simple_params( + name: str, age: int, height: Union[float, None] = None, flow_manager: FlowManager = None + ): return {} func = FlowsDirectFunction(function=my_function_simple_params) @@ -110,6 +118,7 @@ def my_function_complex_params( address_lines: list[str], nickname: str | int | None = "Bud", extra: Optional[dict[str, str]] = None, + flow_manager: FlowManager = None, ): return {} @@ -119,7 +128,7 @@ def my_function_complex_params( def test_property_descriptions_are_set_from_function(self): """Test that FlowsDirectFunction extracts the property descriptions from the function.""" - def my_function(name: str, age: int, height: Union[float, None]): + def my_function(name: str, age: int, height: Union[float, None], flow_manager: FlowManager): """ This is a test function. From 74492d1b1d34c9c85cf9e3fafd769139b6ba817f Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 23 May 2025 15:44:51 -0400 Subject: [PATCH 18/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20support=20returning=20from=20a=20?= =?UTF-8?q?unified=20function=20the=20next=20node's=20name=20alongside=20i?= =?UTF-8?q?ts=20`NodeConfig`,=20to=20facilitate=20debug=20logging?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...restaurant_reservation_direct_functions.py | 19 +++++++++++-------- src/pipecat_flows/__init__.py | 1 + src/pipecat_flows/manager.py | 10 ++++++---- src/pipecat_flows/types.py | 7 ++++++- 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/examples/dynamic/restaurant_reservation_direct_functions.py b/examples/dynamic/restaurant_reservation_direct_functions.py index 53b608d..246ff83 100644 --- a/examples/dynamic/restaurant_reservation_direct_functions.py +++ b/examples/dynamic/restaurant_reservation_direct_functions.py @@ -23,7 +23,7 @@ from pipecat.services.openai.llm import OpenAILLMService from pipecat.transports.services.daily import DailyParams, DailyTransport -from pipecat_flows import FlowArgs, FlowManager, FlowResult, FlowsFunctionSchema, NodeConfig +from pipecat_flows import FlowManager, FlowResult, NamedNodeConfig, NodeConfig sys.path.append(str(Path(__file__).parent.parent)) import argparse @@ -83,7 +83,7 @@ class TimeResult(FlowResult): # Function handlers async def collect_party_size( size: int, flow_manager: FlowManager -) -> tuple[PartySizeResult, NodeConfig]: +) -> tuple[PartySizeResult, NamedNodeConfig]: """ Record the number of people in the party. @@ -94,14 +94,15 @@ async def collect_party_size( result = PartySizeResult(size=size, status="success") # Next node: time selection - next_node = create_time_selection_node() + # NOTE: name is optional, but useful for debug logging; you could use a NodeConfig here rather than a NamedNodeConfig + next_node = "get_time", create_time_selection_node() return result, next_node async def check_availability( time: str, party_size: int, flow_manager: FlowManager -) -> tuple[TimeResult, NodeConfig]: +) -> tuple[TimeResult, NamedNodeConfig]: """ Check availability for requested time. @@ -118,17 +119,19 @@ async def check_availability( ) # Next node: confirmation or no availability + # NOTE: name is optional, but useful for debug logging; you could use a NodeConfig here rather than a NamedNodeConfig if is_available: - next_node = create_confirmation_node() + next_node = "confirm", create_confirmation_node() else: - next_node = create_no_availability_node(alternative_times) + next_node = "no_availability", create_no_availability_node(alternative_times) return result, next_node -async def end_conversation(flow_manager: FlowManager) -> tuple[None, NodeConfig]: +async def end_conversation(flow_manager: FlowManager) -> tuple[None, NamedNodeConfig]: """End the conversation.""" - return None, create_end_node() + # NOTE: name is optional, but useful for debug logging; you could use a NodeConfig here rather than a NamedNodeConfig + return None, ("end", create_end_node()) # Node configurations diff --git a/src/pipecat_flows/__init__.py b/src/pipecat_flows/__init__.py index ca10870..bdb96f2 100644 --- a/src/pipecat_flows/__init__.py +++ b/src/pipecat_flows/__init__.py @@ -72,6 +72,7 @@ async def handle_transitions(function_name: str, args: Dict, flow_manager): FlowResult, FlowsFunctionSchema, LegacyFunctionHandler, + NamedNodeConfig, NodeConfig, ) diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index affe911..fccb510 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -327,11 +327,13 @@ async def on_context_updated_edge( # Only process transition if this was the last pending call if self._pending_function_calls == 0: if next_node: - # TODO: figure out how to elegantly get name of the next node # TODO: handle possibility of next_node being a string identifying a node? for static flows? mabe we can just say direct functions are only supported in dynamic flows? - # TODO: put name of next_node in the debug log message - logger.debug(f"Transition to handler-returned node for: {name}") - await self.set_node(str(uuid.uuid4()), next_node) + if isinstance(next_node, tuple): + next_node_name, next_node = next_node + else: + next_node_name, next_node = str(uuid.uuid4()), next_node + logger.debug(f"Transition to function-returned node: {next_node_name}") + await self.set_node(next_node_name, next_node) elif transition_to: # Static flow logger.debug(f"Static transition to: {transition_to}") await self.set_node(transition_to, self.nodes[transition_to]) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index 04e27c1..231a6dc 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -81,7 +81,12 @@ class FlowResult(TypedDict, total=False): } """ -UnifiedFunctionResult = Tuple[Optional[FlowResult], Optional["NodeConfig"]] +NamedNodeConfig = tuple[str, "NodeConfig"] +"""Type alias for a node configuration with its name.""" + +UnifiedFunctionResult = Tuple[ + Optional[FlowResult], Optional["NodeConfig"] | Optional[NamedNodeConfig] +] """Return type for "unified" functions that do either or both of handling some processing as well as specifying the next node.""" LegacyFunctionHandler = Callable[[FlowArgs], Awaitable[FlowResult | UnifiedFunctionResult]] From 2feb06068b1e6ce7b46190c55bf95ffa3c06c973 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 23 May 2025 15:48:24 -0400 Subject: [PATCH 19/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20add=20`FlowsDirectFunction`=20to?= =?UTF-8?q?=20`NodeConfig.functions`=20type=20definition?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index 231a6dc..95ae1f0 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -449,7 +449,7 @@ class NodeConfig(NodeConfigRequired, total=False): """ role_messages: List[Dict[str, Any]] - functions: List[Union[Dict[str, Any], FlowsFunctionSchema]] + functions: List[Union[Dict[str, Any], FlowsFunctionSchema, FlowsDirectFunction]] pre_actions: List[ActionConfig] post_actions: List[ActionConfig] context_strategy: ContextStrategyConfig From 45fa831c0762330d1e22b84bf37fb523fb9529e9 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 23 May 2025 16:43:33 -0400 Subject: [PATCH 20/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20some=20validation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/manager.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index fccb510..64a440c 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -42,7 +42,7 @@ from .actions import ActionError, ActionManager from .adapters import create_adapter -from .exceptions import FlowError, FlowInitializationError, FlowTransitionError +from .exceptions import FlowError, FlowInitializationError, FlowTransitionError, InvalidFunctionError from .types import ( ActionConfig, ContextStrategy, @@ -395,6 +395,11 @@ async def transition_func(params: FunctionCallParams) -> None: else: result = handler_response next_node = None + # FlowsDirectFunctions should always be "unified" functions that return a tuple + if isinstance(handler, FlowsDirectFunction): + raise InvalidFunctionError( + f"Direct function {name} expected to return a tuple (result, next_node) but got {type(result)}" + ) else: result = acknowledged_result next_node = None From c6ecfb4108c14bf44e0efc6e87a4e23e1353ad2d Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 23 May 2025 16:54:18 -0400 Subject: [PATCH 21/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20some=20validation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/manager.py | 7 ++++++- src/pipecat_flows/types.py | 19 ++++++++++++++----- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index 64a440c..c20c159 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -42,7 +42,12 @@ from .actions import ActionError, ActionManager from .adapters import create_adapter -from .exceptions import FlowError, FlowInitializationError, FlowTransitionError, InvalidFunctionError +from .exceptions import ( + FlowError, + FlowInitializationError, + FlowTransitionError, + InvalidFunctionError, +) from .types import ( ActionConfig, ContextStrategy, diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index 95ae1f0..fcc8927 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -42,6 +42,8 @@ from loguru import logger from pipecat.adapters.schemas.function_schema import FunctionSchema +from pipecat_flows.exceptions import InvalidFunctionError + T = TypeVar("T") TransitionHandler = Callable[[Dict[str, T], "FlowManager"], Awaitable[None]] """Type for transition handler functions. @@ -243,13 +245,20 @@ def __init__(self, function: Callable): self.function = function self._initialize_metadata() - # TODO: implement this. - # Throw an error on any validation failure. - # Add unit tests. - # Test with different callable types (functions, lambdas, methods, etc.), and different signatures. @staticmethod def validate_function(function: Callable) -> None: - pass + if not inspect.iscoroutinefunction(function): + raise InvalidFunctionError(f"Direct function {function.__name__} must be async") + params = list(inspect.signature(function).parameters.items()) + if len(params) == 0: + raise InvalidFunctionError( + f"Direct function {function.__name__} must have at least one parameter (flow_manager)" + ) + last_param_name = params[-1][0] + if last_param_name != "flow_manager": + raise InvalidFunctionError( + f"Direct function {function.__name__} last parameter must be named 'flow_manager'" + ) async def invoke( self, args: Mapping[str, Any], flow_manager: "FlowManager" From 4b35aa1abc50f1dc845b30f752d04b64f62173d4 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 23 May 2025 17:01:20 -0400 Subject: [PATCH 22/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20add=20tests=20for=20validation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_flows_direct_functions.py | 58 ++++++++++++++++++---------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/tests/test_flows_direct_functions.py b/tests/test_flows_direct_functions.py index fd9c8ff..aec7a0d 100644 --- a/tests/test_flows_direct_functions.py +++ b/tests/test_flows_direct_functions.py @@ -1,6 +1,7 @@ import unittest from typing import Optional, Union +from pipecat_flows.exceptions import InvalidFunctionError from pipecat_flows.manager import FlowManager from pipecat_flows.types import FlowsDirectFunction @@ -16,8 +17,8 @@ class TestFlowsDirectFunction(unittest.TestCase): def test_name_is_set_from_function(self): """Test that FlowsDirectFunction extracts the name from the function.""" - def my_function(flow_manager: FlowManager): - return {} + async def my_function(flow_manager: FlowManager): + return {}, None func = FlowsDirectFunction(function=my_function) self.assertEqual(func.name, "my_function") @@ -25,14 +26,14 @@ def my_function(flow_manager: FlowManager): def test_description_is_set_from_function(self): """Test that FlowsDirectFunction extracts the description from the function.""" - def my_function_short_description(flow_manager: FlowManager): + async def my_function_short_description(flow_manager: FlowManager): """This is a test function.""" - return {} + return {}, None func = FlowsDirectFunction(function=my_function_short_description) self.assertEqual(func.description, "This is a test function.") - def my_function_long_description(flow_manager: FlowManager): + async def my_function_long_description(flow_manager: FlowManager): """ This is a test function. @@ -40,7 +41,7 @@ def my_function_long_description(flow_manager: FlowManager): Trust me, you'll want to use it. """ - return {} + return {}, None func = FlowsDirectFunction(function=my_function_long_description) self.assertEqual( @@ -51,16 +52,16 @@ def my_function_long_description(flow_manager: FlowManager): def test_properties_are_set_from_function(self): """Test that FlowsDirectFunction extracts the properties from the function.""" - def my_function_no_params(flow_manager: FlowManager): - return {} + async def my_function_no_params(flow_manager: FlowManager): + return {}, None func = FlowsDirectFunction(function=my_function_no_params) self.assertEqual(func.properties, {}) - def my_function_simple_params( + async def my_function_simple_params( name: str, age: int, height: Union[float, None], flow_manager: FlowManager ): - return {} + return {}, None func = FlowsDirectFunction(function=my_function_simple_params) self.assertEqual( @@ -72,13 +73,13 @@ def my_function_simple_params( }, ) - def my_function_complex_params( + async def my_function_complex_params( address_lines: list[str], nickname: str | int | float, extra: Optional[dict[str, str]], flow_manager: FlowManager, ): - return {} + return {}, None func = FlowsDirectFunction(function=my_function_complex_params) self.assertEqual( @@ -100,27 +101,27 @@ def my_function_complex_params( def test_required_is_set_from_function(self): """Test that FlowsDirectFunction extracts the required properties from the function.""" - def my_function_no_params(flow_manager: FlowManager): - return {} + async def my_function_no_params(flow_manager: FlowManager): + return {}, None func = FlowsDirectFunction(function=my_function_no_params) self.assertEqual(func.required, []) - def my_function_simple_params( + async def my_function_simple_params( name: str, age: int, height: Union[float, None] = None, flow_manager: FlowManager = None ): - return {} + return {}, None func = FlowsDirectFunction(function=my_function_simple_params) self.assertEqual(func.required, ["name", "age"]) - def my_function_complex_params( + async def my_function_complex_params( address_lines: list[str], nickname: str | int | None = "Bud", extra: Optional[dict[str, str]] = None, flow_manager: FlowManager = None, ): - return {} + return {}, None func = FlowsDirectFunction(function=my_function_complex_params) self.assertEqual(func.required, ["address_lines"]) @@ -128,7 +129,9 @@ def my_function_complex_params( def test_property_descriptions_are_set_from_function(self): """Test that FlowsDirectFunction extracts the property descriptions from the function.""" - def my_function(name: str, age: int, height: Union[float, None], flow_manager: FlowManager): + async def my_function( + name: str, age: int, height: Union[float, None], flow_manager: FlowManager + ): """ This is a test function. @@ -137,7 +140,7 @@ def my_function(name: str, age: int, height: Union[float, None], flow_manager: F age (int): The age of the person. height (float | None): The height of the person in meters. Defaults to None. """ - return {} + return {}, None func = FlowsDirectFunction(function=my_function) @@ -157,6 +160,21 @@ def my_function(name: str, age: int, height: Union[float, None], flow_manager: F }, ) + def test_invalid_functions_fail_validation(self): + """Test that invalid functions fail FlowsDirectFunction validation.""" + + def my_function_non_async(flow_manager: FlowManager): + return {}, None + + with self.assertRaises(InvalidFunctionError): + FlowsDirectFunction.validate_function(my_function_non_async) + + async def my_function_missing_flow_manager(): + return {}, None + + with self.assertRaises(InvalidFunctionError): + FlowsDirectFunction.validate_function(my_function_missing_flow_manager) + if __name__ == "__main__": unittest.main() From d21d895daea6a37123bd3a4354401d86416aba1a Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 23 May 2025 17:27:08 -0400 Subject: [PATCH 23/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20fix=20up=20and=20add=20tests=20fo?= =?UTF-8?q?r=20TypedDict=20types=20in=20direct=20function=20args?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/types.py | 10 +++++--- tests/test_flows_direct_functions.py | 37 +++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index fcc8927..717a8c5 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -396,12 +396,14 @@ def _typehint_to_jsonschema(self, type_hint: Any) -> Dict[str, Any]: properties = {} required = [] + # NOTE: this does not yet support some fields being required and others not, which could happen when: + # - the base class is a TypedDict with required fields (total=True or not specified) and the derived class has optional fields (total=False) + # - Python 3.11+ NotRequired is used + all_fields_required = getattr(type_hint, "__total__", True) + for field_name, field_type in get_type_hints(type_hint).items(): properties[field_name] = self._typehint_to_jsonschema(field_type) - # Check if field is required (this is a simplification, might need adjustment) - if not getattr(type_hint, "__total__", True) or not isinstance( - field_type, Optional - ): + if all_fields_required: required.append(field_name) schema = {"type": "object", "properties": properties} diff --git a/tests/test_flows_direct_functions.py b/tests/test_flows_direct_functions.py index aec7a0d..6288224 100644 --- a/tests/test_flows_direct_functions.py +++ b/tests/test_flows_direct_functions.py @@ -1,5 +1,5 @@ import unittest -from typing import Optional, Union +from typing import Optional, TypedDict, Union from pipecat_flows.exceptions import InvalidFunctionError from pipecat_flows.manager import FlowManager @@ -98,6 +98,41 @@ async def my_function_complex_params( }, ) + class MyInfo1(TypedDict): + name: str + age: int + + class MyInfo2(TypedDict, total=False): + name: str + age: int + + async def my_function_complex_type_params( + info1: MyInfo1, info2: MyInfo2, flow_manager: FlowManager + ): + return {}, None + + func = FlowsDirectFunction(function=my_function_complex_type_params) + self.assertEqual( + func.properties, + { + "info1": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + }, + "info2": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + }, + }, + ) + def test_required_is_set_from_function(self): """Test that FlowsDirectFunction extracts the required properties from the function.""" From 00ccc4c3ab878f625228f327a70cb09149c61349 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Wed, 28 May 2025 15:14:23 -0400 Subject: [PATCH 24/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20make=20`flow=5Fmanager`=20the=20f?= =?UTF-8?q?irst=20parameter=20of=20direct=20functions=20rather=20the=20las?= =?UTF-8?q?t,=20so=20that=20if=20any=20of=20the=20other=20parameters=20are?= =?UTF-8?q?=20optional=20the=20user=20doesn't=20have=20to=20awkwardly=20sp?= =?UTF-8?q?ecify=20a=20default=20value=20for=20`flow=5Fmanager`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...restaurant_reservation_direct_functions.py | 4 ++-- src/pipecat_flows/types.py | 16 ++++++------- tests/test_flows_direct_functions.py | 23 ++++++++++++++----- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/examples/dynamic/restaurant_reservation_direct_functions.py b/examples/dynamic/restaurant_reservation_direct_functions.py index 246ff83..5efcbf7 100644 --- a/examples/dynamic/restaurant_reservation_direct_functions.py +++ b/examples/dynamic/restaurant_reservation_direct_functions.py @@ -82,7 +82,7 @@ class TimeResult(FlowResult): # Function handlers async def collect_party_size( - size: int, flow_manager: FlowManager + flow_manager: FlowManager, size: int ) -> tuple[PartySizeResult, NamedNodeConfig]: """ Record the number of people in the party. @@ -101,7 +101,7 @@ async def collect_party_size( async def check_availability( - time: str, party_size: int, flow_manager: FlowManager + flow_manager: FlowManager, time: str, party_size: int ) -> tuple[TimeResult, NamedNodeConfig]: """ Check availability for requested time. diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index 717a8c5..0dcd816 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -254,17 +254,17 @@ def validate_function(function: Callable) -> None: raise InvalidFunctionError( f"Direct function {function.__name__} must have at least one parameter (flow_manager)" ) - last_param_name = params[-1][0] - if last_param_name != "flow_manager": + first_param_name = params[0][0] + if first_param_name != "flow_manager": raise InvalidFunctionError( - f"Direct function {function.__name__} last parameter must be named 'flow_manager'" + f"Direct function {function.__name__} first parameter must be named 'flow_manager'" ) async def invoke( self, args: Mapping[str, Any], flow_manager: "FlowManager" ) -> UnifiedFunctionResult: # print(f"[pk] Invoking function {self.name} with args: {args}") - return await self.function(**args, flow_manager=flow_manager) + return await self.function(flow_manager=flow_manager, **args) def to_function_schema(self) -> FunctionSchema: """Convert to a standard FunctionSchema for use with LLMs. @@ -321,10 +321,10 @@ def _get_parameters_as_jsonschema( if name == "self": continue - # Ignore the last parameter, which is expected to be the flow_manager - param_names = [n for n in sig.parameters] - is_last_param = name == param_names[-1] - if is_last_param: + # Ignore the first parameter, which is expected to be the flow_manager + # (We have presumably validated that this is the case in validate_function()) + is_first_param = name == next(iter(sig.parameters)) + if is_first_param: continue type_hint = hints.get(name) diff --git a/tests/test_flows_direct_functions.py b/tests/test_flows_direct_functions.py index 6288224..2e37886 100644 --- a/tests/test_flows_direct_functions.py +++ b/tests/test_flows_direct_functions.py @@ -20,6 +20,7 @@ def test_name_is_set_from_function(self): async def my_function(flow_manager: FlowManager): return {}, None + self.assertIsNone(FlowsDirectFunction.validate_function(my_function)) func = FlowsDirectFunction(function=my_function) self.assertEqual(func.name, "my_function") @@ -30,6 +31,7 @@ async def my_function_short_description(flow_manager: FlowManager): """This is a test function.""" return {}, None + self.assertIsNone(FlowsDirectFunction.validate_function(my_function_short_description)) func = FlowsDirectFunction(function=my_function_short_description) self.assertEqual(func.description, "This is a test function.") @@ -43,6 +45,7 @@ async def my_function_long_description(flow_manager: FlowManager): """ return {}, None + self.assertIsNone(FlowsDirectFunction.validate_function(my_function_long_description)) func = FlowsDirectFunction(function=my_function_long_description) self.assertEqual( func.description, @@ -55,14 +58,16 @@ def test_properties_are_set_from_function(self): async def my_function_no_params(flow_manager: FlowManager): return {}, None + self.assertIsNone(FlowsDirectFunction.validate_function(my_function_no_params)) func = FlowsDirectFunction(function=my_function_no_params) self.assertEqual(func.properties, {}) async def my_function_simple_params( - name: str, age: int, height: Union[float, None], flow_manager: FlowManager + flow_manager: FlowManager, name: str, age: int, height: Union[float, None] ): return {}, None + self.assertIsNone(FlowsDirectFunction.validate_function(my_function_simple_params)) func = FlowsDirectFunction(function=my_function_simple_params) self.assertEqual( func.properties, @@ -74,13 +79,14 @@ async def my_function_simple_params( ) async def my_function_complex_params( + flow_manager: FlowManager, address_lines: list[str], nickname: str | int | float, extra: Optional[dict[str, str]], - flow_manager: FlowManager, ): return {}, None + self.assertIsNone(FlowsDirectFunction.validate_function(my_function_complex_params)) func = FlowsDirectFunction(function=my_function_complex_params) self.assertEqual( func.properties, @@ -107,10 +113,11 @@ class MyInfo2(TypedDict, total=False): age: int async def my_function_complex_type_params( - info1: MyInfo1, info2: MyInfo2, flow_manager: FlowManager + flow_manager: FlowManager, info1: MyInfo1, info2: MyInfo2 ): return {}, None + self.assertIsNone(FlowsDirectFunction.validate_function(my_function_complex_type_params)) func = FlowsDirectFunction(function=my_function_complex_type_params) self.assertEqual( func.properties, @@ -139,25 +146,28 @@ def test_required_is_set_from_function(self): async def my_function_no_params(flow_manager: FlowManager): return {}, None + self.assertIsNone(FlowsDirectFunction.validate_function(my_function_no_params)) func = FlowsDirectFunction(function=my_function_no_params) self.assertEqual(func.required, []) async def my_function_simple_params( - name: str, age: int, height: Union[float, None] = None, flow_manager: FlowManager = None + flow_manager: FlowManager, name: str, age: int, height: Union[float, None] = None ): return {}, None + self.assertIsNone(FlowsDirectFunction.validate_function(my_function_simple_params)) func = FlowsDirectFunction(function=my_function_simple_params) self.assertEqual(func.required, ["name", "age"]) async def my_function_complex_params( + flow_manager: FlowManager, address_lines: list[str], nickname: str | int | None = "Bud", extra: Optional[dict[str, str]] = None, - flow_manager: FlowManager = None, ): return {}, None + self.assertIsNone(FlowsDirectFunction.validate_function(my_function_complex_params)) func = FlowsDirectFunction(function=my_function_complex_params) self.assertEqual(func.required, ["address_lines"]) @@ -165,7 +175,7 @@ def test_property_descriptions_are_set_from_function(self): """Test that FlowsDirectFunction extracts the property descriptions from the function.""" async def my_function( - name: str, age: int, height: Union[float, None], flow_manager: FlowManager + flow_manager: FlowManager, name: str, age: int, height: Union[float, None] ): """ This is a test function. @@ -177,6 +187,7 @@ async def my_function( """ return {}, None + self.assertIsNone(FlowsDirectFunction.validate_function(my_function)) func = FlowsDirectFunction(function=my_function) # Validate that the function description is still set correctly even with the longer docstring From a2425aa2a991e9ab7dadfc719037399bcd1bd4a9 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Wed, 28 May 2025 15:18:27 -0400 Subject: [PATCH 25/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20add=20additional=20test=20for=20c?= =?UTF-8?q?hecking=20validation=20that=20`flow=5Fmanager`=20argument=20is?= =?UTF-8?q?=20in=20the=20right=20place?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_flows_direct_functions.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_flows_direct_functions.py b/tests/test_flows_direct_functions.py index 2e37886..b172601 100644 --- a/tests/test_flows_direct_functions.py +++ b/tests/test_flows_direct_functions.py @@ -221,6 +221,12 @@ async def my_function_missing_flow_manager(): with self.assertRaises(InvalidFunctionError): FlowsDirectFunction.validate_function(my_function_missing_flow_manager) + async def my_function_misplaced_flow_manager(foo: str, flow_manager: FlowManager): + return {}, None + + with self.assertRaises(InvalidFunctionError): + FlowsDirectFunction.validate_function(my_function_misplaced_flow_manager) + if __name__ == "__main__": unittest.main() From ca2427c65468192370233c144dc581b8a42eb397 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Wed, 28 May 2025 15:23:13 -0400 Subject: [PATCH 26/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20add=20test=20for=20`FlowsDirectFu?= =?UTF-8?q?nction.invoke`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/types.py | 1 - tests/test_flows_direct_functions.py | 25 +++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index 0dcd816..0898978 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -344,7 +344,6 @@ def _get_parameters_as_jsonschema( return properties, required - # TODO: test this way more, throwing crazy types at it def _typehint_to_jsonschema(self, type_hint: Any) -> Dict[str, Any]: """ Convert a Python type hint to a JSON Schema. diff --git a/tests/test_flows_direct_functions.py b/tests/test_flows_direct_functions.py index b172601..3ea775f 100644 --- a/tests/test_flows_direct_functions.py +++ b/tests/test_flows_direct_functions.py @@ -1,3 +1,4 @@ +import asyncio import unittest from typing import Optional, TypedDict, Union @@ -227,6 +228,30 @@ async def my_function_misplaced_flow_manager(foo: str, flow_manager: FlowManager with self.assertRaises(InvalidFunctionError): FlowsDirectFunction.validate_function(my_function_misplaced_flow_manager) + def test_invoke_calls_function_with_args_and_flow_manager(self): + """Test that FlowsDirectFunction.invoke calls the function with correct args and flow_manager.""" + + called = {} + + class DummyFlowManager: + pass + + async def my_function(flow_manager: DummyFlowManager, name: str, age: int): + called["flow_manager"] = flow_manager + called["name"] = name + called["age"] = age + return {"status": "success"}, None + + func = FlowsDirectFunction(function=my_function) + flow_manager = DummyFlowManager() + args = {"name": "Alice", "age": 30} + + result = asyncio.run(func.invoke(args=args, flow_manager=flow_manager)) + self.assertEqual(result, ({"status": "success"}, None)) + self.assertIs(called["flow_manager"], flow_manager) + self.assertEqual(called["name"], "Alice") + self.assertEqual(called["age"], 30) + if __name__ == "__main__": unittest.main() From 90fa1c30e2cf53cb77f3a845bedb9e8dc8df9111 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Wed, 28 May 2025 15:27:42 -0400 Subject: [PATCH 27/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20return=20more=20realistic=20`Flow?= =?UTF-8?q?Result`s=20from=20direct=20functions=20exercised=20in=20unit=20?= =?UTF-8?q?tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_flows_direct_functions.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/test_flows_direct_functions.py b/tests/test_flows_direct_functions.py index 3ea775f..1d4e508 100644 --- a/tests/test_flows_direct_functions.py +++ b/tests/test_flows_direct_functions.py @@ -19,7 +19,7 @@ def test_name_is_set_from_function(self): """Test that FlowsDirectFunction extracts the name from the function.""" async def my_function(flow_manager: FlowManager): - return {}, None + return {"status": "success"}, None self.assertIsNone(FlowsDirectFunction.validate_function(my_function)) func = FlowsDirectFunction(function=my_function) @@ -30,7 +30,7 @@ def test_description_is_set_from_function(self): async def my_function_short_description(flow_manager: FlowManager): """This is a test function.""" - return {}, None + return {"status": "success"}, None self.assertIsNone(FlowsDirectFunction.validate_function(my_function_short_description)) func = FlowsDirectFunction(function=my_function_short_description) @@ -44,7 +44,7 @@ async def my_function_long_description(flow_manager: FlowManager): Trust me, you'll want to use it. """ - return {}, None + return {"status": "success"}, None self.assertIsNone(FlowsDirectFunction.validate_function(my_function_long_description)) func = FlowsDirectFunction(function=my_function_long_description) @@ -57,7 +57,7 @@ def test_properties_are_set_from_function(self): """Test that FlowsDirectFunction extracts the properties from the function.""" async def my_function_no_params(flow_manager: FlowManager): - return {}, None + return {"status": "success"}, None self.assertIsNone(FlowsDirectFunction.validate_function(my_function_no_params)) func = FlowsDirectFunction(function=my_function_no_params) @@ -66,7 +66,7 @@ async def my_function_no_params(flow_manager: FlowManager): async def my_function_simple_params( flow_manager: FlowManager, name: str, age: int, height: Union[float, None] ): - return {}, None + return {"status": "success"}, None self.assertIsNone(FlowsDirectFunction.validate_function(my_function_simple_params)) func = FlowsDirectFunction(function=my_function_simple_params) @@ -85,7 +85,7 @@ async def my_function_complex_params( nickname: str | int | float, extra: Optional[dict[str, str]], ): - return {}, None + return {"status": "success"}, None self.assertIsNone(FlowsDirectFunction.validate_function(my_function_complex_params)) func = FlowsDirectFunction(function=my_function_complex_params) @@ -116,7 +116,7 @@ class MyInfo2(TypedDict, total=False): async def my_function_complex_type_params( flow_manager: FlowManager, info1: MyInfo1, info2: MyInfo2 ): - return {}, None + return {"status": "success"}, None self.assertIsNone(FlowsDirectFunction.validate_function(my_function_complex_type_params)) func = FlowsDirectFunction(function=my_function_complex_type_params) @@ -145,7 +145,7 @@ def test_required_is_set_from_function(self): """Test that FlowsDirectFunction extracts the required properties from the function.""" async def my_function_no_params(flow_manager: FlowManager): - return {}, None + return {"status": "success"}, None self.assertIsNone(FlowsDirectFunction.validate_function(my_function_no_params)) func = FlowsDirectFunction(function=my_function_no_params) @@ -154,7 +154,7 @@ async def my_function_no_params(flow_manager: FlowManager): async def my_function_simple_params( flow_manager: FlowManager, name: str, age: int, height: Union[float, None] = None ): - return {}, None + return {"status": "success"}, None self.assertIsNone(FlowsDirectFunction.validate_function(my_function_simple_params)) func = FlowsDirectFunction(function=my_function_simple_params) @@ -166,7 +166,7 @@ async def my_function_complex_params( nickname: str | int | None = "Bud", extra: Optional[dict[str, str]] = None, ): - return {}, None + return {"status": "success"}, None self.assertIsNone(FlowsDirectFunction.validate_function(my_function_complex_params)) func = FlowsDirectFunction(function=my_function_complex_params) @@ -186,7 +186,7 @@ async def my_function( age (int): The age of the person. height (float | None): The height of the person in meters. Defaults to None. """ - return {}, None + return {"status": "success"}, None self.assertIsNone(FlowsDirectFunction.validate_function(my_function)) func = FlowsDirectFunction(function=my_function) @@ -211,19 +211,19 @@ def test_invalid_functions_fail_validation(self): """Test that invalid functions fail FlowsDirectFunction validation.""" def my_function_non_async(flow_manager: FlowManager): - return {}, None + return {"status": "success"}, None with self.assertRaises(InvalidFunctionError): FlowsDirectFunction.validate_function(my_function_non_async) async def my_function_missing_flow_manager(): - return {}, None + return {"status": "success"}, None with self.assertRaises(InvalidFunctionError): FlowsDirectFunction.validate_function(my_function_missing_flow_manager) async def my_function_misplaced_flow_manager(foo: str, flow_manager: FlowManager): - return {}, None + return {"status": "success"}, None with self.assertRaises(InvalidFunctionError): FlowsDirectFunction.validate_function(my_function_misplaced_flow_manager) From 59b985066129ef716f8871292dcdbfc2e517338c Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Wed, 28 May 2025 16:15:01 -0400 Subject: [PATCH 28/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20support=20static=20flows?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...restaurant_reservation_direct_functions.py | 14 +- .../static/food_ordering_direct_functions.py | 300 ++++++++++++++++++ src/pipecat_flows/__init__.py | 2 +- src/pipecat_flows/manager.py | 12 +- src/pipecat_flows/types.py | 12 +- 5 files changed, 322 insertions(+), 18 deletions(-) create mode 100644 examples/static/food_ordering_direct_functions.py diff --git a/examples/dynamic/restaurant_reservation_direct_functions.py b/examples/dynamic/restaurant_reservation_direct_functions.py index 5efcbf7..43e7529 100644 --- a/examples/dynamic/restaurant_reservation_direct_functions.py +++ b/examples/dynamic/restaurant_reservation_direct_functions.py @@ -23,7 +23,7 @@ from pipecat.services.openai.llm import OpenAILLMService from pipecat.transports.services.daily import DailyParams, DailyTransport -from pipecat_flows import FlowManager, FlowResult, NamedNodeConfig, NodeConfig +from pipecat_flows import FlowManager, FlowResult, NamedNode, NodeConfig sys.path.append(str(Path(__file__).parent.parent)) import argparse @@ -83,7 +83,7 @@ class TimeResult(FlowResult): # Function handlers async def collect_party_size( flow_manager: FlowManager, size: int -) -> tuple[PartySizeResult, NamedNodeConfig]: +) -> tuple[PartySizeResult, NamedNode]: """ Record the number of people in the party. @@ -94,7 +94,7 @@ async def collect_party_size( result = PartySizeResult(size=size, status="success") # Next node: time selection - # NOTE: name is optional, but useful for debug logging; you could use a NodeConfig here rather than a NamedNodeConfig + # NOTE: name is optional, but useful for debug logging; you could pass a NodeConfig here directly next_node = "get_time", create_time_selection_node() return result, next_node @@ -102,7 +102,7 @@ async def collect_party_size( async def check_availability( flow_manager: FlowManager, time: str, party_size: int -) -> tuple[TimeResult, NamedNodeConfig]: +) -> tuple[TimeResult, NamedNode]: """ Check availability for requested time. @@ -119,7 +119,7 @@ async def check_availability( ) # Next node: confirmation or no availability - # NOTE: name is optional, but useful for debug logging; you could use a NodeConfig here rather than a NamedNodeConfig + # NOTE: name is optional, but useful for debug logging; you could pass a NodeConfig here directly if is_available: next_node = "confirm", create_confirmation_node() else: @@ -128,9 +128,9 @@ async def check_availability( return result, next_node -async def end_conversation(flow_manager: FlowManager) -> tuple[None, NamedNodeConfig]: +async def end_conversation(flow_manager: FlowManager) -> tuple[None, NamedNode]: """End the conversation.""" - # NOTE: name is optional, but useful for debug logging; you could use a NodeConfig here rather than a NamedNodeConfig + # NOTE: name is optional, but useful for debug logging; you could pass a NodeConfig here directly return None, ("end", create_end_node()) diff --git a/examples/static/food_ordering_direct_functions.py b/examples/static/food_ordering_direct_functions.py new file mode 100644 index 0000000..c90e4a3 --- /dev/null +++ b/examples/static/food_ordering_direct_functions.py @@ -0,0 +1,300 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys +from pathlib import Path + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.cartesia.tts import CartesiaTTSService +from pipecat.services.deepgram.stt import DeepgramSTTService +from pipecat.services.openai.llm import OpenAILLMService +from pipecat.transports.services.daily import DailyParams, DailyTransport + +from pipecat_flows import FlowArgs, FlowConfig, FlowManager, FlowResult + +sys.path.append(str(Path(__file__).parent.parent)) +from runner import configure + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + +# Flow Configuration - Food ordering +# +# This configuration defines a food ordering system with the following states: +# +# 1. start +# - Initial state where user chooses between pizza or sushi +# - Functions: +# * choose_pizza (transitions to choose_pizza) +# * choose_sushi (transitions to choose_sushi) +# +# 2. choose_pizza +# - Handles pizza order details +# - Functions: +# * select_pizza_order (node function with size and type) +# * confirm_order (transitions to confirm) +# - Pricing: +# * Small: $10 +# * Medium: $15 +# * Large: $20 +# +# 3. choose_sushi +# - Handles sushi order details +# - Functions: +# * select_sushi_order (node function with count and type) +# * confirm_order (transitions to confirm) +# - Pricing: +# * $8 per roll +# +# 4. confirm +# - Reviews order details with the user +# - Functions: +# * complete_order (transitions to end) +# +# 5. end +# - Final state that closes the conversation +# - No functions available +# - Post-action: Ends conversation + + +# Type definitions +class PizzaOrderResult(FlowResult): + size: str + type: str + price: float + + +class SushiOrderResult(FlowResult): + count: int + type: str + price: float + + +# Actions +async def check_kitchen_status(action: dict) -> None: + """Check if kitchen is open and log status.""" + logger.info("Checking kitchen status") + + +# Functions +async def choose_pizza(flow_manager: FlowManager) -> tuple[None, str]: + """ + User wants to order pizza. Let's get that order started. + """ + return None, "choose_pizza" + + +async def choose_sushi(flow_manager: FlowManager) -> tuple[None, str]: + """ + User wants to order sushi. Let's get that order started. + """ + return None, "choose_sushi" + + +async def select_pizza_order( + flow_manager: FlowManager, size: str, pizza_type: str +) -> tuple[PizzaOrderResult, str]: + """ + Record the pizza order details. + + Args: + size (str): Size of the pizza. Must be one of "small", "medium", or "large". + pizza_type (str): Type of pizza. Must be one of "pepperoni", "cheese", "supreme", or "vegetarian". + """ + # Simple pricing + base_price = {"small": 10.00, "medium": 15.00, "large": 20.00} + price = base_price[size] + + return {"size": size, "type": pizza_type, "price": price}, "confirm" + + +async def select_sushi_order( + flow_manager: FlowManager, count: int, roll_type: str +) -> tuple[SushiOrderResult, str]: + """ + Record the sushi order details. + + Args: + count (int): Number of sushi rolls to order. Must be between 1 and 10. + roll_type (str): Type of sushi roll. Must be one of "california", "spicy tuna", "rainbow", or "dragon". + """ + # Simple pricing: $8 per roll + price = count * 8.00 + + return {"count": count, "type": roll_type, "price": price}, "confirm" + + +async def complete_order(flow_manager: FlowManager) -> tuple[None, str]: + """ + User confirms the order is correct. + """ + return None, "end" + + +async def revise_order(flow_manager: FlowManager) -> tuple[None, str]: + """ + User wants to make changes to their order. + """ + return None, "start" + + +flow_config: FlowConfig = { + "initial_node": "start", + "nodes": { + "start": { + "role_messages": [ + { + "role": "system", + "content": "You are an order-taking assistant. You must ALWAYS use the available functions to progress the conversation. This is a phone conversation and your responses will be converted to audio. Keep the conversation friendly, casual, and polite. Avoid outputting special characters and emojis.", + } + ], + "task_messages": [ + { + "role": "system", + "content": "For this step, ask the user if they want pizza or sushi, and wait for them to use a function to choose. Start off by greeting them. Be friendly and casual; you're taking an order for food over the phone.", + } + ], + "pre_actions": [ + { + "type": "check_kitchen", + "handler": check_kitchen_status, + }, + ], + "functions": [choose_pizza, choose_sushi], + }, + "choose_pizza": { + "task_messages": [ + { + "role": "system", + "content": """You are handling a pizza order. Use the available functions: +- Use select_pizza_order when the user specifies both size AND type + +Pricing: +- Small: $10 +- Medium: $15 +- Large: $20 + +Remember to be friendly and casual.""", + } + ], + "functions": [select_pizza_order], + }, + "choose_sushi": { + "task_messages": [ + { + "role": "system", + "content": """You are handling a sushi order. Use the available functions: +- Use select_sushi_order when the user specifies both count AND type + +Pricing: +- $8 per roll + +Remember to be friendly and casual.""", + } + ], + "functions": [select_sushi_order], + }, + "confirm": { + "task_messages": [ + { + "role": "system", + "content": """Read back the complete order details to the user and if they want anything else or if they want to make changes. Use the available functions: +- Use complete_order when the user confirms that the order is correct and no changes are needed +- Use revise_order if they want to change something + +Be friendly and clear when reading back the order details.""", + } + ], + "functions": [complete_order, revise_order], + }, + "end": { + "task_messages": [ + { + "role": "system", + "content": "Thank the user for their order and end the conversation politely and concisely.", + } + ], + "functions": [], + "post_actions": [{"type": "end_conversation"}], + }, + }, +} + + +async def main(): + """Main function to set up and run the food ordering bot.""" + async with aiohttp.ClientSession() as session: + (room_url, _) = await configure(session) + + # Initialize services + transport = DailyTransport( + room_url, + None, + "Food Ordering Bot", + DailyParams( + audio_in_enabled=True, + audio_out_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + ), + ) + + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="820a3788-2b37-4d21-847a-b65d8a68c99a", # Salesman + ) + llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") + + context = OpenAILLMContext() + context_aggregator = llm.create_context_aggregator(context) + + # Create pipeline + pipeline = Pipeline( + [ + transport.input(), + stt, + context_aggregator.user(), + llm, + tts, + transport.output(), + context_aggregator.assistant(), + ] + ) + + task = PipelineTask(pipeline, params=PipelineParams(allow_interruptions=True)) + + # Initialize flow manager in static mode + flow_manager = FlowManager( + task=task, + llm=llm, + context_aggregator=context_aggregator, + tts=tts, + flow_config=flow_config, + ) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + await transport.capture_participant_transcription(participant["id"]) + logger.debug("Initializing flow") + await flow_manager.initialize() + + runner = PipelineRunner() + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/pipecat_flows/__init__.py b/src/pipecat_flows/__init__.py index bdb96f2..d9c6b56 100644 --- a/src/pipecat_flows/__init__.py +++ b/src/pipecat_flows/__init__.py @@ -72,7 +72,7 @@ async def handle_transitions(function_name: str, args: Dict, flow_manager): FlowResult, FlowsFunctionSchema, LegacyFunctionHandler, - NamedNodeConfig, + NamedNode, NodeConfig, ) diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index c20c159..5469190 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -58,6 +58,7 @@ FlowsDirectFunction, FlowsFunctionSchema, FunctionHandler, + NamedNode, NodeConfig, UnifiedFunctionResult, ) @@ -309,7 +310,7 @@ def decrease_pending_function_calls() -> None: ) async def on_context_updated_edge( - next_node: Optional[NodeConfig], + next_node: Optional[NodeConfig | NamedNode], args: Optional[Dict[str, Any]], result: Optional[Any], result_callback: Callable, @@ -331,11 +332,12 @@ async def on_context_updated_edge( # Only process transition if this was the last pending call if self._pending_function_calls == 0: - if next_node: - # TODO: handle possibility of next_node being a string identifying a node? for static flows? mabe we can just say direct functions are only supported in dynamic flows? - if isinstance(next_node, tuple): + if next_node: # Function-returned next node (as opposed to next node specified by transition_*) + if isinstance(next_node, str): # Static flow + next_node_name, next_node = next_node, self.nodes[next_node] + elif isinstance(next_node, tuple): # Dynamic flow with named node next_node_name, next_node = next_node - else: + else: # Dynamic flow with anonymous node next_node_name, next_node = str(uuid.uuid4()), next_node logger.debug(f"Transition to function-returned node: {next_node_name}") await self.set_node(next_node_name, next_node) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index 0898978..b5b8b0b 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -83,12 +83,14 @@ class FlowResult(TypedDict, total=False): } """ -NamedNodeConfig = tuple[str, "NodeConfig"] -"""Type alias for a node configuration with its name.""" +NamedNode = str | tuple[str, "NodeConfig"] +""" +Type alias for a named node, which can either be: +- A string representing the node name (for static flows) +- A tuple containing the node name and a NodeConfig instance (for dynamic flows) +""" -UnifiedFunctionResult = Tuple[ - Optional[FlowResult], Optional["NodeConfig"] | Optional[NamedNodeConfig] -] +UnifiedFunctionResult = Tuple[Optional[FlowResult], Optional["NodeConfig"] | Optional[NamedNode]] """Return type for "unified" functions that do either or both of handling some processing as well as specifying the next node.""" LegacyFunctionHandler = Callable[[FlowArgs], Awaitable[FlowResult | UnifiedFunctionResult]] From b9835004023bbda26506cec6cac0b586140206bd Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Thu, 29 May 2025 11:22:13 -0400 Subject: [PATCH 29/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20some=20typing=20fixes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/__init__.py | 5 +++++ src/pipecat_flows/manager.py | 5 +++-- src/pipecat_flows/types.py | 25 +++++++++++++++++++++++-- 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/pipecat_flows/__init__.py b/src/pipecat_flows/__init__.py index d9c6b56..d52dbe4 100644 --- a/src/pipecat_flows/__init__.py +++ b/src/pipecat_flows/__init__.py @@ -66,6 +66,7 @@ async def handle_transitions(function_name: str, args: Dict, flow_manager): from .types import ( ContextStrategy, ContextStrategyConfig, + DirectFunction, FlowArgs, FlowConfig, FlowFunctionHandler, @@ -74,6 +75,7 @@ async def handle_transitions(function_name: str, args: Dict, flow_manager): LegacyFunctionHandler, NamedNode, NodeConfig, + UnifiedFunctionResult, ) __all__ = [ @@ -86,8 +88,11 @@ async def handle_transitions(function_name: str, args: Dict, flow_manager): "FlowConfig", "FlowFunctionHandler", "FlowResult", + "UnifiedFunctionResult", "FlowsFunctionSchema", "LegacyFunctionHandler", + "DirectFunction", + "NamedNode", "NodeConfig", # Exceptions "FlowError", diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index 5469190..715fd43 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -492,7 +492,8 @@ async def _register_function( Args: name: Name of the function to register - handler: The function handler to register + handler: A callable function handler, a FlowsDirectFunction, or a string. + If string starts with '__function__:', extracts the function name after the prefix. transition_to: Optional node to transition to (static flows) transition_callback: Optional transition callback (dynamic flows) new_functions: Set to track newly registered functions for this node @@ -603,7 +604,7 @@ async def register_direct_function(func): ) for func_config in functions_list: - # Handle FlowsDirectFunctions + # Handle direct functions if callable(func_config): await register_direct_function(func_config) # Handle Gemini's nested function declarations as a special case diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index b5b8b0b..c566aea 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -28,6 +28,7 @@ List, Mapping, Optional, + Protocol, Set, Tuple, TypedDict, @@ -116,10 +117,30 @@ class FlowResult(TypedDict, total=False): FlowResult: Result of the function execution """ + FunctionHandler = Union[LegacyFunctionHandler, FlowFunctionHandler] """Union type for function handlers supporting both legacy and modern patterns.""" +class DirectFunction(Protocol): + """ + \"Direct\" function whose definition is automatically extracted from the function signature and docstring. + This can be used in NodeConfigs directly, in lieu of a FlowsFunctionSchema or function definition dict. + + Args: + flow_manager: Reference to the FlowManager instance + **kwargs: Additional keyword arguments + + Returns: + UnifiedFunctionResult: Result of the function execution, which can include both a FlowResult + and the next node to transition to. + """ + + def __call__( + self, flow_manager: "FlowManager", **kwargs: Any + ) -> Awaitable[UnifiedFunctionResult]: ... + + LegacyActionHandler = Callable[[Dict[str, Any]], Awaitable[None]] """Legacy action handler type that only receives the action dictionary. @@ -433,7 +454,7 @@ class NodeConfig(NodeConfigRequired, total=False): Optional fields: role_messages: List of message dicts defining the bot's role/personality functions: List of function definitions in provider-specific format, FunctionSchema, - or FlowsFunctionSchema + or FlowsFunctionSchema; or a "direct function" whose definition is automatically extracted pre_actions: Actions to execute before LLM inference post_actions: Actions to execute after LLM inference context_strategy: Strategy for updating context during transitions @@ -461,7 +482,7 @@ class NodeConfig(NodeConfigRequired, total=False): """ role_messages: List[Dict[str, Any]] - functions: List[Union[Dict[str, Any], FlowsFunctionSchema, FlowsDirectFunction]] + functions: List[Union[Dict[str, Any], FlowsFunctionSchema, DirectFunction]] pre_actions: List[ActionConfig] post_actions: List[ActionConfig] context_strategy: ContextStrategyConfig From daef6fa619083633aba77aa2b12c8c56930442b3 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Thu, 29 May 2025 11:26:40 -0400 Subject: [PATCH 30/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20some=20typing=20fixes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index c566aea..983a8c6 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -91,7 +91,7 @@ class FlowResult(TypedDict, total=False): - A tuple containing the node name and a NodeConfig instance (for dynamic flows) """ -UnifiedFunctionResult = Tuple[Optional[FlowResult], Optional["NodeConfig"] | Optional[NamedNode]] +UnifiedFunctionResult = Tuple[Optional[FlowResult], Optional[Union["NodeConfig", NamedNode]]] """Return type for "unified" functions that do either or both of handling some processing as well as specifying the next node.""" LegacyFunctionHandler = Callable[[FlowArgs], Awaitable[FlowResult | UnifiedFunctionResult]] From c317da797540ad403db262e6b322527ecef0d457 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Thu, 29 May 2025 12:29:00 -0400 Subject: [PATCH 31/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20add=20support=20for=20direct=20fu?= =?UTF-8?q?nctions=20with=20Gemini?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/adapters.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/pipecat_flows/adapters.py b/src/pipecat_flows/adapters.py index eb48c0d..75f339f 100644 --- a/src/pipecat_flows/adapters.py +++ b/src/pipecat_flows/adapters.py @@ -18,7 +18,7 @@ """ import sys -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from loguru import logger from pipecat.adapters.base_llm_adapter import BaseLLMAdapter @@ -29,7 +29,7 @@ from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter -from .types import FlowsFunctionSchema +from .types import FlowsDirectFunction, FlowsFunctionSchema class LLMAdapter: @@ -405,6 +405,20 @@ def format_functions( }, } ) + elif isinstance(func_config, Callable): + # Convert direct function to Gemini format + direct_func = FlowsDirectFunction(function=func_config) + gemini_functions.append( + { + "name": direct_func.name, + "description": direct_func.description, + "parameters": { + "type": "object", + "properties": direct_func.properties, + "required": direct_func.required, + }, + } + ) elif "function_declarations" in func_config: # Already in Gemini format, use directly but remove handler/transition fields for decl in func_config["function_declarations"]: From 70467937e6a9c3902fabe31a29c00f05d1b51ae3 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Thu, 29 May 2025 12:42:47 -0400 Subject: [PATCH 32/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20rename=20"unified"=20functions=20?= =?UTF-8?q?to=20"consolidated"=20functions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/__init__.py | 4 ++-- src/pipecat_flows/manager.py | 9 ++++----- src/pipecat_flows/types.py | 20 ++++++++++++-------- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/pipecat_flows/__init__.py b/src/pipecat_flows/__init__.py index d52dbe4..dee7bb9 100644 --- a/src/pipecat_flows/__init__.py +++ b/src/pipecat_flows/__init__.py @@ -75,7 +75,7 @@ async def handle_transitions(function_name: str, args: Dict, flow_manager): LegacyFunctionHandler, NamedNode, NodeConfig, - UnifiedFunctionResult, + ConsolidatedFunctionResult, ) __all__ = [ @@ -88,7 +88,7 @@ async def handle_transitions(function_name: str, args: Dict, flow_manager): "FlowConfig", "FlowFunctionHandler", "FlowResult", - "UnifiedFunctionResult", + "ConsolidatedFunctionResult", "FlowsFunctionSchema", "LegacyFunctionHandler", "DirectFunction", diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index 715fd43..feefd04 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -60,7 +60,7 @@ FunctionHandler, NamedNode, NodeConfig, - UnifiedFunctionResult, + ConsolidatedFunctionResult, ) if TYPE_CHECKING: @@ -235,7 +235,7 @@ def _register_action_from_config(self, action: ActionConfig) -> None: async def _call_handler( self, handler: FunctionHandler, args: FlowArgs - ) -> FlowResult | UnifiedFunctionResult: + ) -> FlowResult | ConsolidatedFunctionResult: """Call handler with appropriate parameters based on its signature. Detects whether the handler can accept a flow_manager parameter and @@ -392,7 +392,7 @@ async def transition_func(params: FunctionCallParams) -> None: handler_response = await handler.invoke(params.arguments, self) else: handler_response = await self._call_handler(handler, params.arguments) - # Support both "unified" handlers that return (result, next_node) and handlers + # Support both "consolidated" handlers that return (result, next_node) and handlers # that return just the result. if isinstance(handler_response, tuple): result, next_node = handler_response @@ -402,7 +402,7 @@ async def transition_func(params: FunctionCallParams) -> None: else: result = handler_response next_node = None - # FlowsDirectFunctions should always be "unified" functions that return a tuple + # FlowsDirectFunctions should always be "consolidated" functions that return a tuple if isinstance(handler, FlowsDirectFunction): raise InvalidFunctionError( f"Direct function {name} expected to return a tuple (result, next_node) but got {type(result)}" @@ -411,7 +411,6 @@ async def transition_func(params: FunctionCallParams) -> None: result = acknowledged_result next_node = None is_transition_only_function = True - # TODO: test transition-only and non-transition-only functions using both transitional and unified functions logger.debug( f"{'Transition-only function called for' if is_transition_only_function else 'Function handler completed for'} {name}" ) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index 983a8c6..2a168eb 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -91,10 +91,14 @@ class FlowResult(TypedDict, total=False): - A tuple containing the node name and a NodeConfig instance (for dynamic flows) """ -UnifiedFunctionResult = Tuple[Optional[FlowResult], Optional[Union["NodeConfig", NamedNode]]] -"""Return type for "unified" functions that do either or both of handling some processing as well as specifying the next node.""" +ConsolidatedFunctionResult = Tuple[Optional[FlowResult], Optional[Union["NodeConfig", NamedNode]]] +""" +Return type for "consolidated" functions that do either or both of: +- doing some work +- specifying the next node to transition to after the work is done +""" -LegacyFunctionHandler = Callable[[FlowArgs], Awaitable[FlowResult | UnifiedFunctionResult]] +LegacyFunctionHandler = Callable[[FlowArgs], Awaitable[FlowResult | ConsolidatedFunctionResult]] """Legacy function handler that only receives arguments. Args: @@ -105,7 +109,7 @@ class FlowResult(TypedDict, total=False): """ FlowFunctionHandler = Callable[ - [FlowArgs, "FlowManager"], Awaitable[FlowResult | UnifiedFunctionResult] + [FlowArgs, "FlowManager"], Awaitable[FlowResult | ConsolidatedFunctionResult] ] """Modern function handler that receives both arguments and flow_manager. @@ -132,13 +136,13 @@ class DirectFunction(Protocol): **kwargs: Additional keyword arguments Returns: - UnifiedFunctionResult: Result of the function execution, which can include both a FlowResult - and the next node to transition to. + ConsolidatedFunctionResult: Result of the function execution, which can include both a + FlowResult and the next node to transition to. """ def __call__( self, flow_manager: "FlowManager", **kwargs: Any - ) -> Awaitable[UnifiedFunctionResult]: ... + ) -> Awaitable[ConsolidatedFunctionResult]: ... LegacyActionHandler = Callable[[Dict[str, Any]], Awaitable[None]] @@ -285,7 +289,7 @@ def validate_function(function: Callable) -> None: async def invoke( self, args: Mapping[str, Any], flow_manager: "FlowManager" - ) -> UnifiedFunctionResult: + ) -> ConsolidatedFunctionResult: # print(f"[pk] Invoking function {self.name} with args: {args}") return await self.function(flow_manager=flow_manager, **args) From 04965ee9b6b9dcfccdac7fdb6d7007d5717c94fd Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Thu, 29 May 2025 12:47:23 -0400 Subject: [PATCH 33/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20remove=20comment?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_manager.py b/tests/test_manager.py index 4879da5..84e167d 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -509,7 +509,6 @@ async def handler_no_args(): result = await flow_manager._call_handler(handler_no_args, {}) self.assertEqual(result["status"], "success") - # TODO: test async def test_transition_func_error_handling(self): """Test error handling in transition functions.""" flow_manager = FlowManager( From 1f69ce77af62511cb570dc7e958fa452328b9ec2 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Thu, 29 May 2025 12:52:10 -0400 Subject: [PATCH 34/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20some=20minor=20cleanup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/__init__.py | 2 +- src/pipecat_flows/manager.py | 2 +- src/pipecat_flows/types.py | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/pipecat_flows/__init__.py b/src/pipecat_flows/__init__.py index dee7bb9..e2b7295 100644 --- a/src/pipecat_flows/__init__.py +++ b/src/pipecat_flows/__init__.py @@ -64,6 +64,7 @@ async def handle_transitions(function_name: str, args: Dict, flow_manager): ) from .manager import FlowManager from .types import ( + ConsolidatedFunctionResult, ContextStrategy, ContextStrategyConfig, DirectFunction, @@ -75,7 +76,6 @@ async def handle_transitions(function_name: str, args: Dict, flow_manager): LegacyFunctionHandler, NamedNode, NodeConfig, - ConsolidatedFunctionResult, ) __all__ = [ diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index feefd04..4584d9d 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -50,6 +50,7 @@ ) from .types import ( ActionConfig, + ConsolidatedFunctionResult, ContextStrategy, ContextStrategyConfig, FlowArgs, @@ -60,7 +61,6 @@ FunctionHandler, NamedNode, NodeConfig, - ConsolidatedFunctionResult, ) if TYPE_CHECKING: diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index 2a168eb..9284a34 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -136,7 +136,7 @@ class DirectFunction(Protocol): **kwargs: Additional keyword arguments Returns: - ConsolidatedFunctionResult: Result of the function execution, which can include both a + ConsolidatedFunctionResult: Result of the function execution, which can include both a FlowResult and the next node to transition to. """ @@ -290,7 +290,6 @@ def validate_function(function: Callable) -> None: async def invoke( self, args: Mapping[str, Any], flow_manager: "FlowManager" ) -> ConsolidatedFunctionResult: - # print(f"[pk] Invoking function {self.name} with args: {args}") return await self.function(flow_manager=flow_manager, **args) def to_function_schema(self) -> FunctionSchema: From 6361de5767305ad13b0a8c8e14dd2a14f4073a50 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Thu, 29 May 2025 15:03:04 -0400 Subject: [PATCH 35/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20changelog?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 92 +++++++++++++++++++ ...restaurant_reservation_direct_functions.py | 6 +- 2 files changed, 95 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 411b2ea..03d6ae5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,98 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Added support for providing "consolidated" functions, which are responsible + for both doing some work as well as specifying the next node to transition + to. When using consolidated functions, you don't specify `transition_to` or + `transition_callback`. + + Usage: + + ```python + # "Consolidated" function + async def do_something(args: FlowArgs) -> tuple[FlowResult, NamedNode]: + foo = args["foo"] + bar = args.get("bar", "") + + # Do some work (optional; this function may be a transition-only function) + result = await process(foo, bar) + + # Specify next node (optional; this function may be a work-only function) + # Here, you could use: + # - A NodeConfig by itself + # - A tuple of (name, NodeConfig), as shown below (name is helpful for debug logging) + # - A string identifying a node in a static flow + next_node = ("another_node", create_another_node()) + + return result, next_node + + def create_a_node() -> NodeConfig: + return NodeConfig( + task_messages=[ + # ... + ], + functions=[FlowsFunctionSchema( + name="do_something", + description="Do something interesting.", + handler=do_something, + properties={ + "foo": { + "type": "integer", + "description": "The foo to do something interesting with." + }, + "bar": { + "type": "string", + "description": "The bar to do something interesting with." + } + }, + required=["foo"], + )], + ) + ``` + +- Added support for providing "direct" functions, which don't need an + accompanying `FlowsFunctionSchema` or function definition dict. Instead, + metadata (i.e. `name`, `description`, `properties`, and `required`) are + automatically extracted from a combination of the function signature and + docstring. + + Usage: + + ```python + # "Direct" function + # `flow_manager` must be the first parameter + async def do_something(flow_manager: FlowManager, foo: int, bar: str = "") -> tuple[FlowResult, NamedNode]: + """ + Do something interesting. + + Args: + foo (int): The foo to do something interesting with. + bar (string): The bar to do something interesting with. + """ + + # Do some work (optional; this function may be a transition-only function) + result = await process(foo, bar) + + # Specify next node (optional; this function may be a work-only function) + # Here, you could use: + # - A NodeConfig by itself + # - A tuple of (name, NodeConfig), as shown below (name is helpful for debug logging) + # - A string identifying a node in a static flow + next_node = ("another_node", create_another_node()) + + return result, next_node + + def create_a_node() -> NodeConfig: + return NodeConfig( + task_messages=[ + # ... + ], + functions=[do_something] + ) + ``` + ### Changed - `functions` are now optional in the `NodeConfig`. Additionally, for AWS diff --git a/examples/dynamic/restaurant_reservation_direct_functions.py b/examples/dynamic/restaurant_reservation_direct_functions.py index 43e7529..381ca13 100644 --- a/examples/dynamic/restaurant_reservation_direct_functions.py +++ b/examples/dynamic/restaurant_reservation_direct_functions.py @@ -94,7 +94,7 @@ async def collect_party_size( result = PartySizeResult(size=size, status="success") # Next node: time selection - # NOTE: name is optional, but useful for debug logging; you could pass a NodeConfig here directly + # NOTE: name is optional, but useful for debug logging; you could use a NodeConfig here directly next_node = "get_time", create_time_selection_node() return result, next_node @@ -119,7 +119,7 @@ async def check_availability( ) # Next node: confirmation or no availability - # NOTE: name is optional, but useful for debug logging; you could pass a NodeConfig here directly + # NOTE: name is optional, but useful for debug logging; you could use a NodeConfig here directly if is_available: next_node = "confirm", create_confirmation_node() else: @@ -130,7 +130,7 @@ async def check_availability( async def end_conversation(flow_manager: FlowManager) -> tuple[None, NamedNode]: """End the conversation.""" - # NOTE: name is optional, but useful for debug logging; you could pass a NodeConfig here directly + # NOTE: name is optional, but useful for debug logging; you could use a NodeConfig here directly return None, ("end", create_end_node()) From bd437429728e6f5358a80ad4dfa83b55cda6b5e9 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Thu, 29 May 2025 15:13:02 -0400 Subject: [PATCH 36/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20test=20tweak=20to=20better=20exer?= =?UTF-8?q?cise=20non-required=20parameters=20(specifically,=20to=20ensure?= =?UTF-8?q?=20that=20`Optional`=20and=20not=20required=20are=20not=20the?= =?UTF-8?q?=20same)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_flows_direct_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_flows_direct_functions.py b/tests/test_flows_direct_functions.py index 1d4e508..a22c0e3 100644 --- a/tests/test_flows_direct_functions.py +++ b/tests/test_flows_direct_functions.py @@ -162,8 +162,8 @@ async def my_function_simple_params( async def my_function_complex_params( flow_manager: FlowManager, - address_lines: list[str], - nickname: str | int | None = "Bud", + address_lines: Optional[list[str]], + nickname: str | int = "Bud", extra: Optional[dict[str, str]] = None, ): return {"status": "success"}, None From 33abc74b30b46283536e409030eb566f846cf45f Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 30 May 2025 11:02:41 -0400 Subject: [PATCH 37/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20add=20optional=20`name`=20field?= =?UTF-8?q?=20to=20`NodeConfig`=20so=20that=20we=20don't=20have=20to=20hav?= =?UTF-8?q?e=20the=20awkward=20`NamedNode`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 18 +++++---------- ...restaurant_reservation_direct_functions.py | 23 ++++++++++--------- src/pipecat_flows/__init__.py | 2 -- src/pipecat_flows/manager.py | 17 +++++++------- src/pipecat_flows/types.py | 15 +++++------- 5 files changed, 32 insertions(+), 43 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 03d6ae5..5f70732 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ```python # "Consolidated" function - async def do_something(args: FlowArgs) -> tuple[FlowResult, NamedNode]: + async def do_something(args: FlowArgs) -> tuple[FlowResult, NodeConfig]: foo = args["foo"] bar = args.get("bar", "") @@ -26,11 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 result = await process(foo, bar) # Specify next node (optional; this function may be a work-only function) - # Here, you could use: - # - A NodeConfig by itself - # - A tuple of (name, NodeConfig), as shown below (name is helpful for debug logging) - # - A string identifying a node in a static flow - next_node = ("another_node", create_another_node()) + # This is either a NodeConfig (for dynamic flows) or a node name (for static flows) + next_node = create_another_node() return result, next_node @@ -69,7 +66,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ```python # "Direct" function # `flow_manager` must be the first parameter - async def do_something(flow_manager: FlowManager, foo: int, bar: str = "") -> tuple[FlowResult, NamedNode]: + async def do_something(flow_manager: FlowManager, foo: int, bar: str = "") -> tuple[FlowResult, NodeConfig]: """ Do something interesting. @@ -82,11 +79,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 result = await process(foo, bar) # Specify next node (optional; this function may be a work-only function) - # Here, you could use: - # - A NodeConfig by itself - # - A tuple of (name, NodeConfig), as shown below (name is helpful for debug logging) - # - A string identifying a node in a static flow - next_node = ("another_node", create_another_node()) + # This is either a NodeConfig (for dynamic flows) or a node name (for static flows) + next_node = create_another_node() return result, next_node diff --git a/examples/dynamic/restaurant_reservation_direct_functions.py b/examples/dynamic/restaurant_reservation_direct_functions.py index 381ca13..6b55813 100644 --- a/examples/dynamic/restaurant_reservation_direct_functions.py +++ b/examples/dynamic/restaurant_reservation_direct_functions.py @@ -23,7 +23,7 @@ from pipecat.services.openai.llm import OpenAILLMService from pipecat.transports.services.daily import DailyParams, DailyTransport -from pipecat_flows import FlowManager, FlowResult, NamedNode, NodeConfig +from pipecat_flows import FlowManager, FlowResult, NodeConfig sys.path.append(str(Path(__file__).parent.parent)) import argparse @@ -83,7 +83,7 @@ class TimeResult(FlowResult): # Function handlers async def collect_party_size( flow_manager: FlowManager, size: int -) -> tuple[PartySizeResult, NamedNode]: +) -> tuple[PartySizeResult, NodeConfig]: """ Record the number of people in the party. @@ -94,15 +94,14 @@ async def collect_party_size( result = PartySizeResult(size=size, status="success") # Next node: time selection - # NOTE: name is optional, but useful for debug logging; you could use a NodeConfig here directly - next_node = "get_time", create_time_selection_node() + next_node = create_time_selection_node() return result, next_node async def check_availability( flow_manager: FlowManager, time: str, party_size: int -) -> tuple[TimeResult, NamedNode]: +) -> tuple[TimeResult, NodeConfig]: """ Check availability for requested time. @@ -119,19 +118,17 @@ async def check_availability( ) # Next node: confirmation or no availability - # NOTE: name is optional, but useful for debug logging; you could use a NodeConfig here directly if is_available: - next_node = "confirm", create_confirmation_node() + next_node = create_confirmation_node() else: - next_node = "no_availability", create_no_availability_node(alternative_times) + next_node = create_no_availability_node(alternative_times) return result, next_node -async def end_conversation(flow_manager: FlowManager) -> tuple[None, NamedNode]: +async def end_conversation(flow_manager: FlowManager) -> tuple[None, NodeConfig]: """End the conversation.""" - # NOTE: name is optional, but useful for debug logging; you could use a NodeConfig here directly - return None, ("end", create_end_node()) + return None, create_end_node() # Node configurations @@ -159,6 +156,7 @@ def create_time_selection_node() -> NodeConfig: """Create node for time selection and availability check.""" logger.debug("Creating time selection node") return { + "name": "get_time", "task_messages": [ { "role": "system", @@ -172,6 +170,7 @@ def create_time_selection_node() -> NodeConfig: def create_confirmation_node() -> NodeConfig: """Create confirmation node for successful reservations.""" return { + "name": "confirm", "task_messages": [ { "role": "system", @@ -186,6 +185,7 @@ def create_no_availability_node(alternative_times: list[str]) -> NodeConfig: """Create node for handling no availability.""" times_list = ", ".join(alternative_times) return { + "name": "no_availability", "task_messages": [ { "role": "system", @@ -203,6 +203,7 @@ def create_no_availability_node(alternative_times: list[str]) -> NodeConfig: def create_end_node() -> NodeConfig: """Create the final node.""" return { + "name": "end", "task_messages": [ { "role": "system", diff --git a/src/pipecat_flows/__init__.py b/src/pipecat_flows/__init__.py index e2b7295..cfb02c5 100644 --- a/src/pipecat_flows/__init__.py +++ b/src/pipecat_flows/__init__.py @@ -74,7 +74,6 @@ async def handle_transitions(function_name: str, args: Dict, flow_manager): FlowResult, FlowsFunctionSchema, LegacyFunctionHandler, - NamedNode, NodeConfig, ) @@ -92,7 +91,6 @@ async def handle_transitions(function_name: str, args: Dict, flow_manager): "FlowsFunctionSchema", "LegacyFunctionHandler", "DirectFunction", - "NamedNode", "NodeConfig", # Exceptions "FlowError", diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index 4584d9d..c73b662 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -59,7 +59,6 @@ FlowsDirectFunction, FlowsFunctionSchema, FunctionHandler, - NamedNode, NodeConfig, ) @@ -310,7 +309,7 @@ def decrease_pending_function_calls() -> None: ) async def on_context_updated_edge( - next_node: Optional[NodeConfig | NamedNode], + next_node: Optional[NodeConfig | str], args: Optional[Dict[str, Any]], result: Optional[Any], result_callback: Callable, @@ -334,13 +333,13 @@ async def on_context_updated_edge( if self._pending_function_calls == 0: if next_node: # Function-returned next node (as opposed to next node specified by transition_*) if isinstance(next_node, str): # Static flow - next_node_name, next_node = next_node, self.nodes[next_node] - elif isinstance(next_node, tuple): # Dynamic flow with named node - next_node_name, next_node = next_node - else: # Dynamic flow with anonymous node - next_node_name, next_node = str(uuid.uuid4()), next_node - logger.debug(f"Transition to function-returned node: {next_node_name}") - await self.set_node(next_node_name, next_node) + node_name = next_node + node = self.nodes[next_node] + else: # Dynamic flow + node_name = next_node.get("name", str(uuid.uuid4())) + node = next_node + logger.debug(f"Transition to function-returned node: {node_name}") + await self.set_node(node_name, node) elif transition_to: # Static flow logger.debug(f"Static transition to: {transition_to}") await self.set_node(transition_to, self.nodes[transition_to]) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index 9284a34..4db4a09 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -84,18 +84,13 @@ class FlowResult(TypedDict, total=False): } """ -NamedNode = str | tuple[str, "NodeConfig"] -""" -Type alias for a named node, which can either be: -- A string representing the node name (for static flows) -- A tuple containing the node name and a NodeConfig instance (for dynamic flows) -""" - -ConsolidatedFunctionResult = Tuple[Optional[FlowResult], Optional[Union["NodeConfig", NamedNode]]] +ConsolidatedFunctionResult = Tuple[Optional[FlowResult], Optional[Union["NodeConfig", str]]] """ Return type for "consolidated" functions that do either or both of: - doing some work -- specifying the next node to transition to after the work is done +- specifying the next node to transition to after the work is done, specified as either: + - a NodeConfig (for dynamic flows) + - a node name (for static flows) """ LegacyFunctionHandler = Callable[[FlowArgs], Awaitable[FlowResult | ConsolidatedFunctionResult]] @@ -455,6 +450,7 @@ class NodeConfig(NodeConfigRequired, total=False): task_messages: List of message dicts defining the current node's objectives Optional fields: + name: Name of the node, useful for debug logging role_messages: List of message dicts defining the bot's role/personality functions: List of function definitions in provider-specific format, FunctionSchema, or FlowsFunctionSchema; or a "direct function" whose definition is automatically extracted @@ -484,6 +480,7 @@ class NodeConfig(NodeConfigRequired, total=False): } """ + name: str role_messages: List[Dict[str, Any]] functions: List[Union[Dict[str, Any], FlowsFunctionSchema, DirectFunction]] pre_actions: List[ActionConfig] From 8f2e1edbb09a84ad824c6256d207f26b044ad04f Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 30 May 2025 14:18:21 -0400 Subject: [PATCH 38/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20deprecate=20`transition=5Fto`=20a?= =?UTF-8?q?nd=20`transition=5Fcallback`=20in=20favor=20of=20"consolidated"?= =?UTF-8?q?=20`handler`=20functions=20that=20return=20a=20tuple=20(result,?= =?UTF-8?q?=20next=5Fnode)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/manager.py | 18 ++++++++++++++++++ src/pipecat_flows/types.py | 9 +++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index c73b662..ade268f 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -27,6 +27,7 @@ import inspect import sys import uuid +import warnings from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Union, cast from loguru import logger @@ -142,6 +143,8 @@ def __init__( self.current_functions: Set[str] = set() # Track registered functions self.current_node: Optional[str] = None + self._showed_deprecation_warning_for_transition_fields = False + def _validate_transition_callback(self, name: str, callback: Any) -> None: """Validate a transition callback. @@ -838,3 +841,18 @@ def _validate_node_config(self, node_id: str, config: NodeConfig) -> None: logger.warning( f"Function '{name}' in node '{node_id}' has neither handler, transition_to, nor transition_callback" ) + + # Warn about usage of deprecated transition_to and transition_callback + if ( + has_transition_to + or has_transition_callback + and not self._showed_deprecation_warning_for_transition_fields + ): + self._showed_deprecation_warning_for_transition_fields = True + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + '`transition_to` and `transition_callback` are deprecated and will be removed in a future version. Use a "consolidated" `handler` that returns a tuple (result, next_node) instead', + DeprecationWarning, + stacklevel=2, + ) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index 4db4a09..7a1ccae 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -231,8 +231,13 @@ class FlowsFunctionSchema: properties: Dictionary defining properties types and descriptions required: List of required parameters handler: Function handler to process the function call - transition_to: Target node to transition to after function execution - transition_callback: Callback function for dynamic transitions + transition_to: Target node to transition to after function execution (deprecated) + transition_callback: Callback function for dynamic transitions (deprecated) + + Deprecated: + 0.0.18: `transition_to` and `transition_callback` are deprecated and will be removed in a + future version. Use a "consolidated" `handler` that returns a tuple (result, next_node) + instead. """ name: str From 538cf83d0788e8d87264105cff1dcbfae271196e Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 30 May 2025 14:51:44 -0400 Subject: [PATCH 39/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20update=20CHANGELOG?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 11 +++++++++++ src/pipecat_flows/manager.py | 2 +- src/pipecat_flows/types.py | 3 ++- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f70732..50a8e04 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Addded a new optional `name` field to `NodeConfig`. When using dynamic flows alongside + "consolidated" functions that return a tuple (result, next node), giving the next node a `name` is + helpful for debug logging. If you don't specify a `name`, an automatically-generated UUID is used. + - Added support for providing "consolidated" functions, which are responsible for both doing some work as well as specifying the next node to transition to. When using consolidated functions, you don't specify `transition_to` or @@ -93,6 +97,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ) ``` +### Deprecated + +- Deprecated `transition_to` and `transition_callback` in favor of "consolidated" `handler`s that + return a tuple (result, next node). Alternatively, you could use "direct" functions and avoid + using `FlowsFunctionSchema`s or function definition dicts entirely. See the "Added" section above + for more details. + ### Changed - `functions` are now optional in the `NodeConfig`. Additionally, for AWS diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index ade268f..71667e5 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -852,7 +852,7 @@ def _validate_node_config(self, node_id: str, config: NodeConfig) -> None: with warnings.catch_warnings(): warnings.simplefilter("always") warnings.warn( - '`transition_to` and `transition_callback` are deprecated and will be removed in a future version. Use a "consolidated" `handler` that returns a tuple (result, next_node) instead', + '`transition_to` and `transition_callback` are deprecated and will be removed in a future version. Use a "consolidated" `handler` that returns a tuple (result, next_node) instead.', DeprecationWarning, stacklevel=2, ) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index 7a1ccae..d7ee133 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -455,7 +455,8 @@ class NodeConfig(NodeConfigRequired, total=False): task_messages: List of message dicts defining the current node's objectives Optional fields: - name: Name of the node, useful for debug logging + name: Name of the node, useful for debug logging when returning a next node from a + "consolidated" function role_messages: List of message dicts defining the bot's role/personality functions: List of function definitions in provider-specific format, FunctionSchema, or FlowsFunctionSchema; or a "direct function" whose definition is automatically extracted From 4907c5de56e226b676fc9b6c33ce2bbcdfb139f6 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 30 May 2025 15:28:58 -0400 Subject: [PATCH 40/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20`initialize()`=20should=20take=20?= =?UTF-8?q?an=20initial=20node,=20obviating=20the=20need=20for=20the=20use?= =?UTF-8?q?r=20to=20call=20`set=5Fnode`=20directly;=20this=20change=20feel?= =?UTF-8?q?s=20important=20to=20do=20now=20that=20we've=20added=20the=20`n?= =?UTF-8?q?ame`=20field=20to=20`NodeConfig`=20so=20that=20users=20don't=20?= =?UTF-8?q?end=20up=20confusingly=20specifying=20(potentially=20different)?= =?UTF-8?q?=20names=20in=20their=20`NodeConfig`s=20as=20in=20their=20`set?= =?UTF-8?q?=5Fnode`=20calls?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...restaurant_reservation_direct_functions.py | 5 ++-- src/pipecat_flows/manager.py | 27 +++++++++++++++---- src/pipecat_flows/types.py | 6 +++++ 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/examples/dynamic/restaurant_reservation_direct_functions.py b/examples/dynamic/restaurant_reservation_direct_functions.py index 6b55813..1d7c71d 100644 --- a/examples/dynamic/restaurant_reservation_direct_functions.py +++ b/examples/dynamic/restaurant_reservation_direct_functions.py @@ -135,6 +135,7 @@ async def end_conversation(flow_manager: FlowManager) -> tuple[None, NodeConfig] def create_initial_node(wait_for_user: bool) -> NodeConfig: """Create initial node for party size collection.""" return { + "name": "initial", "role_messages": [ { "role": "system", @@ -266,9 +267,7 @@ async def main(wait_for_user: bool): async def on_first_participant_joined(transport, participant): await transport.capture_participant_transcription(participant["id"]) logger.debug("Initializing flow manager") - await flow_manager.initialize() - logger.debug("Setting initial node") - await flow_manager.set_node("initial", create_initial_node(wait_for_user)) + await flow_manager.initialize(create_initial_node(wait_for_user)) runner = PipelineRunner() await runner.run(task) diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index 71667e5..1288042 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -61,6 +61,7 @@ FlowsFunctionSchema, FunctionHandler, NodeConfig, + get_or_generate_node_name, ) if TYPE_CHECKING: @@ -160,7 +161,7 @@ def _validate_transition_callback(self, name: str, callback: Any) -> None: if not inspect.iscoroutinefunction(callback): raise ValueError(f"Transition callback for {name} must be async") - async def initialize(self) -> None: + async def initialize(self, initial_node: Optional[NodeConfig] = None) -> None: """Initialize the flow manager.""" if self.initialized: logger.warning(f"{self.__class__.__name__} already initialized") @@ -170,10 +171,26 @@ async def initialize(self) -> None: self.initialized = True logger.debug(f"Initialized {self.__class__.__name__}") - # If in static mode, set initial node + # Set initial node + node_name = None + node = None if self.initial_node: - logger.debug(f"Setting initial node: {self.initial_node}") - await self.set_node(self.initial_node, self.nodes[self.initial_node]) + # Static flow: self.initial_node is expected to be there + node_name = self.initial_node + node = self.nodes[self.initial_node] + if not node: + raise ValueError( + f"Initial node '{self.initial_node}' not found in static flow configuration" + ) + else: + # Dynamic flow: initial_node argument may have been provided (otherwise initial node + # will be set later via set_node()) + if initial_node: + node_name = get_or_generate_node_name(initial_node) + node = initial_node + if node_name: + logger.debug(f"Setting initial node: {node_name}") + await self.set_node(node_name, node) except Exception as e: self.initialized = False @@ -339,7 +356,7 @@ async def on_context_updated_edge( node_name = next_node node = self.nodes[next_node] else: # Dynamic flow - node_name = next_node.get("name", str(uuid.uuid4())) + node_name = get_or_generate_node_name(next_node) node = next_node logger.debug(f"Transition to function-returned node: {node_name}") await self.set_node(node_name, node) diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index d7ee133..62b254e 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -18,6 +18,7 @@ import inspect import types +import uuid from dataclasses import dataclass from enum import Enum from typing import ( @@ -495,6 +496,11 @@ class NodeConfig(NodeConfigRequired, total=False): respond_immediately: bool +def get_or_generate_node_name(node_config: NodeConfig) -> str: + """Get the node name from the given configuration, defaulting to a UUID if not set.""" + return node_config.get("name", str(uuid.uuid4())) + + class FlowConfig(TypedDict): """Configuration for the entire conversation flow. From 758602ae0995e0e8d87dd3e81e9c88b2b09ce587 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 30 May 2025 15:49:34 -0400 Subject: [PATCH 41/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20deprecate=20`set=5Fnode()`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 5 +++ src/pipecat_flows/manager.py | 21 +++++++++-- tests/test_context_strategies.py | 28 +++++++-------- tests/test_manager.py | 60 ++++++++++++++++---------------- 4 files changed, 67 insertions(+), 47 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 50a8e04..135cff3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -104,6 +104,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 using `FlowsFunctionSchema`s or function definition dicts entirely. See the "Added" section above for more details. +- Deprecated `set_node()` in favor of doing the following for dynamic flows: + - Use "consolidated" or "direct" functions that return a tuple (result, next node); provide a + `name` in your next node's config for debug logging purposes + - Pass your initial node to `FlowManager.initialize()` + ### Changed - `functions` are now optional in the `NodeConfig`. Additionally, for AWS diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index 1288042..a4255ac 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -145,6 +145,7 @@ def __init__( self.current_node: Optional[str] = None self._showed_deprecation_warning_for_transition_fields = False + self._showed_deprecation_warning_for_set_node = False def _validate_transition_callback(self, name: str, callback: Any) -> None: """Validate a transition callback. @@ -190,7 +191,7 @@ async def initialize(self, initial_node: Optional[NodeConfig] = None) -> None: node = initial_node if node_name: logger.debug(f"Setting initial node: {node_name}") - await self.set_node(node_name, node) + await self._set_node(node_name, node) except Exception as e: self.initialized = False @@ -359,10 +360,10 @@ async def on_context_updated_edge( node_name = get_or_generate_node_name(next_node) node = next_node logger.debug(f"Transition to function-returned node: {node_name}") - await self.set_node(node_name, node) + await self._set_node(node_name, node) elif transition_to: # Static flow logger.debug(f"Static transition to: {transition_to}") - await self.set_node(transition_to, self.nodes[transition_to]) + await self._set_node(transition_to, self.nodes[transition_to]) elif transition_callback: # Dynamic flow logger.debug(f"Dynamic transition for: {name}") # Check callback signature @@ -544,6 +545,20 @@ async def _register_function( raise FlowError(f"Function registration failed: {str(e)}") from e async def set_node(self, node_id: str, node_config: NodeConfig) -> None: + if not self._showed_deprecation_warning_for_set_node: + self._showed_deprecation_warning_for_set_node = True + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + """`set_node()` is deprecated and will be removed in a future version. Instead, do the following for dynamic flows: +- Use "consolidated" or "direct" functions that return a tuple (result, next_node); provide a `name` in your next_node's config for debug logging purposes +- Pass your initial node to `FlowManager.initialize()`""", + DeprecationWarning, + stacklevel=2, + ) + await self._set_node(node_id, node_config) + + async def _set_node(self, node_id: str, node_config: NodeConfig) -> None: """Set up a new conversation node and transition to it. Handles the complete node transition process in the following order: diff --git a/tests/test_context_strategies.py b/tests/test_context_strategies.py index 2c27ddb..715213e 100644 --- a/tests/test_context_strategies.py +++ b/tests/test_context_strategies.py @@ -96,7 +96,7 @@ async def test_default_strategy(self): await flow_manager.initialize() # First node should use UpdateFrame regardless of strategy - await flow_manager.set_node("first", self.sample_node) + await flow_manager._set_node("first", self.sample_node) first_call = self.mock_task.queue_frames.call_args_list[0] first_frames = first_call[0][0] self.assertTrue(any(isinstance(f, LLMMessagesUpdateFrame) for f in first_frames)) @@ -105,7 +105,7 @@ async def test_default_strategy(self): self.mock_task.queue_frames.reset_mock() # Subsequent node should use AppendFrame with default strategy - await flow_manager.set_node("second", self.sample_node) + await flow_manager._set_node("second", self.sample_node) second_call = self.mock_task.queue_frames.call_args_list[0] second_frames = second_call[0][0] self.assertTrue(any(isinstance(f, LLMMessagesAppendFrame) for f in second_frames)) @@ -121,11 +121,11 @@ async def test_reset_strategy(self): await flow_manager.initialize() # Set initial node - await flow_manager.set_node("first", self.sample_node) + await flow_manager._set_node("first", self.sample_node) self.mock_task.queue_frames.reset_mock() # Second node should use UpdateFrame with RESET strategy - await flow_manager.set_node("second", self.sample_node) + await flow_manager._set_node("second", self.sample_node) second_call = self.mock_task.queue_frames.call_args_list[0] second_frames = second_call[0][0] self.assertTrue(any(isinstance(f, LLMMessagesUpdateFrame) for f in second_frames)) @@ -150,10 +150,10 @@ async def test_reset_with_summary_success(self): await flow_manager.initialize() # Set nodes and verify summary inclusion - await flow_manager.set_node("first", self.sample_node) + await flow_manager._set_node("first", self.sample_node) self.mock_task.queue_frames.reset_mock() - await flow_manager.set_node("second", self.sample_node) + await flow_manager._set_node("second", self.sample_node) # Verify summary was included in context update second_call = self.mock_task.queue_frames.call_args_list[0] @@ -180,10 +180,10 @@ async def test_reset_with_summary_timeout(self): ) # Set nodes and verify fallback to RESET - await flow_manager.set_node("first", self.sample_node) + await flow_manager._set_node("first", self.sample_node) self.mock_task.queue_frames.reset_mock() - await flow_manager.set_node("second", self.sample_node) + await flow_manager._set_node("second", self.sample_node) # Verify UpdateFrame was used (RESET behavior) second_call = self.mock_task.queue_frames.call_args_list[0] @@ -238,10 +238,10 @@ async def test_node_level_strategy_override(self): } # Set nodes and verify strategy override - await flow_manager.set_node("first", self.sample_node) + await flow_manager._set_node("first", self.sample_node) self.mock_task.queue_frames.reset_mock() - await flow_manager.set_node("second", node_with_strategy) + await flow_manager._set_node("second", node_with_strategy) # Verify UpdateFrame was used (RESET behavior) despite global APPEND second_call = self.mock_task.queue_frames.call_args_list[0] @@ -267,8 +267,8 @@ async def test_summary_generation_content(self): await flow_manager.initialize() # Set nodes to trigger summary generation - await flow_manager.set_node("first", self.sample_node) - await flow_manager.set_node("second", self.sample_node) + await flow_manager._set_node("first", self.sample_node) + await flow_manager._set_node("second", self.sample_node) # Verify summary generation call create_call = self.mock_llm._client.chat.completions.create.call_args @@ -299,7 +299,7 @@ async def test_context_structure_after_summary(self): await flow_manager.initialize() # Set nodes to trigger summary generation - await flow_manager.set_node("first", self.sample_node) + await flow_manager._set_node("first", self.sample_node) self.mock_task.queue_frames.reset_mock() # Node with new task messages @@ -307,7 +307,7 @@ async def test_context_structure_after_summary(self): "task_messages": [{"role": "system", "content": "New task."}], "functions": [], } - await flow_manager.set_node("second", new_node) + await flow_manager._set_node("second", new_node) # Verify context structure update_call = self.mock_task.queue_frames.call_args_list[0] diff --git a/tests/test_manager.py b/tests/test_manager.py index 84e167d..912adf5 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -134,7 +134,7 @@ async def test_dynamic_flow_initialization(self): # Initialize and set node await flow_manager.initialize() - await flow_manager.set_node("test", test_node) + await flow_manager._set_node("test", test_node) self.assertFalse( mock_transition_handler.called @@ -160,7 +160,7 @@ async def test_static_flow_transitions(self): # In static flows, transitions happen through set_node with a # predefined node configuration from the flow_config - await flow_manager.set_node("next_node", flow_manager.nodes["next_node"]) + await flow_manager._set_node("next_node", flow_manager.nodes["next_node"]) # Verify node transition occurred self.assertEqual(flow_manager.current_node, "next_node") @@ -231,7 +231,7 @@ async def test_handler(args: FlowArgs) -> FlowResult: } # Test old style callback - await flow_manager.set_node("old_style", old_style_node) + await flow_manager._set_node("old_style", old_style_node) func = flow_manager.llm.register_function.call_args[0][1] # Store the context_updated callback @@ -279,7 +279,7 @@ async def result_callback(result, properties=None): } # Test new style callback - await flow_manager.set_node("new_style", new_style_node) + await flow_manager._set_node("new_style", new_style_node) func = flow_manager.llm.register_function.call_args[0][1] # Reset context_updated callback @@ -316,12 +316,12 @@ async def test_node_validation(self): # Test missing task_messages invalid_config = {"functions": []} with self.assertRaises(FlowError) as context: - await flow_manager.set_node("test", invalid_config) + await flow_manager._set_node("test", invalid_config) self.assertIn("missing required 'task_messages' field", str(context.exception)) # Test valid config valid_config = {"task_messages": []} - await flow_manager.set_node("test", valid_config) + await flow_manager._set_node("test", valid_config) self.assertEqual(flow_manager.current_node, "test") self.assertEqual(flow_manager.current_functions, set()) @@ -339,7 +339,7 @@ async def test_function_registration(self): self.mock_llm.register_function.reset_mock() # Set node with function - await flow_manager.set_node("test", self.sample_node) + await flow_manager._set_node("test", self.sample_node) # Verify function was registered self.mock_llm.register_function.assert_called_once() @@ -370,7 +370,7 @@ async def test_action_execution(self): self.mock_tts.say.reset_mock() # Set node with actions - await flow_manager.set_node("test", node_with_actions) + await flow_manager._set_node("test", node_with_actions) # Verify TTS was called for both actions self.assertEqual(self.mock_tts.say.call_count, 2) @@ -393,7 +393,7 @@ async def test_error_handling(self): # Test setting node before initialization with self.assertRaises(FlowTransitionError): - await flow_manager.set_node("test", self.sample_node) + await flow_manager._set_node("test", self.sample_node) # Initialize normally await flow_manager.initialize() @@ -402,7 +402,7 @@ async def test_error_handling(self): # Test node setting error self.mock_task.queue_frames.side_effect = Exception("Queue error") with self.assertRaises(FlowError): - await flow_manager.set_node("test", self.sample_node) + await flow_manager._set_node("test", self.sample_node) # Verify flow manager remains initialized despite error self.assertTrue(flow_manager.initialized) @@ -424,7 +424,7 @@ async def test_state_management(self): self.mock_task.queue_frames.reset_mock() # Verify state persists across node transitions - await flow_manager.set_node("test", self.sample_node) + await flow_manager._set_node("test", self.sample_node) self.assertEqual(flow_manager.state["test_key"], test_value) async def test_multiple_function_registration(self): @@ -452,7 +452,7 @@ async def test_multiple_function_registration(self): ], } - await flow_manager.set_node("test", node_config) + await flow_manager._set_node("test", node_config) # Verify all functions were registered self.assertEqual(self.mock_llm.register_function.call_count, 3) @@ -562,7 +562,7 @@ async def test_node_validation_edge_cases(self): "functions": [{"type": "function"}], # Missing name } with self.assertRaises(FlowError) as context: - await flow_manager.set_node("test", invalid_config) + await flow_manager._set_node("test", invalid_config) self.assertIn("invalid format", str(context.exception)) # Test node function without handler or transition_to @@ -588,7 +588,7 @@ def capture_warning(msg, *args, **kwargs): warning_message = msg with patch("loguru.logger.warning", side_effect=capture_warning): - await flow_manager.set_node("test", invalid_config) + await flow_manager._set_node("test", invalid_config) self.assertIsNotNone(warning_message) self.assertIn( "Function 'test_func' in node 'test' has neither handler, transition_to, nor transition_callback", @@ -627,7 +627,7 @@ async def failing_handler(args, flow_manager): } # Set up node and get registered function - await flow_manager.set_node("test", test_node) + await flow_manager._set_node("test", test_node) transition_func = flow_manager.llm.register_function.call_args[0][1] # Track the result and context_updated callback @@ -702,7 +702,7 @@ async def test_action_execution_error_handling(self): # Should raise FlowError due to invalid actions with self.assertRaises(FlowError): - await flow_manager.set_node("test", node_config) + await flow_manager._set_node("test", node_config) # Verify error handling for pre and post actions separately with self.assertRaises(FlowError): @@ -766,7 +766,7 @@ async def test_handler(args): } # Set node and verify function registration - await flow_manager.set_node("test", node_config) + await flow_manager._set_node("test", node_config) # Verify both functions were registered self.assertIn("test1", flow_manager.current_functions) @@ -806,7 +806,7 @@ async def test_handler_main(args): ], } - await flow_manager.set_node("test", node_config) + await flow_manager._set_node("test", node_config) self.assertIn("test_function", flow_manager.current_functions) finally: @@ -838,7 +838,7 @@ async def test_function_token_handling_not_found(self): } with self.assertRaises(FlowError) as context: - await flow_manager.set_node("test", node_config) + await flow_manager._set_node("test", node_config) self.assertIn("Function 'nonexistent_handler' not found", str(context.exception)) @@ -879,7 +879,7 @@ async def test_handler(args): ], } - await flow_manager.set_node("test", node_config) + await flow_manager._set_node("test", node_config) # Get the registered function and test it name, func = self.mock_llm.register_function.call_args[0] @@ -928,7 +928,7 @@ async def test_role_message_inheritance(self): } # Set first node and verify UpdateFrame - await flow_manager.set_node("first", first_node) + await flow_manager._set_node("first", first_node) first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call first_frames = first_call[0][0] update_frames = [f for f in first_frames if isinstance(f, LLMMessagesUpdateFrame)] @@ -940,7 +940,7 @@ async def test_role_message_inheritance(self): # Reset mock and set second node self.mock_task.queue_frames.reset_mock() - await flow_manager.set_node("second", second_node) + await flow_manager._set_node("second", second_node) # Verify AppendFrame for second node first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call @@ -966,7 +966,7 @@ async def test_frame_type_selection(self): } # First node should use UpdateFrame - await flow_manager.set_node("first", test_node) + await flow_manager._set_node("first", test_node) first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call first_frames = first_call[0][0] self.assertTrue( @@ -982,7 +982,7 @@ async def test_frame_type_selection(self): self.mock_task.queue_frames.reset_mock() # Second node should use AppendFrame - await flow_manager.set_node("second", test_node) + await flow_manager._set_node("second", test_node) first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call second_frames = first_call[0][0] self.assertTrue( @@ -1033,7 +1033,7 @@ async def test_handler(args): ], } - await flow_manager.set_node("test", node_config) + await flow_manager._set_node("test", node_config) # Get the registered functions node_func = None @@ -1105,7 +1105,7 @@ async def test_completion_timing(self): self.mock_task.queue_frames.reset_mock() self.mock_context_aggregator.user().get_context_frame.reset_mock() - await flow_manager.set_node( + await flow_manager._set_node( "initial", { "task_messages": [{"role": "system", "content": "Test"}], @@ -1131,7 +1131,7 @@ async def test_completion_timing(self): self.mock_task.queue_frames.reset_mock() self.mock_context_aggregator.user().get_context_frame.reset_mock() - await flow_manager.set_node("next", next_node) + await flow_manager._set_node("next", next_node) # Should see context update and completion trigger again self.assertTrue(self.mock_task.queue_frames.called) @@ -1168,7 +1168,7 @@ async def test_transition_configuration_exclusivity(self): # Should raise error when trying to use both with self.assertRaises(FlowError) as context: - await flow_manager.set_node("test", test_node) + await flow_manager._set_node("test", test_node) self.assertIn( "Cannot specify both transition_to and transition_callback", str(context.exception) ) @@ -1236,7 +1236,7 @@ async def test_node_without_functions(self): } # Set node and verify it works without error - await flow_manager.set_node("no_functions", node_config) + await flow_manager._set_node("no_functions", node_config) # Verify current_functions is empty set self.assertEqual(flow_manager.current_functions, set()) @@ -1265,7 +1265,7 @@ async def test_node_with_empty_functions(self): } # Set node and verify it works without error - await flow_manager.set_node("empty_functions", node_config) + await flow_manager._set_node("empty_functions", node_config) # Verify current_functions is empty set self.assertEqual(flow_manager.current_functions, set()) From e34c3866675a3040add4bd4b10f8b06057bf6e39 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 30 May 2025 17:13:47 -0400 Subject: [PATCH 42/47] =?UTF-8?q?Revert=20"[WIP]=20More=20progress=20on=20?= =?UTF-8?q?=E2=80=9Cdirect=20functions=E2=80=9D:=20deprecate=20`set=5Fnode?= =?UTF-8?q?()`"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 284e464de3ad16aed5fda2678844c87f415a994b. It turns out that `set_node()` is useful outside of the function lifecycle; it's used, for example, in the warm transfer demo in an `on_participant_joined` handler (when the human agent joins) --- CHANGELOG.md | 5 --- src/pipecat_flows/manager.py | 21 ++--------- tests/test_context_strategies.py | 28 +++++++-------- tests/test_manager.py | 60 ++++++++++++++++---------------- 4 files changed, 47 insertions(+), 67 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 135cff3..50a8e04 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -104,11 +104,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 using `FlowsFunctionSchema`s or function definition dicts entirely. See the "Added" section above for more details. -- Deprecated `set_node()` in favor of doing the following for dynamic flows: - - Use "consolidated" or "direct" functions that return a tuple (result, next node); provide a - `name` in your next node's config for debug logging purposes - - Pass your initial node to `FlowManager.initialize()` - ### Changed - `functions` are now optional in the `NodeConfig`. Additionally, for AWS diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index a4255ac..1288042 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -145,7 +145,6 @@ def __init__( self.current_node: Optional[str] = None self._showed_deprecation_warning_for_transition_fields = False - self._showed_deprecation_warning_for_set_node = False def _validate_transition_callback(self, name: str, callback: Any) -> None: """Validate a transition callback. @@ -191,7 +190,7 @@ async def initialize(self, initial_node: Optional[NodeConfig] = None) -> None: node = initial_node if node_name: logger.debug(f"Setting initial node: {node_name}") - await self._set_node(node_name, node) + await self.set_node(node_name, node) except Exception as e: self.initialized = False @@ -360,10 +359,10 @@ async def on_context_updated_edge( node_name = get_or_generate_node_name(next_node) node = next_node logger.debug(f"Transition to function-returned node: {node_name}") - await self._set_node(node_name, node) + await self.set_node(node_name, node) elif transition_to: # Static flow logger.debug(f"Static transition to: {transition_to}") - await self._set_node(transition_to, self.nodes[transition_to]) + await self.set_node(transition_to, self.nodes[transition_to]) elif transition_callback: # Dynamic flow logger.debug(f"Dynamic transition for: {name}") # Check callback signature @@ -545,20 +544,6 @@ async def _register_function( raise FlowError(f"Function registration failed: {str(e)}") from e async def set_node(self, node_id: str, node_config: NodeConfig) -> None: - if not self._showed_deprecation_warning_for_set_node: - self._showed_deprecation_warning_for_set_node = True - with warnings.catch_warnings(): - warnings.simplefilter("always") - warnings.warn( - """`set_node()` is deprecated and will be removed in a future version. Instead, do the following for dynamic flows: -- Use "consolidated" or "direct" functions that return a tuple (result, next_node); provide a `name` in your next_node's config for debug logging purposes -- Pass your initial node to `FlowManager.initialize()`""", - DeprecationWarning, - stacklevel=2, - ) - await self._set_node(node_id, node_config) - - async def _set_node(self, node_id: str, node_config: NodeConfig) -> None: """Set up a new conversation node and transition to it. Handles the complete node transition process in the following order: diff --git a/tests/test_context_strategies.py b/tests/test_context_strategies.py index 715213e..2c27ddb 100644 --- a/tests/test_context_strategies.py +++ b/tests/test_context_strategies.py @@ -96,7 +96,7 @@ async def test_default_strategy(self): await flow_manager.initialize() # First node should use UpdateFrame regardless of strategy - await flow_manager._set_node("first", self.sample_node) + await flow_manager.set_node("first", self.sample_node) first_call = self.mock_task.queue_frames.call_args_list[0] first_frames = first_call[0][0] self.assertTrue(any(isinstance(f, LLMMessagesUpdateFrame) for f in first_frames)) @@ -105,7 +105,7 @@ async def test_default_strategy(self): self.mock_task.queue_frames.reset_mock() # Subsequent node should use AppendFrame with default strategy - await flow_manager._set_node("second", self.sample_node) + await flow_manager.set_node("second", self.sample_node) second_call = self.mock_task.queue_frames.call_args_list[0] second_frames = second_call[0][0] self.assertTrue(any(isinstance(f, LLMMessagesAppendFrame) for f in second_frames)) @@ -121,11 +121,11 @@ async def test_reset_strategy(self): await flow_manager.initialize() # Set initial node - await flow_manager._set_node("first", self.sample_node) + await flow_manager.set_node("first", self.sample_node) self.mock_task.queue_frames.reset_mock() # Second node should use UpdateFrame with RESET strategy - await flow_manager._set_node("second", self.sample_node) + await flow_manager.set_node("second", self.sample_node) second_call = self.mock_task.queue_frames.call_args_list[0] second_frames = second_call[0][0] self.assertTrue(any(isinstance(f, LLMMessagesUpdateFrame) for f in second_frames)) @@ -150,10 +150,10 @@ async def test_reset_with_summary_success(self): await flow_manager.initialize() # Set nodes and verify summary inclusion - await flow_manager._set_node("first", self.sample_node) + await flow_manager.set_node("first", self.sample_node) self.mock_task.queue_frames.reset_mock() - await flow_manager._set_node("second", self.sample_node) + await flow_manager.set_node("second", self.sample_node) # Verify summary was included in context update second_call = self.mock_task.queue_frames.call_args_list[0] @@ -180,10 +180,10 @@ async def test_reset_with_summary_timeout(self): ) # Set nodes and verify fallback to RESET - await flow_manager._set_node("first", self.sample_node) + await flow_manager.set_node("first", self.sample_node) self.mock_task.queue_frames.reset_mock() - await flow_manager._set_node("second", self.sample_node) + await flow_manager.set_node("second", self.sample_node) # Verify UpdateFrame was used (RESET behavior) second_call = self.mock_task.queue_frames.call_args_list[0] @@ -238,10 +238,10 @@ async def test_node_level_strategy_override(self): } # Set nodes and verify strategy override - await flow_manager._set_node("first", self.sample_node) + await flow_manager.set_node("first", self.sample_node) self.mock_task.queue_frames.reset_mock() - await flow_manager._set_node("second", node_with_strategy) + await flow_manager.set_node("second", node_with_strategy) # Verify UpdateFrame was used (RESET behavior) despite global APPEND second_call = self.mock_task.queue_frames.call_args_list[0] @@ -267,8 +267,8 @@ async def test_summary_generation_content(self): await flow_manager.initialize() # Set nodes to trigger summary generation - await flow_manager._set_node("first", self.sample_node) - await flow_manager._set_node("second", self.sample_node) + await flow_manager.set_node("first", self.sample_node) + await flow_manager.set_node("second", self.sample_node) # Verify summary generation call create_call = self.mock_llm._client.chat.completions.create.call_args @@ -299,7 +299,7 @@ async def test_context_structure_after_summary(self): await flow_manager.initialize() # Set nodes to trigger summary generation - await flow_manager._set_node("first", self.sample_node) + await flow_manager.set_node("first", self.sample_node) self.mock_task.queue_frames.reset_mock() # Node with new task messages @@ -307,7 +307,7 @@ async def test_context_structure_after_summary(self): "task_messages": [{"role": "system", "content": "New task."}], "functions": [], } - await flow_manager._set_node("second", new_node) + await flow_manager.set_node("second", new_node) # Verify context structure update_call = self.mock_task.queue_frames.call_args_list[0] diff --git a/tests/test_manager.py b/tests/test_manager.py index 912adf5..84e167d 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -134,7 +134,7 @@ async def test_dynamic_flow_initialization(self): # Initialize and set node await flow_manager.initialize() - await flow_manager._set_node("test", test_node) + await flow_manager.set_node("test", test_node) self.assertFalse( mock_transition_handler.called @@ -160,7 +160,7 @@ async def test_static_flow_transitions(self): # In static flows, transitions happen through set_node with a # predefined node configuration from the flow_config - await flow_manager._set_node("next_node", flow_manager.nodes["next_node"]) + await flow_manager.set_node("next_node", flow_manager.nodes["next_node"]) # Verify node transition occurred self.assertEqual(flow_manager.current_node, "next_node") @@ -231,7 +231,7 @@ async def test_handler(args: FlowArgs) -> FlowResult: } # Test old style callback - await flow_manager._set_node("old_style", old_style_node) + await flow_manager.set_node("old_style", old_style_node) func = flow_manager.llm.register_function.call_args[0][1] # Store the context_updated callback @@ -279,7 +279,7 @@ async def result_callback(result, properties=None): } # Test new style callback - await flow_manager._set_node("new_style", new_style_node) + await flow_manager.set_node("new_style", new_style_node) func = flow_manager.llm.register_function.call_args[0][1] # Reset context_updated callback @@ -316,12 +316,12 @@ async def test_node_validation(self): # Test missing task_messages invalid_config = {"functions": []} with self.assertRaises(FlowError) as context: - await flow_manager._set_node("test", invalid_config) + await flow_manager.set_node("test", invalid_config) self.assertIn("missing required 'task_messages' field", str(context.exception)) # Test valid config valid_config = {"task_messages": []} - await flow_manager._set_node("test", valid_config) + await flow_manager.set_node("test", valid_config) self.assertEqual(flow_manager.current_node, "test") self.assertEqual(flow_manager.current_functions, set()) @@ -339,7 +339,7 @@ async def test_function_registration(self): self.mock_llm.register_function.reset_mock() # Set node with function - await flow_manager._set_node("test", self.sample_node) + await flow_manager.set_node("test", self.sample_node) # Verify function was registered self.mock_llm.register_function.assert_called_once() @@ -370,7 +370,7 @@ async def test_action_execution(self): self.mock_tts.say.reset_mock() # Set node with actions - await flow_manager._set_node("test", node_with_actions) + await flow_manager.set_node("test", node_with_actions) # Verify TTS was called for both actions self.assertEqual(self.mock_tts.say.call_count, 2) @@ -393,7 +393,7 @@ async def test_error_handling(self): # Test setting node before initialization with self.assertRaises(FlowTransitionError): - await flow_manager._set_node("test", self.sample_node) + await flow_manager.set_node("test", self.sample_node) # Initialize normally await flow_manager.initialize() @@ -402,7 +402,7 @@ async def test_error_handling(self): # Test node setting error self.mock_task.queue_frames.side_effect = Exception("Queue error") with self.assertRaises(FlowError): - await flow_manager._set_node("test", self.sample_node) + await flow_manager.set_node("test", self.sample_node) # Verify flow manager remains initialized despite error self.assertTrue(flow_manager.initialized) @@ -424,7 +424,7 @@ async def test_state_management(self): self.mock_task.queue_frames.reset_mock() # Verify state persists across node transitions - await flow_manager._set_node("test", self.sample_node) + await flow_manager.set_node("test", self.sample_node) self.assertEqual(flow_manager.state["test_key"], test_value) async def test_multiple_function_registration(self): @@ -452,7 +452,7 @@ async def test_multiple_function_registration(self): ], } - await flow_manager._set_node("test", node_config) + await flow_manager.set_node("test", node_config) # Verify all functions were registered self.assertEqual(self.mock_llm.register_function.call_count, 3) @@ -562,7 +562,7 @@ async def test_node_validation_edge_cases(self): "functions": [{"type": "function"}], # Missing name } with self.assertRaises(FlowError) as context: - await flow_manager._set_node("test", invalid_config) + await flow_manager.set_node("test", invalid_config) self.assertIn("invalid format", str(context.exception)) # Test node function without handler or transition_to @@ -588,7 +588,7 @@ def capture_warning(msg, *args, **kwargs): warning_message = msg with patch("loguru.logger.warning", side_effect=capture_warning): - await flow_manager._set_node("test", invalid_config) + await flow_manager.set_node("test", invalid_config) self.assertIsNotNone(warning_message) self.assertIn( "Function 'test_func' in node 'test' has neither handler, transition_to, nor transition_callback", @@ -627,7 +627,7 @@ async def failing_handler(args, flow_manager): } # Set up node and get registered function - await flow_manager._set_node("test", test_node) + await flow_manager.set_node("test", test_node) transition_func = flow_manager.llm.register_function.call_args[0][1] # Track the result and context_updated callback @@ -702,7 +702,7 @@ async def test_action_execution_error_handling(self): # Should raise FlowError due to invalid actions with self.assertRaises(FlowError): - await flow_manager._set_node("test", node_config) + await flow_manager.set_node("test", node_config) # Verify error handling for pre and post actions separately with self.assertRaises(FlowError): @@ -766,7 +766,7 @@ async def test_handler(args): } # Set node and verify function registration - await flow_manager._set_node("test", node_config) + await flow_manager.set_node("test", node_config) # Verify both functions were registered self.assertIn("test1", flow_manager.current_functions) @@ -806,7 +806,7 @@ async def test_handler_main(args): ], } - await flow_manager._set_node("test", node_config) + await flow_manager.set_node("test", node_config) self.assertIn("test_function", flow_manager.current_functions) finally: @@ -838,7 +838,7 @@ async def test_function_token_handling_not_found(self): } with self.assertRaises(FlowError) as context: - await flow_manager._set_node("test", node_config) + await flow_manager.set_node("test", node_config) self.assertIn("Function 'nonexistent_handler' not found", str(context.exception)) @@ -879,7 +879,7 @@ async def test_handler(args): ], } - await flow_manager._set_node("test", node_config) + await flow_manager.set_node("test", node_config) # Get the registered function and test it name, func = self.mock_llm.register_function.call_args[0] @@ -928,7 +928,7 @@ async def test_role_message_inheritance(self): } # Set first node and verify UpdateFrame - await flow_manager._set_node("first", first_node) + await flow_manager.set_node("first", first_node) first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call first_frames = first_call[0][0] update_frames = [f for f in first_frames if isinstance(f, LLMMessagesUpdateFrame)] @@ -940,7 +940,7 @@ async def test_role_message_inheritance(self): # Reset mock and set second node self.mock_task.queue_frames.reset_mock() - await flow_manager._set_node("second", second_node) + await flow_manager.set_node("second", second_node) # Verify AppendFrame for second node first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call @@ -966,7 +966,7 @@ async def test_frame_type_selection(self): } # First node should use UpdateFrame - await flow_manager._set_node("first", test_node) + await flow_manager.set_node("first", test_node) first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call first_frames = first_call[0][0] self.assertTrue( @@ -982,7 +982,7 @@ async def test_frame_type_selection(self): self.mock_task.queue_frames.reset_mock() # Second node should use AppendFrame - await flow_manager._set_node("second", test_node) + await flow_manager.set_node("second", test_node) first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call second_frames = first_call[0][0] self.assertTrue( @@ -1033,7 +1033,7 @@ async def test_handler(args): ], } - await flow_manager._set_node("test", node_config) + await flow_manager.set_node("test", node_config) # Get the registered functions node_func = None @@ -1105,7 +1105,7 @@ async def test_completion_timing(self): self.mock_task.queue_frames.reset_mock() self.mock_context_aggregator.user().get_context_frame.reset_mock() - await flow_manager._set_node( + await flow_manager.set_node( "initial", { "task_messages": [{"role": "system", "content": "Test"}], @@ -1131,7 +1131,7 @@ async def test_completion_timing(self): self.mock_task.queue_frames.reset_mock() self.mock_context_aggregator.user().get_context_frame.reset_mock() - await flow_manager._set_node("next", next_node) + await flow_manager.set_node("next", next_node) # Should see context update and completion trigger again self.assertTrue(self.mock_task.queue_frames.called) @@ -1168,7 +1168,7 @@ async def test_transition_configuration_exclusivity(self): # Should raise error when trying to use both with self.assertRaises(FlowError) as context: - await flow_manager._set_node("test", test_node) + await flow_manager.set_node("test", test_node) self.assertIn( "Cannot specify both transition_to and transition_callback", str(context.exception) ) @@ -1236,7 +1236,7 @@ async def test_node_without_functions(self): } # Set node and verify it works without error - await flow_manager._set_node("no_functions", node_config) + await flow_manager.set_node("no_functions", node_config) # Verify current_functions is empty set self.assertEqual(flow_manager.current_functions, set()) @@ -1265,7 +1265,7 @@ async def test_node_with_empty_functions(self): } # Set node and verify it works without error - await flow_manager._set_node("empty_functions", node_config) + await flow_manager.set_node("empty_functions", node_config) # Verify current_functions is empty set self.assertEqual(flow_manager.current_functions, set()) From b910b31ee129552743a11764539b31cff3fa24a1 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 30 May 2025 17:35:35 -0400 Subject: [PATCH 43/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20update=20examples=20to=20use=20ne?= =?UTF-8?q?w=20recommended=20practices?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/dynamic/insurance_anthropic.py | 94 ++++++++---------- examples/dynamic/insurance_aws_bedrock.py | 106 ++++++++++---------- examples/dynamic/insurance_gemini.py | 108 ++++++++++----------- examples/dynamic/insurance_openai.py | 100 +++++++++---------- examples/dynamic/restaurant_reservation.py | 53 ++++------ examples/dynamic/warm_transfer.py | 89 ++++++++--------- 6 files changed, 253 insertions(+), 297 deletions(-) diff --git a/examples/dynamic/insurance_anthropic.py b/examples/dynamic/insurance_anthropic.py index d05cadd..1684379 100644 --- a/examples/dynamic/insurance_anthropic.py +++ b/examples/dynamic/insurance_anthropic.py @@ -29,7 +29,7 @@ import os import sys from pathlib import Path -from typing import Dict, TypedDict, Union +from typing import TypedDict, Union import aiohttp from dotenv import load_dotenv @@ -88,22 +88,37 @@ class CoverageUpdateResult(FlowResult, InsuranceQuote): # Function handlers -async def collect_age(args: FlowArgs) -> AgeCollectionResult: - """Process age collection.""" +async def collect_age( + args: FlowArgs, flow_manager: FlowManager +) -> tuple[AgeCollectionResult, NodeConfig]: + """Process age collection and return next node name.""" age = args["age"] logger.debug(f"collect_age handler executing with age: {age}") - return AgeCollectionResult(age=age) + flow_manager.state["age"] = age + result = AgeCollectionResult(age=age) -async def collect_marital_status(args: FlowArgs) -> MaritalStatusResult: - """Process marital status collection.""" + next_node = create_marital_status_node() + + return result, next_node + + +async def collect_marital_status( + args: FlowArgs, flow_manager: FlowManager +) -> tuple[MaritalStatusResult, NodeConfig]: + """Process marital status collection and return next node name.""" status = args["marital_status"] logger.debug(f"collect_marital_status handler executing with status: {status}") - return MaritalStatusResult(marital_status=status) + result = MaritalStatusResult(marital_status=status) + + next_node = create_quote_calculation_node(flow_manager.state["age"], status) + + return result, next_node -async def calculate_quote(args: FlowArgs) -> QuoteCalculationResult: - """Calculate insurance quote based on age and marital status.""" + +async def calculate_quote(args: FlowArgs) -> tuple[QuoteCalculationResult, NodeConfig]: + """Calculate insurance quote based on age and marital status, return next node name.""" age = args["age"] marital_status = args["marital_status"] logger.debug(f"calculate_quote handler executing with age: {age}, status: {marital_status}") @@ -116,15 +131,17 @@ async def calculate_quote(args: FlowArgs) -> QuoteCalculationResult: # Calculate quote monthly_premium = rates["base_rate"] * rates["risk_multiplier"] - return { + result = { "monthly_premium": monthly_premium, "coverage_amount": 250000, "deductible": 1000, } + next_node = create_quote_results_node(result) + return result, next_node -async def update_coverage(args: FlowArgs) -> CoverageUpdateResult: - """Update coverage options and recalculate premium.""" +async def update_coverage(args: FlowArgs) -> tuple[CoverageUpdateResult, NodeConfig]: + """Update coverage options and recalculate premium, return next node name.""" coverage_amount = args["coverage_amount"] deductible = args["deductible"] logger.debug( @@ -136,51 +153,28 @@ async def update_coverage(args: FlowArgs) -> CoverageUpdateResult: if deductible > 1000: monthly_premium *= 0.9 # 10% discount for higher deductible - return { + result = { "monthly_premium": monthly_premium, "coverage_amount": coverage_amount, "deductible": deductible, } + next_node = create_quote_results_node(result) + return result, next_node -async def end_quote() -> FlowResult: - """Handle quote completion.""" +async def end_quote(args: FlowArgs) -> tuple[FlowResult, str]: + """Handle quote completion and return next node name.""" logger.debug("end_quote handler executing") - return {"status": "completed"} - - -# Transition callbacks and handlers -async def handle_age_collection(args: Dict, result: AgeCollectionResult, flow_manager: FlowManager): - flow_manager.state["age"] = result["age"] - await flow_manager.set_node("marital_status", create_marital_status_node()) - - -async def handle_marital_status_collection( - args: Dict, result: MaritalStatusResult, flow_manager: FlowManager -): - flow_manager.state["marital_status"] = result["marital_status"] - await flow_manager.set_node( - "quote_calculation", - create_quote_calculation_node( - flow_manager.state["age"], flow_manager.state["marital_status"] - ), - ) - - -async def handle_quote_calculation( - args: Dict, result: QuoteCalculationResult, flow_manager: FlowManager -): - await flow_manager.set_node("quote_results", create_quote_results_node(result)) - - -async def handle_end_quote(_: Dict, result: FlowResult, flow_manager: FlowManager): - await flow_manager.set_node("end", create_end_node()) + result = {"status": "completed"} + next_node = create_end_node() + return result, next_node # Node configurations def create_initial_node() -> NodeConfig: """Create the initial node asking for age.""" return { + "name": "initial", "role_messages": [ { "role": "system", @@ -221,7 +215,6 @@ def create_initial_node() -> NodeConfig: "properties": {"age": {"type": "integer"}}, "required": ["age"], }, - "transition_callback": handle_age_collection, } ], } @@ -230,6 +223,7 @@ def create_initial_node() -> NodeConfig: def create_marital_status_node() -> NodeConfig: """Create node for collecting marital status.""" return { + "name": "marital_status", "task_messages": [ { "role": "user", @@ -253,7 +247,6 @@ def create_marital_status_node() -> NodeConfig: }, "required": ["marital_status"], }, - "transition_callback": handle_marital_status_collection, } ], } @@ -262,6 +255,7 @@ def create_marital_status_node() -> NodeConfig: def create_quote_calculation_node(age: int, marital_status: str) -> NodeConfig: """Create node for calculating initial quote.""" return { + "name": "quote_calculation", "task_messages": [ { "role": "user", @@ -290,7 +284,6 @@ def create_quote_calculation_node(age: int, marital_status: str) -> NodeConfig: }, "required": ["age", "marital_status"], }, - "transition_callback": handle_quote_calculation, } ], } @@ -301,6 +294,7 @@ def create_quote_results_node( ) -> NodeConfig: """Create node for showing quote and adjustment options.""" return { + "name": "quote_results", "task_messages": [ { "role": "user", @@ -341,7 +335,6 @@ def create_quote_results_node( "handler": end_quote, "description": "Complete the quote process", "input_schema": {"type": "object", "properties": {}}, - "transition_callback": handle_end_quote, }, ], } @@ -350,6 +343,7 @@ def create_quote_results_node( def create_end_node() -> NodeConfig: """Create the final node.""" return { + "name": "end", "task_messages": [ { "role": "user", @@ -421,9 +415,7 @@ async def main(): async def on_first_participant_joined(transport, participant): await transport.capture_participant_transcription(participant["id"]) # Initialize flow - await flow_manager.initialize() - # Set initial node - await flow_manager.set_node("initial", create_initial_node()) + await flow_manager.initialize(create_initial_node()) # Run the pipeline runner = PipelineRunner() diff --git a/examples/dynamic/insurance_aws_bedrock.py b/examples/dynamic/insurance_aws_bedrock.py index fb72b66..826e3ab 100644 --- a/examples/dynamic/insurance_aws_bedrock.py +++ b/examples/dynamic/insurance_aws_bedrock.py @@ -29,7 +29,7 @@ import os import sys from pathlib import Path -from typing import Dict, TypedDict, Union +from typing import TypedDict, Union import aiohttp from dotenv import load_dotenv @@ -88,21 +88,36 @@ class CoverageUpdateResult(FlowResult, InsuranceQuote): # Function handlers -async def collect_age(args: FlowArgs) -> AgeCollectionResult: +async def collect_age( + args: FlowArgs, flow_manager: FlowManager +) -> tuple[AgeCollectionResult, NodeConfig]: """Process age collection.""" age = args["age"] logger.debug(f"collect_age handler executing with age: {age}") - return AgeCollectionResult(age=age) + flow_manager.state["age"] = age + result = AgeCollectionResult(age=age) -async def collect_marital_status(args: FlowArgs) -> MaritalStatusResult: + next_node = create_marital_status_node() + + return result, next_node + + +async def collect_marital_status( + args: FlowArgs, flow_manager: FlowManager +) -> tuple[MaritalStatusResult, NodeConfig]: """Process marital status collection.""" status = args["marital_status"] logger.debug(f"collect_marital_status handler executing with status: {status}") - return MaritalStatusResult(marital_status=status) + + result = MaritalStatusResult(marital_status=status) + + next_node = create_quote_calculation_node(flow_manager.state["age"], status) + + return result, next_node -async def calculate_quote(args: FlowArgs) -> QuoteCalculationResult: +async def calculate_quote(args: FlowArgs) -> tuple[QuoteCalculationResult, NodeConfig]: """Calculate insurance quote based on age and marital status.""" age = args["age"] marital_status = args["marital_status"] @@ -116,14 +131,16 @@ async def calculate_quote(args: FlowArgs) -> QuoteCalculationResult: # Calculate quote monthly_premium = rates["base_rate"] * rates["risk_multiplier"] - return { - "monthly_premium": monthly_premium, - "coverage_amount": 250000, - "deductible": 1000, - } + result = QuoteCalculationResult( + monthly_premium=monthly_premium, + coverage_amount=250000, + deductible=1000, + ) + next_node = create_quote_results_node(result) + return result, next_node -async def update_coverage(args: FlowArgs) -> CoverageUpdateResult: +async def update_coverage(args: FlowArgs) -> tuple[CoverageUpdateResult, NodeConfig]: """Update coverage options and recalculate premium.""" coverage_amount = args["coverage_amount"] deductible = args["deductible"] @@ -136,51 +153,28 @@ async def update_coverage(args: FlowArgs) -> CoverageUpdateResult: if deductible > 1000: monthly_premium *= 0.9 # 10% discount for higher deductible - return { - "monthly_premium": monthly_premium, - "coverage_amount": coverage_amount, - "deductible": deductible, - } + result = CoverageUpdateResult( + monthly_premium=monthly_premium, + coverage_amount=coverage_amount, + deductible=deductible, + ) + next_node = create_quote_results_node(result) + return result, next_node -async def end_quote() -> FlowResult: +async def end_quote(args: FlowArgs) -> tuple[FlowResult, NodeConfig]: """Handle quote completion.""" logger.debug("end_quote handler executing") - return {"status": "completed"} - - -# Transition callbacks and handlers -async def handle_age_collection(args: Dict, result: AgeCollectionResult, flow_manager: FlowManager): - flow_manager.state["age"] = result["age"] - await flow_manager.set_node("marital_status", create_marital_status_node()) - - -async def handle_marital_status_collection( - args: Dict, result: MaritalStatusResult, flow_manager: FlowManager -): - flow_manager.state["marital_status"] = result["marital_status"] - await flow_manager.set_node( - "quote_calculation", - create_quote_calculation_node( - flow_manager.state["age"], flow_manager.state["marital_status"] - ), - ) - - -async def handle_quote_calculation( - args: Dict, result: QuoteCalculationResult, flow_manager: FlowManager -): - await flow_manager.set_node("quote_results", create_quote_results_node(result)) - - -async def handle_end_quote(_: Dict, result: FlowResult, flow_manager: FlowManager): - await flow_manager.set_node("end", create_end_node()) + result = {"status": "completed"} + next_node = create_end_node() + return result, next_node # Node configurations using FlowsFunctionSchema def create_initial_node() -> NodeConfig: """Create the initial node asking for age.""" return { + "name": "initial", "role_messages": [ { "role": "system", @@ -205,7 +199,6 @@ def create_initial_node() -> NodeConfig: properties={"age": {"type": "integer"}}, required=["age"], handler=collect_age, - transition_callback=handle_age_collection, ) ], } @@ -214,6 +207,7 @@ def create_initial_node() -> NodeConfig: def create_marital_status_node() -> NodeConfig: """Create node for collecting marital status.""" return { + "name": "marital_status", "task_messages": [ { "role": "user", @@ -231,7 +225,6 @@ def create_marital_status_node() -> NodeConfig: properties={"marital_status": {"type": "string", "enum": ["single", "married"]}}, required=["marital_status"], handler=collect_marital_status, - transition_callback=handle_marital_status_collection, ) ], } @@ -240,6 +233,7 @@ def create_marital_status_node() -> NodeConfig: def create_quote_calculation_node(age: int, marital_status: str) -> NodeConfig: """Create node for calculating initial quote.""" return { + "name": "quote_calculation", "task_messages": [ { "role": "user", @@ -260,7 +254,6 @@ def create_quote_calculation_node(age: int, marital_status: str) -> NodeConfig: }, required=["age", "marital_status"], handler=calculate_quote, - transition_callback=handle_quote_calculation, ) ], } @@ -271,6 +264,7 @@ def create_quote_results_node( ) -> NodeConfig: """Create node for showing quote and adjustment options.""" return { + "name": "quote_results", "task_messages": [ { "role": "user", @@ -301,10 +295,9 @@ def create_quote_results_node( FlowsFunctionSchema( name="end_quote", description="Complete the quote process when customer is satisfied", - properties={"status": {"type": "string", "enum": ["completed"]}}, - required=["status"], + properties={}, + required=[], handler=end_quote, - transition_callback=handle_end_quote, ), ], } @@ -313,6 +306,7 @@ def create_quote_results_node( def create_end_node() -> NodeConfig: """Create the final node.""" return { + "name": "end", "task_messages": [ { "role": "user", @@ -373,7 +367,7 @@ async def main(): task = PipelineTask(pipeline, params=PipelineParams(allow_interruptions=True)) - # Initialize flow manager with transition callback + # Initialize flow manager flow_manager = FlowManager( task=task, llm=llm, @@ -385,9 +379,7 @@ async def main(): async def on_first_participant_joined(transport, participant): await transport.capture_participant_transcription(participant["id"]) # Initialize flow - await flow_manager.initialize() - # Set initial node - await flow_manager.set_node("initial", create_initial_node()) + await flow_manager.initialize(create_initial_node()) # Run the pipeline runner = PipelineRunner() diff --git a/examples/dynamic/insurance_gemini.py b/examples/dynamic/insurance_gemini.py index 2ed6456..c568bdd 100644 --- a/examples/dynamic/insurance_gemini.py +++ b/examples/dynamic/insurance_gemini.py @@ -29,7 +29,7 @@ import os import sys from pathlib import Path -from typing import Dict, TypedDict, Union +from typing import TypedDict, Union import aiohttp from dotenv import load_dotenv @@ -88,21 +88,36 @@ class CoverageUpdateResult(FlowResult, InsuranceQuote): # Function handlers -async def collect_age(args: FlowArgs) -> AgeCollectionResult: +async def collect_age( + args: FlowArgs, flow_manager: FlowManager +) -> tuple[AgeCollectionResult, NodeConfig]: """Process age collection.""" age = args["age"] logger.debug(f"collect_age handler executing with age: {age}") - return AgeCollectionResult(age=age) + flow_manager.state["age"] = age + result = AgeCollectionResult(age=age) -async def collect_marital_status(args: FlowArgs) -> MaritalStatusResult: + next_node = create_marital_status_node() + + return result, next_node + + +async def collect_marital_status( + args: FlowArgs, flow_manager: FlowManager +) -> tuple[MaritalStatusResult, NodeConfig]: """Process marital status collection.""" status = args["marital_status"] logger.debug(f"collect_marital_status handler executing with status: {status}") - return MaritalStatusResult(marital_status=status) + + result = MaritalStatusResult(marital_status=status) + + next_node = create_quote_calculation_node(flow_manager.state["age"], status) + + return result, next_node -async def calculate_quote(args: FlowArgs) -> QuoteCalculationResult: +async def calculate_quote(args: FlowArgs) -> tuple[QuoteCalculationResult, NodeConfig]: """Calculate insurance quote based on age and marital status.""" age = args["age"] marital_status = args["marital_status"] @@ -116,14 +131,16 @@ async def calculate_quote(args: FlowArgs) -> QuoteCalculationResult: # Calculate quote monthly_premium = rates["base_rate"] * rates["risk_multiplier"] - return { - "monthly_premium": monthly_premium, - "coverage_amount": 250000, - "deductible": 1000, - } + result = QuoteCalculationResult( + monthly_premium=monthly_premium, + coverage_amount=250000, + deductible=1000, + ) + next_node = create_quote_results_node(result) + return result, next_node -async def update_coverage(args: FlowArgs) -> CoverageUpdateResult: +async def update_coverage(args: FlowArgs) -> tuple[CoverageUpdateResult, NodeConfig]: """Update coverage options and recalculate premium.""" coverage_amount = args["coverage_amount"] deductible = args["deductible"] @@ -136,51 +153,28 @@ async def update_coverage(args: FlowArgs) -> CoverageUpdateResult: if deductible > 1000: monthly_premium *= 0.9 # 10% discount for higher deductible - return { - "monthly_premium": monthly_premium, - "coverage_amount": coverage_amount, - "deductible": deductible, - } + result = CoverageUpdateResult( + monthly_premium=monthly_premium, + coverage_amount=coverage_amount, + deductible=deductible, + ) + next_node = create_quote_results_node(result) + return result, next_node -async def end_quote() -> FlowResult: +async def end_quote(args: FlowArgs) -> tuple[FlowResult, NodeConfig]: """Handle quote completion.""" logger.debug("end_quote handler executing") - return {"status": "completed"} - - -# Transition callbacks and handlers -async def handle_age_collection(args: Dict, result: AgeCollectionResult, flow_manager: FlowManager): - flow_manager.state["age"] = result["age"] - await flow_manager.set_node("marital_status", create_marital_status_node()) - - -async def handle_marital_status_collection( - args: Dict, result: MaritalStatusResult, flow_manager: FlowManager -): - flow_manager.state["marital_status"] = result["marital_status"] - await flow_manager.set_node( - "quote_calculation", - create_quote_calculation_node( - flow_manager.state["age"], flow_manager.state["marital_status"] - ), - ) - - -async def handle_quote_calculation( - args: Dict, result: QuoteCalculationResult, flow_manager: FlowManager -): - await flow_manager.set_node("quote_results", create_quote_results_node(result)) - - -async def handle_end_quote(_: Dict, result: FlowResult, flow_manager: FlowManager): - await flow_manager.set_node("end", create_end_node()) + result = {"status": "completed"} + next_node = create_end_node() + return result, next_node # Node configurations using FlowsFunctionSchema def create_initial_node() -> NodeConfig: """Create the initial node asking for age.""" return { + "name": "initial", "role_messages": [ { "role": "system", @@ -205,7 +199,6 @@ def create_initial_node() -> NodeConfig: properties={"age": {"type": "integer"}}, required=["age"], handler=collect_age, - transition_callback=handle_age_collection, ) ], } @@ -214,6 +207,7 @@ def create_initial_node() -> NodeConfig: def create_marital_status_node() -> NodeConfig: """Create node for collecting marital status.""" return { + "name": "marital_status", "task_messages": [ { "role": "system", @@ -228,10 +222,11 @@ def create_marital_status_node() -> NodeConfig: FlowsFunctionSchema( name="collect_marital_status", description="Record marital status after customer provides it", - properties={"marital_status": {"type": "string", "enum": ["single", "married"]}}, + properties={ + "marital_status": {"type": "string", "enum": ["single", "married"]}, + }, required=["marital_status"], handler=collect_marital_status, - transition_callback=handle_marital_status_collection, ) ], } @@ -240,6 +235,7 @@ def create_marital_status_node() -> NodeConfig: def create_quote_calculation_node(age: int, marital_status: str) -> NodeConfig: """Create node for calculating initial quote.""" return { + "name": "quote_calculation", "task_messages": [ { "role": "system", @@ -260,7 +256,6 @@ def create_quote_calculation_node(age: int, marital_status: str) -> NodeConfig: }, required=["age", "marital_status"], handler=calculate_quote, - transition_callback=handle_quote_calculation, ) ], } @@ -271,6 +266,7 @@ def create_quote_results_node( ) -> NodeConfig: """Create node for showing quote and adjustment options.""" return { + "name": "quote_results", "task_messages": [ { "role": "system", @@ -301,10 +297,9 @@ def create_quote_results_node( FlowsFunctionSchema( name="end_quote", description="Complete the quote process when customer is satisfied", - properties={"status": {"type": "string", "enum": ["completed"]}}, - required=["status"], + properties={}, + required=[], handler=end_quote, - transition_callback=handle_end_quote, ), ], } @@ -313,6 +308,7 @@ def create_quote_results_node( def create_end_node() -> NodeConfig: """Create the final node.""" return { + "name": "end", "task_messages": [ { "role": "system", @@ -377,9 +373,7 @@ async def main(): async def on_first_participant_joined(transport, participant): await transport.capture_participant_transcription(participant["id"]) # Initialize flow - await flow_manager.initialize() - # Set initial node - await flow_manager.set_node("initial", create_initial_node()) + await flow_manager.initialize(create_initial_node()) # Run the pipeline runner = PipelineRunner() diff --git a/examples/dynamic/insurance_openai.py b/examples/dynamic/insurance_openai.py index d0f3259..ca3ae2c 100644 --- a/examples/dynamic/insurance_openai.py +++ b/examples/dynamic/insurance_openai.py @@ -29,7 +29,7 @@ import os import sys from pathlib import Path -from typing import Dict, TypedDict, Union +from typing import TypedDict, Union import aiohttp from dotenv import load_dotenv @@ -88,21 +88,36 @@ class CoverageUpdateResult(FlowResult, InsuranceQuote): # Function handlers -async def collect_age(args: FlowArgs) -> AgeCollectionResult: +async def collect_age( + args: FlowArgs, flow_manager: FlowManager +) -> tuple[AgeCollectionResult, NodeConfig]: """Process age collection.""" age = args["age"] logger.debug(f"collect_age handler executing with age: {age}") - return AgeCollectionResult(age=age) + flow_manager.state["age"] = age + result = AgeCollectionResult(age=age) -async def collect_marital_status(args: FlowArgs) -> MaritalStatusResult: + next_node = create_marital_status_node() + + return result, next_node + + +async def collect_marital_status( + args: FlowArgs, flow_manager: FlowManager +) -> tuple[MaritalStatusResult, NodeConfig]: """Process marital status collection.""" status = args["marital_status"] logger.debug(f"collect_marital_status handler executing with status: {status}") - return MaritalStatusResult(marital_status=status) + + result = MaritalStatusResult(marital_status=status) + + next_node = create_quote_calculation_node(flow_manager.state["age"], status) + + return result, next_node -async def calculate_quote(args: FlowArgs) -> QuoteCalculationResult: +async def calculate_quote(args: FlowArgs) -> tuple[QuoteCalculationResult, NodeConfig]: """Calculate insurance quote based on age and marital status.""" age = args["age"] marital_status = args["marital_status"] @@ -116,14 +131,16 @@ async def calculate_quote(args: FlowArgs) -> QuoteCalculationResult: # Calculate quote monthly_premium = rates["base_rate"] * rates["risk_multiplier"] - return { - "monthly_premium": monthly_premium, - "coverage_amount": 250000, - "deductible": 1000, - } + result = QuoteCalculationResult( + monthly_premium=monthly_premium, + coverage_amount=250000, + deductible=1000, + ) + next_node = create_quote_results_node(result) + return result, next_node -async def update_coverage(args: FlowArgs) -> CoverageUpdateResult: +async def update_coverage(args: FlowArgs) -> tuple[CoverageUpdateResult, NodeConfig]: """Update coverage options and recalculate premium.""" coverage_amount = args["coverage_amount"] deductible = args["deductible"] @@ -136,51 +153,28 @@ async def update_coverage(args: FlowArgs) -> CoverageUpdateResult: if deductible > 1000: monthly_premium *= 0.9 # 10% discount for higher deductible - return { - "monthly_premium": monthly_premium, - "coverage_amount": coverage_amount, - "deductible": deductible, - } + result = CoverageUpdateResult( + monthly_premium=monthly_premium, + coverage_amount=coverage_amount, + deductible=deductible, + ) + next_node = create_quote_results_node(result) + return result, next_node -async def end_quote() -> FlowResult: +async def end_quote(args: FlowArgs) -> tuple[FlowResult, NodeConfig]: """Handle quote completion.""" logger.debug("end_quote handler executing") - return {"status": "completed"} - - -# Transition callbacks and handlers -async def handle_age_collection(args: Dict, result: AgeCollectionResult, flow_manager: FlowManager): - flow_manager.state["age"] = result["age"] - await flow_manager.set_node("marital_status", create_marital_status_node()) - - -async def handle_marital_status_collection( - args: Dict, result: MaritalStatusResult, flow_manager: FlowManager -): - flow_manager.state["marital_status"] = result["marital_status"] - await flow_manager.set_node( - "quote_calculation", - create_quote_calculation_node( - flow_manager.state["age"], flow_manager.state["marital_status"] - ), - ) - - -async def handle_quote_calculation( - args: Dict, result: QuoteCalculationResult, flow_manager: FlowManager -): - await flow_manager.set_node("quote_results", create_quote_results_node(result)) - - -async def handle_end_quote(_: Dict, result: FlowResult, flow_manager: FlowManager): - await flow_manager.set_node("end", create_end_node()) + result = {"status": "completed"} + next_node = create_end_node() + return result, next_node # Node configurations def create_initial_node() -> NodeConfig: """Create the initial node asking for age.""" return { + "name": "initial", "role_messages": [ { "role": "system", @@ -209,7 +203,6 @@ def create_initial_node() -> NodeConfig: "properties": {"age": {"type": "integer"}}, "required": ["age"], }, - "transition_callback": handle_age_collection, }, } ], @@ -219,6 +212,7 @@ def create_initial_node() -> NodeConfig: def create_marital_status_node() -> NodeConfig: """Create node for collecting marital status.""" return { + "name": "marital_status", "task_messages": [ { "role": "system", @@ -239,7 +233,6 @@ def create_marital_status_node() -> NodeConfig: }, "required": ["marital_status"], }, - "transition_callback": handle_marital_status_collection, }, } ], @@ -249,6 +242,7 @@ def create_marital_status_node() -> NodeConfig: def create_quote_calculation_node(age: int, marital_status: str) -> NodeConfig: """Create node for calculating initial quote.""" return { + "name": "quote_calculation", "task_messages": [ { "role": "system", @@ -277,7 +271,6 @@ def create_quote_calculation_node(age: int, marital_status: str) -> NodeConfig: }, "required": ["age", "marital_status"], }, - "transition_callback": handle_quote_calculation, }, } ], @@ -289,6 +282,7 @@ def create_quote_results_node( ) -> NodeConfig: """Create node for showing quote and adjustment options.""" return { + "name": "quote_results", "task_messages": [ { "role": "system", @@ -329,7 +323,6 @@ def create_quote_results_node( "handler": end_quote, "description": "Complete the quote process", "parameters": {"type": "object", "properties": {}}, - "transition_callback": handle_end_quote, }, }, ], @@ -339,6 +332,7 @@ def create_quote_results_node( def create_end_node() -> NodeConfig: """Create the final node.""" return { + "name": "end", "task_messages": [ { "role": "system", @@ -400,9 +394,7 @@ async def main(): async def on_first_participant_joined(transport, participant): await transport.capture_participant_transcription(participant["id"]) logger.debug("Initializing flow") - await flow_manager.initialize() - logger.debug("Setting initial node") - await flow_manager.set_node("initial", create_initial_node()) + await flow_manager.initialize(create_initial_node()) # Run the pipeline runner = PipelineRunner() diff --git a/examples/dynamic/restaurant_reservation.py b/examples/dynamic/restaurant_reservation.py index aeceaa0..b7b9e8c 100644 --- a/examples/dynamic/restaurant_reservation.py +++ b/examples/dynamic/restaurant_reservation.py @@ -81,13 +81,15 @@ class TimeResult(FlowResult): # Function handlers -async def collect_party_size(args: FlowArgs) -> PartySizeResult: +async def collect_party_size(args: FlowArgs) -> tuple[PartySizeResult, NodeConfig]: """Process party size collection.""" size = args["size"] - return PartySizeResult(size=size, status="success") + result = PartySizeResult(size=size, status="success") + next_node = create_time_selection_node() + return result, next_node -async def check_availability(args: FlowArgs) -> TimeResult: +async def check_availability(args: FlowArgs) -> tuple[TimeResult, NodeConfig]: """Check reservation availability and return result.""" time = args["time"] party_size = args["party_size"] @@ -98,38 +100,20 @@ async def check_availability(args: FlowArgs) -> TimeResult: result = TimeResult( status="success", time=time, available=is_available, alternative_times=alternative_times ) - return result - -# Transition handlers -async def handle_party_size_collection( - args: Dict, result: PartySizeResult, flow_manager: FlowManager -): - """Handle party size collection and transition to time selection.""" - # Store party size in flow state - flow_manager.state["party_size"] = result["size"] - await flow_manager.set_node("get_time", create_time_selection_node()) - - -async def handle_availability_check(args: Dict, result: TimeResult, flow_manager: FlowManager): - """Handle availability check result and transition based on availability.""" - # Store reservation details in flow state - flow_manager.state["requested_time"] = args["time"] - - # Use result directly instead of accessing state - if result["available"]: + if is_available: logger.debug("Time is available, transitioning to confirmation node") - await flow_manager.set_node("confirm", create_confirmation_node()) + next_node = create_confirmation_node() else: logger.debug(f"Time not available, storing alternatives: {result['alternative_times']}") - await flow_manager.set_node( - "no_availability", create_no_availability_node(result["alternative_times"]) - ) + next_node = create_no_availability_node(result["alternative_times"]) + + return result, next_node -async def handle_end(_: Dict, result: FlowResult, flow_manager: FlowManager): +async def end_conversation(args: FlowArgs) -> tuple[None, NodeConfig]: """Handle conversation end.""" - await flow_manager.set_node("end", create_end_node()) + return None, create_end_node() # Create function schemas @@ -139,7 +123,6 @@ async def handle_end(_: Dict, result: FlowResult, flow_manager: FlowManager): properties={"size": {"type": "integer", "minimum": 1, "maximum": 12}}, required=["size"], handler=collect_party_size, - transition_callback=handle_party_size_collection, ) availability_schema = FlowsFunctionSchema( @@ -155,7 +138,6 @@ async def handle_end(_: Dict, result: FlowResult, flow_manager: FlowManager): }, required=["time", "party_size"], handler=check_availability, - transition_callback=handle_availability_check, ) end_conversation_schema = FlowsFunctionSchema( @@ -163,7 +145,7 @@ async def handle_end(_: Dict, result: FlowResult, flow_manager: FlowManager): description="End the conversation", properties={}, required=[], - transition_callback=handle_end, + handler=end_conversation, ) @@ -171,6 +153,7 @@ async def handle_end(_: Dict, result: FlowResult, flow_manager: FlowManager): def create_initial_node(wait_for_user: bool) -> NodeConfig: """Create initial node for party size collection.""" return { + "name": "initial", "role_messages": [ { "role": "system", @@ -192,6 +175,7 @@ def create_time_selection_node() -> NodeConfig: """Create node for time selection and availability check.""" logger.debug("Creating time selection node") return { + "name": "get_time", "task_messages": [ { "role": "system", @@ -205,6 +189,7 @@ def create_time_selection_node() -> NodeConfig: def create_confirmation_node() -> NodeConfig: """Create confirmation node for successful reservations.""" return { + "name": "confirm", "task_messages": [ { "role": "system", @@ -219,6 +204,7 @@ def create_no_availability_node(alternative_times: list[str]) -> NodeConfig: """Create node for handling no availability.""" times_list = ", ".join(alternative_times) return { + "name": "no_availability", "task_messages": [ { "role": "system", @@ -236,6 +222,7 @@ def create_no_availability_node(alternative_times: list[str]) -> NodeConfig: def create_end_node() -> NodeConfig: """Create the final node.""" return { + "name": "end", "task_messages": [ { "role": "system", @@ -297,9 +284,7 @@ async def main(wait_for_user: bool): async def on_first_participant_joined(transport, participant): await transport.capture_participant_transcription(participant["id"]) logger.debug("Initializing flow manager") - await flow_manager.initialize() - logger.debug("Setting initial node") - await flow_manager.set_node("initial", create_initial_node(wait_for_user)) + await flow_manager.initialize(create_initial_node(wait_for_user)) runner = PipelineRunner() await runner.run(task) diff --git a/examples/dynamic/warm_transfer.py b/examples/dynamic/warm_transfer.py index be3a8d7..ab279c8 100644 --- a/examples/dynamic/warm_transfer.py +++ b/examples/dynamic/warm_transfer.py @@ -61,7 +61,7 @@ ) from pipecat_flows import ContextStrategyConfig, FlowManager, FlowResult, NodeConfig -from pipecat_flows.types import ActionConfig, ContextStrategy, FlowsFunctionSchema +from pipecat_flows.types import ActionConfig, ContextStrategy, FlowArgs, FlowsFunctionSchema sys.path.append(str(Path(__file__).parent.parent)) from runner import configure @@ -79,8 +79,9 @@ # - check_store_location_and_hours_of_operation (always succeeds) # - start_order (always fails) # - end_customer_conversation -# Transition: -# - transition_after_customer_task (to either continued_customer_interaction or transferring_to_human_agent) +# Transitions to either: +# - continued_customer_interaction +# - transferring_to_human_agent # # 2. continued_customer_interaction # The bot has already helped the customer with something. Now they're helping them with something else. @@ -88,8 +89,9 @@ # - check_store_location_and_hours_of_operation (always succeeds) # - start_order (always fails) # - end_customer_conversation -# Transition: -# - transition_after_customer_task (to either continued_customer_interaction or transferring_to_human_agent) +# Transitions to either: +# - continued_customer_interaction +# - transferring_to_human_agent # # 3. transferring_to_human_agent # The customer is asked to please hold while the bot transfers them to a human agent. Hold music plays while the customer waits. @@ -121,18 +123,24 @@ class StartOrderResult(FlowResult): # Function handlers -async def check_store_location_and_hours_of_operation() -> StoreLocationAndHoursOfOperationResult: +async def check_store_location_and_hours_of_operation() -> tuple[ + StoreLocationAndHoursOfOperationResult, NodeConfig +]: """Check store location and hours of operation.""" - return StoreLocationAndHoursOfOperationResult( + result = StoreLocationAndHoursOfOperationResult( status="success", store_location="123 Main St, Anytown, USA", hours_of_operation="9am to 5pm, Monday through Friday", ) + next_node = next_node_after_customer_task(result) + return result, next_node -async def start_order() -> StartOrderResult: +async def start_order() -> tuple[StartOrderResult, NodeConfig]: """Start a new order.""" - return StartOrderResult(status="error") + result = StartOrderResult(status="error") + next_node = next_node_after_customer_task(result) + return result, next_node # Action handlers @@ -216,49 +224,43 @@ async def unmute_customer_and_make_humans_hear_each_other(action: dict, flow_man ) -# Transitions -async def start_customer_interaction(flow_manager: FlowManager): - """Transition to the "customer_interaction" node""" - await flow_manager.set_node("customer_interaction", create_initial_customer_interaction_node()) +# Functions +async def end_customer_conversation( + args: FlowArgs, flow_manager: FlowManager +) -> tuple[None, NodeConfig]: + """Transition to the "end_customer_conversation" node.""" + return None, create_end_customer_conversation_node() + + +async def end_human_agent_conversation( + args: FlowArgs, flow_manager: FlowManager +) -> tuple[None, NodeConfig]: + """Transition to the "end_human_agent_conversation" node.""" + return None, create_end_human_agent_conversation_node() -async def transition_after_customer_task(args: Dict, result: FlowResult, flow_manager: FlowManager): +# Helpers +def next_node_after_customer_task(result: FlowResult) -> NodeConfig: """Transition to either the "continued_customer_interaction" node or "transferring_to_human_agent" node, depending on the outcome of the previous customer task""" if result.get("status") == "success": - await flow_manager.set_node( - "continued_customer_interaction", create_continued_customer_interaction_node() - ) + return create_continued_customer_interaction_node() else: - await flow_manager.set_node( - "transferring_to_human_agent", create_transferring_to_human_agent_node() - ) + return create_transferring_to_human_agent_node() +# Transitions async def start_human_agent_interaction(flow_manager: FlowManager): """Transition to the "human_agent_interaction" node.""" await flow_manager.set_node("human_agent_interaction", create_human_agent_interaction_node()) -async def end_customer_conversation(args: Dict, flow_manager: FlowManager): - """Transition to the "end_customer_conversation" node.""" - await flow_manager.set_node( - "end_customer_conversation", create_end_customer_conversation_node() - ) - - -async def end_human_agent_conversation(args: Dict, flow_manager: FlowManager): - """Transition to the "end_human_agent_conversation" node.""" - await flow_manager.set_node( - "end_human_agent_conversation", create_end_human_agent_conversation_node() - ) - - # Node configuration def create_initial_customer_interaction_node() -> NodeConfig: """Create the "initial_customer_interaction" node. This is the initial node where the bot interacts with the customer and tries to help with their requests. """ return NodeConfig( + name="customer_interaction", role_messages=[ { "role": "system", @@ -283,7 +285,6 @@ def create_initial_customer_interaction_node() -> NodeConfig: name="check_store_location_and_hours_of_operation", description="Check store location and hours of operation", handler=check_store_location_and_hours_of_operation, - transition_callback=transition_after_customer_task, properties={}, required=[], ), @@ -291,14 +292,13 @@ def create_initial_customer_interaction_node() -> NodeConfig: name="start_order", description="Start placing an order", handler=start_order, - transition_callback=transition_after_customer_task, properties={}, required=[], ), FlowsFunctionSchema( name="end_customer_conversation", description="End the conversation", - transition_callback=end_customer_conversation, + handler=end_customer_conversation, properties={}, required=[], ), @@ -312,6 +312,7 @@ def create_continued_customer_interaction_node() -> NodeConfig: It assumes that the bot has already previously helped the customer with something. """ return NodeConfig( + name="continued_customer_interaction", task_messages=[ { "role": "system", @@ -330,7 +331,6 @@ def create_continued_customer_interaction_node() -> NodeConfig: name="check_store_location_and_hours_of_operation", description="Check store location and hours of operation", handler=check_store_location_and_hours_of_operation, - transition_callback=transition_after_customer_task, properties={}, required=[], ), @@ -338,14 +338,13 @@ def create_continued_customer_interaction_node() -> NodeConfig: name="start_order", description="Start placing an order", handler=start_order, - transition_callback=transition_after_customer_task, properties={}, required=[], ), FlowsFunctionSchema( name="end_customer_conversation", description="End the conversation", - transition_callback=end_customer_conversation, + handler=end_customer_conversation, properties={}, required=[], ), @@ -358,6 +357,7 @@ def create_transferring_to_human_agent_node() -> NodeConfig: This is the node where the customer is asked to please hold while the bot transfers them to a human agent. Hold music plays while the customer waits. """ return NodeConfig( + name="transferring_to_human_agent", task_messages=[ { "role": "system", @@ -381,6 +381,7 @@ def create_human_agent_interaction_node() -> NodeConfig: The customer continues to hear hold music. """ return NodeConfig( + name="human_agent_interaction", task_messages=[ { "role": "system", @@ -400,7 +401,7 @@ def create_human_agent_interaction_node() -> NodeConfig: FlowsFunctionSchema( name="connect_human_agent_and_customer", description="Connect the human agent to the customer", - transition_callback=end_human_agent_conversation, + handler=end_human_agent_conversation, properties={}, required=[], ) @@ -414,6 +415,7 @@ def create_end_customer_conversation_node() -> NodeConfig: This is how a conversation ends when a human agent did not need to be brought in. """ return NodeConfig( + name="end_customer_conversation", task_messages=[ { "role": "system", @@ -429,6 +431,7 @@ def create_end_human_agent_conversation_node() -> NodeConfig: This is the node where the bot tells the agent that they're being patched through to the customer and ends the conversation (leaving the customer and agent in the room talking to each other). """ return NodeConfig( + name="end_human_agent_conversation", task_messages=[ { "role": "system", @@ -621,9 +624,7 @@ async def on_first_participant_joined( """ await transport.capture_participant_transcription(participant["id"]) # Initialize flow - await flow_manager.initialize() - # Set initial node - await start_customer_interaction(flow_manager=flow_manager) + await flow_manager.initialize(create_initial_customer_interaction_node()) @transport.event_handler("on_participant_joined") async def on_participant_joined(transport: DailyTransport, participant: Dict[str, Any]): From ea5d9fbddd78f97d79c36a6843405a76e9c39601 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 30 May 2025 22:23:30 -0400 Subject: [PATCH 44/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20remove=20unused=20import?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pipecat_flows/manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index 1288042..56cfff8 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -26,7 +26,6 @@ import asyncio import inspect import sys -import uuid import warnings from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Union, cast From 8da25bf7ee4bf3010294f18c00aa713ddf75191e Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 30 May 2025 22:47:22 -0400 Subject: [PATCH 45/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20Add=20`set=5Fnode=5Ffrom=5Fconfig?= =?UTF-8?q?()`,=20an=20alternative=20to=20`set=5Fnode()`=20that=20doesn't?= =?UTF-8?q?=20require=20the=20customer=20to=20provide=20a=20node=20name,?= =?UTF-8?q?=20since=20they=20can=20do=20so=20from=20within=20their=20node?= =?UTF-8?q?=20config.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Hopefully users won't need to directly set nodes often, what with the introduction of: - "consolidated" functions that return the next node, eliminating the need for transition callbacks - `initialize()` taking the initial node `set_node_from_config()` is a bit clunky, but I couldn’t find a way to reuse `set_node()` without breaking API compatibility. There are ways of fudging method overloading in Python, but not without breaking keyword-argument-based invocation, it seems. --- examples/dynamic/warm_transfer.py | 2 +- src/pipecat_flows/manager.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/dynamic/warm_transfer.py b/examples/dynamic/warm_transfer.py index ab279c8..2d63559 100644 --- a/examples/dynamic/warm_transfer.py +++ b/examples/dynamic/warm_transfer.py @@ -251,7 +251,7 @@ def next_node_after_customer_task(result: FlowResult) -> NodeConfig: # Transitions async def start_human_agent_interaction(flow_manager: FlowManager): """Transition to the "human_agent_interaction" node.""" - await flow_manager.set_node("human_agent_interaction", create_human_agent_interaction_node()) + await flow_manager.set_node_from_config(create_human_agent_interaction_node()) # Node configuration diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index 56cfff8..5056d75 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -542,6 +542,16 @@ async def _register_function( logger.error(f"Failed to register function {name}: {str(e)}") raise FlowError(f"Function registration failed: {str(e)}") from e + async def set_node_from_config( + self, node_config: NodeConfig + ) -> None: + """Set up a new conversation node and transition to it. + + Args: + node_config: Configuration for the new node + """ + await self.set_node(get_or_generate_node_name(node_config), node_config) + async def set_node(self, node_id: str, node_config: NodeConfig) -> None: """Set up a new conversation node and transition to it. From 1d913b655f373fbd07df0d395c725c1125b3d3b3 Mon Sep 17 00:00:00 2001 From: Paul Kompfner Date: Fri, 30 May 2025 23:02:31 -0400 Subject: [PATCH 46/47] =?UTF-8?q?[WIP]=20More=20progress=20on=20=E2=80=9Cd?= =?UTF-8?q?irect=20functions=E2=80=9D:=20deprecate=20`set=5Fnode()`=20(tak?= =?UTF-8?q?e=202).=20This=20time=20we=20have=20an=20alternative:=20`set=5F?= =?UTF-8?q?node=5Ffrom=5Fconfig()`.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 10 ++++++ src/pipecat_flows/manager.py | 45 +++++++++++++++++++----- tests/test_context_strategies.py | 28 +++++++-------- tests/test_manager.py | 60 ++++++++++++++++---------------- 4 files changed, 91 insertions(+), 52 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 50a8e04..36f3136 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -104,6 +104,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 using `FlowsFunctionSchema`s or function definition dicts entirely. See the "Added" section above for more details. +- Deprecated `set_node()` in favor of doing the following for dynamic flows: + + - Prefer "consolidated" or "direct" functions that return a tuple (result, next node) over + deprecated `transition_callback`s + - Pass your initial node to `FlowManager.initialize()` + - If you really need to set a node explicitly, use `set_node_from_config()` + + In all of these cases, you can provide a `name` in your new node's config for debug logging + purposes. + ### Changed - `functions` are now optional in the `NodeConfig`. Additionally, for AWS diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index 5056d75..c966593 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -144,6 +144,7 @@ def __init__( self.current_node: Optional[str] = None self._showed_deprecation_warning_for_transition_fields = False + self._showed_deprecation_warning_for_set_node = False def _validate_transition_callback(self, name: str, callback: Any) -> None: """Validate a transition callback. @@ -189,7 +190,7 @@ async def initialize(self, initial_node: Optional[NodeConfig] = None) -> None: node = initial_node if node_name: logger.debug(f"Setting initial node: {node_name}") - await self.set_node(node_name, node) + await self._set_node(node_name, node) except Exception as e: self.initialized = False @@ -358,10 +359,10 @@ async def on_context_updated_edge( node_name = get_or_generate_node_name(next_node) node = next_node logger.debug(f"Transition to function-returned node: {node_name}") - await self.set_node(node_name, node) + await self._set_node(node_name, node) elif transition_to: # Static flow logger.debug(f"Static transition to: {transition_to}") - await self.set_node(transition_to, self.nodes[transition_to]) + await self._set_node(transition_to, self.nodes[transition_to]) elif transition_callback: # Dynamic flow logger.debug(f"Dynamic transition for: {name}") # Check callback signature @@ -542,19 +543,47 @@ async def _register_function( logger.error(f"Failed to register function {name}: {str(e)}") raise FlowError(f"Function registration failed: {str(e)}") from e - async def set_node_from_config( - self, node_config: NodeConfig - ) -> None: + async def set_node_from_config(self, node_config: NodeConfig) -> None: """Set up a new conversation node and transition to it. - + Args: node_config: Configuration for the new node + + Raises: + FlowTransitionError: If manager not initialized + FlowError: If node setup fails """ - await self.set_node(get_or_generate_node_name(node_config), node_config) + await self._set_node(get_or_generate_node_name(node_config), node_config) async def set_node(self, node_id: str, node_config: NodeConfig) -> None: """Set up a new conversation node and transition to it. + Args: + node_id: Identifier for the new node + node_config: Configuration for the new node + + Raises: + FlowTransitionError: If manager not initialized + FlowError: If node setup fails + """ + if not self._showed_deprecation_warning_for_set_node: + self._showed_deprecation_warning_for_set_node = True + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + """`set_node()` is deprecated and will be removed in a future version. Instead, do the following for dynamic flows: +- Prefer "consolidated" or "direct" functions that return a tuple (result, next_node) over deprecated `transition_callback`s +- Pass your initial node to `FlowManager.initialize()` +- If you really need to set a node explicitly, use `set_node_from_config()` +In all of these cases, you can provide a `name` in your new node's config for debug logging purposes.""", + DeprecationWarning, + stacklevel=2, + ) + await self._set_node(node_id, node_config) + + async def _set_node(self, node_id: str, node_config: NodeConfig) -> None: + """Set up a new conversation node and transition to it. + Handles the complete node transition process in the following order: 1. Execute pre-actions (if any) 2. Set up messages (role and task) diff --git a/tests/test_context_strategies.py b/tests/test_context_strategies.py index 2c27ddb..715213e 100644 --- a/tests/test_context_strategies.py +++ b/tests/test_context_strategies.py @@ -96,7 +96,7 @@ async def test_default_strategy(self): await flow_manager.initialize() # First node should use UpdateFrame regardless of strategy - await flow_manager.set_node("first", self.sample_node) + await flow_manager._set_node("first", self.sample_node) first_call = self.mock_task.queue_frames.call_args_list[0] first_frames = first_call[0][0] self.assertTrue(any(isinstance(f, LLMMessagesUpdateFrame) for f in first_frames)) @@ -105,7 +105,7 @@ async def test_default_strategy(self): self.mock_task.queue_frames.reset_mock() # Subsequent node should use AppendFrame with default strategy - await flow_manager.set_node("second", self.sample_node) + await flow_manager._set_node("second", self.sample_node) second_call = self.mock_task.queue_frames.call_args_list[0] second_frames = second_call[0][0] self.assertTrue(any(isinstance(f, LLMMessagesAppendFrame) for f in second_frames)) @@ -121,11 +121,11 @@ async def test_reset_strategy(self): await flow_manager.initialize() # Set initial node - await flow_manager.set_node("first", self.sample_node) + await flow_manager._set_node("first", self.sample_node) self.mock_task.queue_frames.reset_mock() # Second node should use UpdateFrame with RESET strategy - await flow_manager.set_node("second", self.sample_node) + await flow_manager._set_node("second", self.sample_node) second_call = self.mock_task.queue_frames.call_args_list[0] second_frames = second_call[0][0] self.assertTrue(any(isinstance(f, LLMMessagesUpdateFrame) for f in second_frames)) @@ -150,10 +150,10 @@ async def test_reset_with_summary_success(self): await flow_manager.initialize() # Set nodes and verify summary inclusion - await flow_manager.set_node("first", self.sample_node) + await flow_manager._set_node("first", self.sample_node) self.mock_task.queue_frames.reset_mock() - await flow_manager.set_node("second", self.sample_node) + await flow_manager._set_node("second", self.sample_node) # Verify summary was included in context update second_call = self.mock_task.queue_frames.call_args_list[0] @@ -180,10 +180,10 @@ async def test_reset_with_summary_timeout(self): ) # Set nodes and verify fallback to RESET - await flow_manager.set_node("first", self.sample_node) + await flow_manager._set_node("first", self.sample_node) self.mock_task.queue_frames.reset_mock() - await flow_manager.set_node("second", self.sample_node) + await flow_manager._set_node("second", self.sample_node) # Verify UpdateFrame was used (RESET behavior) second_call = self.mock_task.queue_frames.call_args_list[0] @@ -238,10 +238,10 @@ async def test_node_level_strategy_override(self): } # Set nodes and verify strategy override - await flow_manager.set_node("first", self.sample_node) + await flow_manager._set_node("first", self.sample_node) self.mock_task.queue_frames.reset_mock() - await flow_manager.set_node("second", node_with_strategy) + await flow_manager._set_node("second", node_with_strategy) # Verify UpdateFrame was used (RESET behavior) despite global APPEND second_call = self.mock_task.queue_frames.call_args_list[0] @@ -267,8 +267,8 @@ async def test_summary_generation_content(self): await flow_manager.initialize() # Set nodes to trigger summary generation - await flow_manager.set_node("first", self.sample_node) - await flow_manager.set_node("second", self.sample_node) + await flow_manager._set_node("first", self.sample_node) + await flow_manager._set_node("second", self.sample_node) # Verify summary generation call create_call = self.mock_llm._client.chat.completions.create.call_args @@ -299,7 +299,7 @@ async def test_context_structure_after_summary(self): await flow_manager.initialize() # Set nodes to trigger summary generation - await flow_manager.set_node("first", self.sample_node) + await flow_manager._set_node("first", self.sample_node) self.mock_task.queue_frames.reset_mock() # Node with new task messages @@ -307,7 +307,7 @@ async def test_context_structure_after_summary(self): "task_messages": [{"role": "system", "content": "New task."}], "functions": [], } - await flow_manager.set_node("second", new_node) + await flow_manager._set_node("second", new_node) # Verify context structure update_call = self.mock_task.queue_frames.call_args_list[0] diff --git a/tests/test_manager.py b/tests/test_manager.py index 84e167d..912adf5 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -134,7 +134,7 @@ async def test_dynamic_flow_initialization(self): # Initialize and set node await flow_manager.initialize() - await flow_manager.set_node("test", test_node) + await flow_manager._set_node("test", test_node) self.assertFalse( mock_transition_handler.called @@ -160,7 +160,7 @@ async def test_static_flow_transitions(self): # In static flows, transitions happen through set_node with a # predefined node configuration from the flow_config - await flow_manager.set_node("next_node", flow_manager.nodes["next_node"]) + await flow_manager._set_node("next_node", flow_manager.nodes["next_node"]) # Verify node transition occurred self.assertEqual(flow_manager.current_node, "next_node") @@ -231,7 +231,7 @@ async def test_handler(args: FlowArgs) -> FlowResult: } # Test old style callback - await flow_manager.set_node("old_style", old_style_node) + await flow_manager._set_node("old_style", old_style_node) func = flow_manager.llm.register_function.call_args[0][1] # Store the context_updated callback @@ -279,7 +279,7 @@ async def result_callback(result, properties=None): } # Test new style callback - await flow_manager.set_node("new_style", new_style_node) + await flow_manager._set_node("new_style", new_style_node) func = flow_manager.llm.register_function.call_args[0][1] # Reset context_updated callback @@ -316,12 +316,12 @@ async def test_node_validation(self): # Test missing task_messages invalid_config = {"functions": []} with self.assertRaises(FlowError) as context: - await flow_manager.set_node("test", invalid_config) + await flow_manager._set_node("test", invalid_config) self.assertIn("missing required 'task_messages' field", str(context.exception)) # Test valid config valid_config = {"task_messages": []} - await flow_manager.set_node("test", valid_config) + await flow_manager._set_node("test", valid_config) self.assertEqual(flow_manager.current_node, "test") self.assertEqual(flow_manager.current_functions, set()) @@ -339,7 +339,7 @@ async def test_function_registration(self): self.mock_llm.register_function.reset_mock() # Set node with function - await flow_manager.set_node("test", self.sample_node) + await flow_manager._set_node("test", self.sample_node) # Verify function was registered self.mock_llm.register_function.assert_called_once() @@ -370,7 +370,7 @@ async def test_action_execution(self): self.mock_tts.say.reset_mock() # Set node with actions - await flow_manager.set_node("test", node_with_actions) + await flow_manager._set_node("test", node_with_actions) # Verify TTS was called for both actions self.assertEqual(self.mock_tts.say.call_count, 2) @@ -393,7 +393,7 @@ async def test_error_handling(self): # Test setting node before initialization with self.assertRaises(FlowTransitionError): - await flow_manager.set_node("test", self.sample_node) + await flow_manager._set_node("test", self.sample_node) # Initialize normally await flow_manager.initialize() @@ -402,7 +402,7 @@ async def test_error_handling(self): # Test node setting error self.mock_task.queue_frames.side_effect = Exception("Queue error") with self.assertRaises(FlowError): - await flow_manager.set_node("test", self.sample_node) + await flow_manager._set_node("test", self.sample_node) # Verify flow manager remains initialized despite error self.assertTrue(flow_manager.initialized) @@ -424,7 +424,7 @@ async def test_state_management(self): self.mock_task.queue_frames.reset_mock() # Verify state persists across node transitions - await flow_manager.set_node("test", self.sample_node) + await flow_manager._set_node("test", self.sample_node) self.assertEqual(flow_manager.state["test_key"], test_value) async def test_multiple_function_registration(self): @@ -452,7 +452,7 @@ async def test_multiple_function_registration(self): ], } - await flow_manager.set_node("test", node_config) + await flow_manager._set_node("test", node_config) # Verify all functions were registered self.assertEqual(self.mock_llm.register_function.call_count, 3) @@ -562,7 +562,7 @@ async def test_node_validation_edge_cases(self): "functions": [{"type": "function"}], # Missing name } with self.assertRaises(FlowError) as context: - await flow_manager.set_node("test", invalid_config) + await flow_manager._set_node("test", invalid_config) self.assertIn("invalid format", str(context.exception)) # Test node function without handler or transition_to @@ -588,7 +588,7 @@ def capture_warning(msg, *args, **kwargs): warning_message = msg with patch("loguru.logger.warning", side_effect=capture_warning): - await flow_manager.set_node("test", invalid_config) + await flow_manager._set_node("test", invalid_config) self.assertIsNotNone(warning_message) self.assertIn( "Function 'test_func' in node 'test' has neither handler, transition_to, nor transition_callback", @@ -627,7 +627,7 @@ async def failing_handler(args, flow_manager): } # Set up node and get registered function - await flow_manager.set_node("test", test_node) + await flow_manager._set_node("test", test_node) transition_func = flow_manager.llm.register_function.call_args[0][1] # Track the result and context_updated callback @@ -702,7 +702,7 @@ async def test_action_execution_error_handling(self): # Should raise FlowError due to invalid actions with self.assertRaises(FlowError): - await flow_manager.set_node("test", node_config) + await flow_manager._set_node("test", node_config) # Verify error handling for pre and post actions separately with self.assertRaises(FlowError): @@ -766,7 +766,7 @@ async def test_handler(args): } # Set node and verify function registration - await flow_manager.set_node("test", node_config) + await flow_manager._set_node("test", node_config) # Verify both functions were registered self.assertIn("test1", flow_manager.current_functions) @@ -806,7 +806,7 @@ async def test_handler_main(args): ], } - await flow_manager.set_node("test", node_config) + await flow_manager._set_node("test", node_config) self.assertIn("test_function", flow_manager.current_functions) finally: @@ -838,7 +838,7 @@ async def test_function_token_handling_not_found(self): } with self.assertRaises(FlowError) as context: - await flow_manager.set_node("test", node_config) + await flow_manager._set_node("test", node_config) self.assertIn("Function 'nonexistent_handler' not found", str(context.exception)) @@ -879,7 +879,7 @@ async def test_handler(args): ], } - await flow_manager.set_node("test", node_config) + await flow_manager._set_node("test", node_config) # Get the registered function and test it name, func = self.mock_llm.register_function.call_args[0] @@ -928,7 +928,7 @@ async def test_role_message_inheritance(self): } # Set first node and verify UpdateFrame - await flow_manager.set_node("first", first_node) + await flow_manager._set_node("first", first_node) first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call first_frames = first_call[0][0] update_frames = [f for f in first_frames if isinstance(f, LLMMessagesUpdateFrame)] @@ -940,7 +940,7 @@ async def test_role_message_inheritance(self): # Reset mock and set second node self.mock_task.queue_frames.reset_mock() - await flow_manager.set_node("second", second_node) + await flow_manager._set_node("second", second_node) # Verify AppendFrame for second node first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call @@ -966,7 +966,7 @@ async def test_frame_type_selection(self): } # First node should use UpdateFrame - await flow_manager.set_node("first", test_node) + await flow_manager._set_node("first", test_node) first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call first_frames = first_call[0][0] self.assertTrue( @@ -982,7 +982,7 @@ async def test_frame_type_selection(self): self.mock_task.queue_frames.reset_mock() # Second node should use AppendFrame - await flow_manager.set_node("second", test_node) + await flow_manager._set_node("second", test_node) first_call = self.mock_task.queue_frames.call_args_list[0] # Get first call second_frames = first_call[0][0] self.assertTrue( @@ -1033,7 +1033,7 @@ async def test_handler(args): ], } - await flow_manager.set_node("test", node_config) + await flow_manager._set_node("test", node_config) # Get the registered functions node_func = None @@ -1105,7 +1105,7 @@ async def test_completion_timing(self): self.mock_task.queue_frames.reset_mock() self.mock_context_aggregator.user().get_context_frame.reset_mock() - await flow_manager.set_node( + await flow_manager._set_node( "initial", { "task_messages": [{"role": "system", "content": "Test"}], @@ -1131,7 +1131,7 @@ async def test_completion_timing(self): self.mock_task.queue_frames.reset_mock() self.mock_context_aggregator.user().get_context_frame.reset_mock() - await flow_manager.set_node("next", next_node) + await flow_manager._set_node("next", next_node) # Should see context update and completion trigger again self.assertTrue(self.mock_task.queue_frames.called) @@ -1168,7 +1168,7 @@ async def test_transition_configuration_exclusivity(self): # Should raise error when trying to use both with self.assertRaises(FlowError) as context: - await flow_manager.set_node("test", test_node) + await flow_manager._set_node("test", test_node) self.assertIn( "Cannot specify both transition_to and transition_callback", str(context.exception) ) @@ -1236,7 +1236,7 @@ async def test_node_without_functions(self): } # Set node and verify it works without error - await flow_manager.set_node("no_functions", node_config) + await flow_manager._set_node("no_functions", node_config) # Verify current_functions is empty set self.assertEqual(flow_manager.current_functions, set()) @@ -1265,7 +1265,7 @@ async def test_node_with_empty_functions(self): } # Set node and verify it works without error - await flow_manager.set_node("empty_functions", node_config) + await flow_manager._set_node("empty_functions", node_config) # Verify current_functions is empty set self.assertEqual(flow_manager.current_functions, set()) From c7c998c384416992fbf3ebf25c8f7ff8aa148d3f Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Fri, 13 Jun 2025 12:09:14 -0400 Subject: [PATCH 47/47] Fix: Handle run_in_parallel=False, simplify pending function call tracking --- CHANGELOG.md | 7 + .../static/food_ordering_direct_functions.py | 22 +-- src/pipecat_flows/manager.py | 183 ++++++++---------- 3 files changed, 92 insertions(+), 120 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 36f3136..115a213 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -122,6 +122,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 you to either omit `functions` for nodes, which is common for the end node, or specify an empty function call list, if desired. +### Fixed + +- Fixed an issue where if `run_in_parallel=False` was set for the LLM, the bot + would trigger N completions for each sequential function call. Now, Flows + uses Pipecat's internal function tracking to determine when there are more + edge functions to call. + ## [0.0.17] - 2025-05-16 ### Added diff --git a/examples/static/food_ordering_direct_functions.py b/examples/static/food_ordering_direct_functions.py index c90e4a3..003573e 100644 --- a/examples/static/food_ordering_direct_functions.py +++ b/examples/static/food_ordering_direct_functions.py @@ -92,24 +92,19 @@ async def check_kitchen_status(action: dict) -> None: # Functions async def choose_pizza(flow_manager: FlowManager) -> tuple[None, str]: - """ - User wants to order pizza. Let's get that order started. - """ + """User wants to order pizza. Let's get that order started.""" return None, "choose_pizza" async def choose_sushi(flow_manager: FlowManager) -> tuple[None, str]: - """ - User wants to order sushi. Let's get that order started. - """ + """User wants to order sushi. Let's get that order started.""" return None, "choose_sushi" async def select_pizza_order( flow_manager: FlowManager, size: str, pizza_type: str ) -> tuple[PizzaOrderResult, str]: - """ - Record the pizza order details. + """Record the pizza order details. Args: size (str): Size of the pizza. Must be one of "small", "medium", or "large". @@ -125,8 +120,7 @@ async def select_pizza_order( async def select_sushi_order( flow_manager: FlowManager, count: int, roll_type: str ) -> tuple[SushiOrderResult, str]: - """ - Record the sushi order details. + """Record the sushi order details. Args: count (int): Number of sushi rolls to order. Must be between 1 and 10. @@ -139,16 +133,12 @@ async def select_sushi_order( async def complete_order(flow_manager: FlowManager) -> tuple[None, str]: - """ - User confirms the order is correct. - """ + """User confirms the order is correct.""" return None, "end" async def revise_order(flow_manager: FlowManager) -> tuple[None, str]: - """ - User wants to make changes to their order. - """ + """User wants to make changes to their order.""" return None, "start" diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index c966593..bb308c9 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -123,7 +123,7 @@ def __init__( self.adapter = create_adapter(llm) self.initialized = False self._context_aggregator = context_aggregator - self._pending_function_calls = 0 + self._pending_transition: Optional[Dict[str, Any]] = None self._context_strategy = context_strategy or ContextStrategyConfig( strategy=ContextStrategy.APPEND ) @@ -290,6 +290,56 @@ async def _call_handler( # Modern handler with args and flow_manager return await handler(args, self) + async def _check_and_execute_transition(self) -> None: + """Check if all functions are complete and execute transition if so.""" + if not self._pending_transition: + return + + # Check if all function calls are complete using Pipecat's state + assistant_aggregator = self._context_aggregator.assistant() + if not assistant_aggregator._function_calls_in_progress: + # All functions complete, execute transition + transition_info = self._pending_transition + self._pending_transition = None + + await self._execute_transition(transition_info) + + async def _execute_transition(self, transition_info: Dict[str, Any]) -> None: + """Execute the stored transition.""" + next_node = transition_info.get("next_node") + transition_to = transition_info.get("transition_to") + transition_callback = transition_info.get("transition_callback") + function_name = transition_info.get("function_name") + arguments = transition_info.get("arguments") + result = transition_info.get("result") + + try: + if next_node: # Function-returned next node (consolidated function) + if isinstance(next_node, str): # Static flow + node_name = next_node + node = self.nodes[next_node] + else: # Dynamic flow + node_name = get_or_generate_node_name(next_node) + node = next_node + logger.debug(f"Transition to function-returned node: {node_name}") + await self._set_node(node_name, node) + elif transition_to: # Static flow (deprecated) + logger.debug(f"Static transition to: {transition_to}") + await self._set_node(transition_to, self.nodes[transition_to]) + elif transition_callback: # Dynamic flow (deprecated) + logger.debug(f"Dynamic transition for: {function_name}") + # Check callback signature + sig = inspect.signature(transition_callback) + if len(sig.parameters) == 2: + # Old style: (args, flow_manager) + await transition_callback(arguments, self) + else: + # New style: (args, result, flow_manager) + await transition_callback(arguments, result, self) + except Exception as e: + logger.error(f"Error executing transition: {str(e)}") + raise + async def _create_transition_func( self, name: str, @@ -320,87 +370,10 @@ async def _create_transition_func( if transition_callback: self._validate_transition_callback(name, transition_callback) - def decrease_pending_function_calls() -> None: - """Decrease the pending function calls counter if greater than zero.""" - if self._pending_function_calls > 0: - self._pending_function_calls -= 1 - logger.debug( - f"Function call completed: {name} (remaining: {self._pending_function_calls})" - ) - - async def on_context_updated_edge( - next_node: Optional[NodeConfig | str], - args: Optional[Dict[str, Any]], - result: Optional[Any], - result_callback: Callable, - ) -> None: - """ - Handle context updates for edge functions with transitions. - - If `next_node` is provided: - - Ignore `args` and `result` and just transition to it. - - Otherwise, if `transition_to` is available: - - Use it to look up the next node. - - Otherwise, if `transition_callback` is provided: - - Call it with `args` and `result` to determine the next node. - """ - try: - decrease_pending_function_calls() - - # Only process transition if this was the last pending call - if self._pending_function_calls == 0: - if next_node: # Function-returned next node (as opposed to next node specified by transition_*) - if isinstance(next_node, str): # Static flow - node_name = next_node - node = self.nodes[next_node] - else: # Dynamic flow - node_name = get_or_generate_node_name(next_node) - node = next_node - logger.debug(f"Transition to function-returned node: {node_name}") - await self._set_node(node_name, node) - elif transition_to: # Static flow - logger.debug(f"Static transition to: {transition_to}") - await self._set_node(transition_to, self.nodes[transition_to]) - elif transition_callback: # Dynamic flow - logger.debug(f"Dynamic transition for: {name}") - # Check callback signature - sig = inspect.signature(transition_callback) - if len(sig.parameters) == 2: - # Old style: (args, flow_manager) - await transition_callback(args, self) - else: - # New style: (args, result, flow_manager) - await transition_callback(args, result, self) - # Reset counter after transition completes - self._pending_function_calls = 0 - logger.debug("Reset pending function calls counter") - else: - logger.debug( - f"Skipping transition, {self._pending_function_calls} calls still pending" - ) - except Exception as e: - logger.error(f"Error in transition: {str(e)}") - self._pending_function_calls = 0 - await result_callback( - {"status": "error", "error": str(e)}, - properties=None, # Clear properties to prevent further callbacks - ) - raise # Re-raise to prevent further processing - - async def on_context_updated_node() -> None: - """Handle context updates for node functions without transitions.""" - decrease_pending_function_calls() - async def transition_func(params: FunctionCallParams) -> None: """Inner function that handles the actual tool invocation.""" try: - # Track pending function call - self._pending_function_calls += 1 - logger.debug( - f"Function call pending: {name} (total: {self._pending_function_calls})" - ) + logger.debug(f"Function called: {name}") # Execute handler if present is_transition_only_function = False @@ -430,42 +403,42 @@ async def transition_func(params: FunctionCallParams) -> None: result = acknowledged_result next_node = None is_transition_only_function = True + logger.debug( f"{'Transition-only function called for' if is_transition_only_function else 'Function handler completed for'} {name}" ) - # For edge functions, prevent LLM completion until transition (run_llm=False) - # For node functions, allow immediate completion (run_llm=True) + # Determine if this is an edge function has_explicit_transition = bool(transition_to) or bool(transition_callback) + is_edge_function = bool(next_node) or has_explicit_transition - async def on_context_updated() -> None: - if next_node: - await on_context_updated_edge( - next_node=next_node, - args=None, - result=None, - result_callback=params.result_callback, - ) - elif has_explicit_transition: - await on_context_updated_edge( - next_node=None, - args=params.arguments, - result=result, - result_callback=params.result_callback, - ) - else: - await on_context_updated_node() + if is_edge_function: + # Store transition info for coordinated execution + transition_info = { + "next_node": next_node, + "transition_to": transition_to, + "transition_callback": transition_callback, + "function_name": name, + "arguments": params.arguments, + "result": result, + } + self._pending_transition = transition_info + + properties = FunctionCallResultProperties( + run_llm=False, # Don't run LLM until transition completes + on_context_updated=self._check_and_execute_transition, + ) + else: + # Node function - run LLM immediately + properties = FunctionCallResultProperties( + run_llm=True, + on_context_updated=None, + ) - is_edge_function = bool(next_node) or has_explicit_transition - properties = FunctionCallResultProperties( - run_llm=not is_edge_function, - on_context_updated=on_context_updated, - ) await params.result_callback(result, properties=properties) except Exception as e: logger.error(f"Error in transition function {name}: {str(e)}") - self._pending_function_calls = 0 error_result = {"status": "error", "error": str(e)} await params.result_callback(error_result) @@ -605,6 +578,8 @@ async def _set_node(self, node_id: str, node_config: NodeConfig) -> None: raise FlowTransitionError(f"{self.__class__.__name__} must be initialized first") try: + self._pending_transition = None + self._validate_node_config(node_id, node_config) logger.debug(f"Setting node: {node_id}")