From 4b0e2d4b73dae3292a7530e4a40f07a9c2adf0ad Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Mon, 14 Jul 2025 17:50:23 +0200 Subject: [PATCH 01/21] wip: fixing tests --- haystack/components/agents/__init__.py | 2 +- haystack/components/agents/agent.py | 290 ++++++++++- haystack/core/errors.py | 29 +- haystack/core/pipeline/base.py | 9 +- haystack/core/pipeline/breakpoint.py | 413 +++++++++++++++ .../core/pipeline/delete_base_experimental.py | 60 +++ .../pipeline/delete_pipeline_experimental.py | 393 +++++++++++++++ haystack/core/pipeline/pipeline.py | 374 +++++++++++++- haystack/dataclasses/breakpoints.py | 96 ++++ pyproject.toml | 1 + .../test_agent_breakpoints_inside_pipeline.py | 410 +++++++++++++++ .../test_agent_breakpoints_isolation_async.py | 197 ++++++++ .../test_agent_breakpoints_isolation_sync.py | 121 +++++ .../agents/test_agent_breakpoints_utils.py | 120 +++++ .../agents/test_state_class_experimental.py | 477 ++++++++++++++++++ test/conftest.py | 32 ++ test/core/pipeline/test_breakpoint.py | 128 +++++ ...test_pipeline_breakpoints_answer_joiner.py | 124 +++++ ...test_pipeline_breakpoints_branch_joiner.py | 120 +++++ .../test_pipeline_breakpoints_list_joiner.py | 138 +++++ .../test_pipeline_breakpoints_loops.py | 236 +++++++++ .../test_pipeline_breakpoints_rag_hybrid.py | 292 +++++++++++ ...test_pipeline_breakpoints_string_joiner.py | 65 +++ .../pipeline/test_pipeline_experimental.py | 110 ++++ 24 files changed, 4206 insertions(+), 31 deletions(-) create mode 100644 haystack/core/pipeline/breakpoint.py create mode 100644 haystack/core/pipeline/delete_base_experimental.py create mode 100644 haystack/core/pipeline/delete_pipeline_experimental.py create mode 100644 haystack/dataclasses/breakpoints.py create mode 100644 test/components/agents/test_agent_breakpoints_inside_pipeline.py create mode 100644 test/components/agents/test_agent_breakpoints_isolation_async.py create mode 100644 test/components/agents/test_agent_breakpoints_isolation_sync.py create mode 100644 test/components/agents/test_agent_breakpoints_utils.py create mode 100644 test/components/agents/test_state_class_experimental.py create mode 100644 test/core/pipeline/test_breakpoint.py create mode 100644 test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py create mode 100644 test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py create mode 100644 test/core/pipeline/test_pipeline_breakpoints_list_joiner.py create mode 100644 test/core/pipeline/test_pipeline_breakpoints_loops.py create mode 100644 test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py create mode 100644 test/core/pipeline/test_pipeline_breakpoints_string_joiner.py create mode 100644 test/core/pipeline/test_pipeline_experimental.py diff --git a/haystack/components/agents/__init__.py b/haystack/components/agents/__init__.py index d331918f68..f94e305a6e 100644 --- a/haystack/components/agents/__init__.py +++ b/haystack/components/agents/__init__.py @@ -10,7 +10,7 @@ _import_structure = {"agent": ["Agent"], "state": ["State"]} if TYPE_CHECKING: - from .agent import Agent as Agent + from .origina_agent import Agent as Agent from .state import State as State else: diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index b9d618b3ee..ceeab4760b 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -3,16 +3,22 @@ # SPDX-License-Identifier: Apache-2.0 import inspect +from copy import deepcopy +from pathlib import Path from typing import Any, Dict, List, Optional, Union -from haystack import component, default_from_dict, default_to_dict, logging, tracing +from haystack import logging, tracing from haystack.components.generators.chat.types import ChatGenerator from haystack.components.tools import ToolInvoker +from haystack.core.component.component import component +from haystack.core.errors import BreakpointException from haystack.core.pipeline.async_pipeline import AsyncPipeline +from haystack.core.pipeline.breakpoint import _save_state from haystack.core.pipeline.pipeline import Pipeline from haystack.core.pipeline.utils import _deepcopy_with_exceptions -from haystack.core.serialization import component_to_dict +from haystack.core.serialization import component_to_dict, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, ChatRole +from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset from haystack.utils.callable_serialization import deserialize_callable, serialize_callable @@ -67,9 +73,8 @@ def __init__( exit_conditions: Optional[List[str]] = None, state_schema: Optional[Dict[str, Any]] = None, max_agent_steps: int = 100, - streaming_callback: Optional[StreamingCallbackT] = None, raise_on_tool_invocation_failure: bool = False, - tool_invoker_kwargs: Optional[Dict[str, Any]] = None, + streaming_callback: Optional[StreamingCallbackT] = None, ) -> None: """ Initialize the agent component. @@ -83,11 +88,10 @@ def __init__( :param state_schema: The schema for the runtime state used by the tools. :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100. If the agent exceeds this number of steps, it will stop and return the current state. - :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. - The same callback can be configured to emit tool results when a tool is called. :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails? If set to False, the exception will be turned into a chat message and passed to the LLM. - :param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker. + :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. + The same callback can be configured to emit tool results when a tool is called. :raises TypeError: If the chat_generator does not support tools parameter in its run method. :raises ValueError: If the exit_conditions are not valid. """ @@ -137,15 +141,9 @@ def __init__( component.set_input_type(self, name=param, type=config["type"], default=None) component.set_output_types(self, **output_types) - self.tool_invoker_kwargs = tool_invoker_kwargs self._tool_invoker = None if self.tools: - resolved_tool_invoker_kwargs = { - "tools": self.tools, - "raise_on_failure": self.raise_on_tool_invocation_failure, - **(tool_invoker_kwargs or {}), - } - self._tool_invoker = ToolInvoker(**resolved_tool_invoker_kwargs) + self._tool_invoker = ToolInvoker(tools=self.tools, raise_on_failure=self.raise_on_tool_invocation_failure) else: logger.warning( "No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text " @@ -153,6 +151,7 @@ def __init__( ) self._is_warmed_up = False + self._agent_name: Optional[str] = None def warm_up(self) -> None: """ @@ -183,9 +182,8 @@ def to_dict(self) -> Dict[str, Any]: # We serialize the original state schema, not the resolved one to reflect the original user input state_schema=_schema_to_dict(self._state_schema), max_agent_steps=self.max_agent_steps, - streaming_callback=streaming_callback, raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure, - tool_invoker_kwargs=self.tool_invoker_kwargs, + streaming_callback=streaming_callback, ) @classmethod @@ -229,8 +227,142 @@ def _create_agent_span(self) -> Any: }, ) - def run( - self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any + def _validate_tool_breakpoint_is_valid(self, agent_breakpoint: AgentBreakpoint) -> None: + """ + Validates the AgentBreakpoint passed to the agent. + + Validates that all tool names in ToolBreakpoints correspond to tools available in the agent. + + :param agent_breakpoint: AgentBreakpoint object containing breakpoints for the agent components. + :raises ValueError: If any tool name in ToolBreakpoints is not available in the agent's tools. + """ + + available_tool_names = {tool.name for tool in self.tools} + tool_breakpoint = agent_breakpoint.break_point + if tool_breakpoint.tool_name is not None and tool_breakpoint.tool_name not in available_tool_names: # type: ignore # was checked outside function + raise ValueError(f"Tool '{tool_breakpoint.tool_name}' is not available in the agent's tools") # type: ignore # was checked outside function + + def _check_chat_generator_breakpoint( # pylint: disable=too-many-positional-arguments + self, + agent_breakpoint: Optional[AgentBreakpoint], + component_visits: Dict[str, int], + messages: List[ChatMessage], + generator_inputs: Dict[str, Any], + debug_path: Optional[Union[str, Path]], + kwargs: Dict[str, Any], + state: State, + ) -> None: + """ + Check for breakpoint before calling the ChatGenerator. + + :param agent_breakpoint: AgentBreakpoint object containing breakpoints + :param component_visits: Dictionary tracking component visit counts + :param messages: Current messages to process + :param generator_inputs: Inputs for the chat generator + :param debug_path: Path for saving debug state + :param kwargs: Additional keyword arguments + :param state: Current agent state + :raises AgentBreakpointException: If a breakpoint is triggered + """ + + if agent_breakpoint and isinstance(agent_breakpoint.break_point, Breakpoint): + break_point = agent_breakpoint.break_point + if component_visits[break_point.component_name] == break_point.visit_count: + state_inputs = deepcopy({"messages": messages, **generator_inputs}) + _save_state( + inputs=state_inputs, + component_name=break_point.component_name, + component_visits=component_visits, # these are the component visits of the agent components + debug_path=debug_path, + original_input_data={"messages": messages, **kwargs}, + ordered_component_names=["chat_generator", "tool_invoker"], + agent_name=self._agent_name, + main_pipeline_state=state.data.get("main_pipeline_state", {}), + ) + msg = ( + f"Breaking at {break_point.component_name} visit count " + f"{component_visits[break_point.component_name]}" + ) + logger.info(msg) + raise BreakpointException( + message=msg, component=break_point.component_name, state=state_inputs, results=state.data + ) + + def _check_tool_invoker_breakpoint( # pylint: disable=too-many-positional-arguments + self, + agent_breakpoint: Optional[AgentBreakpoint], + component_visits: Dict[str, int], + llm_messages: List[ChatMessage], + streaming_callback: Optional[StreamingCallbackT], + debug_path: Optional[Union[str, Path]], + messages: List[ChatMessage], + kwargs: Dict[str, Any], + state: State, + ) -> None: + """ + Check for breakpoint before calling the ToolInvoker. + + :param agent_breakpoint: AgentBreakpoint object containing breakpoints + :param component_visits: Dictionary tracking component visit counts + :param llm_messages: Messages from the LLM + :param state: Current agent state + :param streaming_callback: Streaming callback function + :param debug_path: Path for saving debug state + :param messages: Original messages + :param kwargs: Additional keyword arguments + :raises AgentBreakpointException: If a breakpoint is triggered + """ + + if agent_breakpoint and isinstance(agent_breakpoint.break_point, ToolBreakpoint): + tool_breakpoint = agent_breakpoint.break_point + # Check if the visit count matches + if component_visits[tool_breakpoint.component_name] == tool_breakpoint.visit_count: + # Check if we should break for this specific tool or all tools + should_break = False + if tool_breakpoint.tool_name is None: + # Break for any tool call + should_break = any(msg.tool_call for msg in llm_messages) + else: + # Break only for the specific tool + should_break = any( + msg.tool_call and msg.tool_call.tool_name == tool_breakpoint.tool_name for msg in llm_messages + ) + + if should_break: + state_inputs = deepcopy( + {"messages": llm_messages, "state": state, "streaming_callback": streaming_callback} + ) + _save_state( + inputs=state_inputs, + component_name=tool_breakpoint.component_name, + component_visits=component_visits, + debug_path=debug_path, + original_input_data={"messages": messages, **kwargs}, + ordered_component_names=["chat_generator", "tool_invoker"], + agent_name=self._agent_name, + main_pipeline_state=state.data.get("main_pipeline_state", {}), + ) + msg = ( + f"Breaking at {tool_breakpoint.component_name} visit count " + f"{component_visits[tool_breakpoint.component_name]}" + ) + if tool_breakpoint.tool_name: + msg += f" for tool {tool_breakpoint.tool_name}" + logger.info(msg) + + raise BreakpointException( + message=msg, component=tool_breakpoint.component_name, state=state_inputs, results=state.data + ) + + def run( # noqa: PLR0915 + self, + messages: List[ChatMessage], + streaming_callback: Optional[StreamingCallbackT] = None, + *, + break_point: Optional[AgentBreakpoint] = None, + resume_state: Optional[Dict[str, Any]] = None, + debug_path: Optional[Union[str, Path]] = None, + **kwargs: Any, ) -> Dict[str, Any]: """ Process messages and execute tools until an exit condition is met. @@ -239,6 +371,10 @@ def run( If a list of dictionaries is provided, each dictionary will be converted to a ChatMessage object. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint + for "tool_invoker". + :param resume_state: A dictionary containing the state of a previously saved agent execution. + :param debug_path: Path to the directory where the agent state should be saved. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :returns: @@ -247,10 +383,43 @@ def run( - "last_message": The last message exchanged during the agent's run. - Any additional keys defined in the `state_schema`. :raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`. + :raises AgentBreakpointException: If an agent breakpoint is triggered. + """ if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"): raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run()'.") + if break_point and resume_state: + msg = ( + "agent_breakpoint and resume_state cannot be provided at the same time. The agent run will be aborted." + ) + raise ValueError(msg) + + self._agent_name = self.__component_name__ if hasattr(self, "__component_name__") else "isolated_agent" + + # validate breakpoints + if break_point and isinstance(break_point.break_point, ToolBreakpoint): + self._validate_tool_breakpoint_is_valid(break_point) + + # resume state if provided + if resume_state: + component_visits = resume_state.get("pipeline_state", {}).get("component_visits", {}) + state_data = resume_state.get("pipeline_state", {}).get("inputs", {}).get("state", {}).get("data", {}) + state = State(schema=self.state_schema, data=state_data) + + # deserialize messages from pipeline state + raw_messages = resume_state.get("pipeline_state", {}).get("inputs", {}).get("messages", messages) + + # convert raw message dictionaries to ChatMessage objects and populate the state + messages = [ChatMessage.from_dict(msg) if isinstance(msg, dict) else msg for msg in raw_messages] + state.set("messages", messages) + + else: + # initialize new state if not resuming + state = State(schema=self.state_schema, data=kwargs) + state.set("messages", messages) + component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) + if self.system_prompt is not None: messages = [ChatMessage.from_system(self.system_prompt)] + messages @@ -262,7 +431,6 @@ def run( state = State(schema=self.state_schema, data=kwargs) state.set("messages", messages) - component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False @@ -274,7 +442,16 @@ def run( _deepcopy_with_exceptions({"messages": messages, "streaming_callback": streaming_callback, **kwargs}), ) counter = 0 + + if break_point and self._agent_name is None: + raise ValueError("When using breakpoints, the agent_name must be provided to save the state correctly.") + while counter < self.max_agent_steps: + # check for breakpoint before ChatGenerator + self._check_chat_generator_breakpoint( + break_point, component_visits, messages, generator_inputs, debug_path, kwargs, state + ) + # 1. Call the ChatGenerator result = Pipeline._run_component( component_name="chat_generator", @@ -291,6 +468,11 @@ def run( counter += 1 break + # check for breakpoint before ToolInvoker + self._check_tool_invoker_breakpoint( + break_point, component_visits, llm_messages, streaming_callback, debug_path, messages, kwargs, state + ) + # 3. Call the ToolInvoker # We only send the messages from the LLM to the tool invoker tool_invoker_result = Pipeline._run_component( @@ -327,8 +509,15 @@ def run( result.update({"last_message": all_messages[-1]}) return result - async def run_async( - self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any + async def run_async( # noqa: PLR0915 + self, + messages: List[ChatMessage], + streaming_callback: Optional[StreamingCallbackT] = None, + *, + break_point: Optional[AgentBreakpoint] = None, + resume_state: Optional[Dict[str, Any]] = None, + debug_path: Optional[Union[str, Path]] = None, + **kwargs: Any, ) -> Dict[str, Any]: """ Asynchronously process messages and execute tools until the exit condition is met. @@ -339,8 +528,11 @@ async def run_async( :param messages: List of chat messages to process :param streaming_callback: An asynchronous callback that will be invoked when a response - is streamed from the LLM. The same callback can be configured to emit tool results - when a tool is called. + is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint + for "tool_invoker". + :param resume_state: A dictionary containing the state of a previously saved agent execution. + :param debug_path: Path to the directory where the agent state should be saved. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :returns: @@ -349,10 +541,46 @@ async def run_async( - "last_message": The last message exchanged during the agent's run. - Any additional keys defined in the `state_schema`. :raises RuntimeError: If the Agent component wasn't warmed up before calling `run_async()`. + :raises AgentBreakpointException: If an agent breakpoint is triggered. + """ if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"): raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run_async()'.") + if break_point and resume_state: + msg = ( + "agent_breakpoint and resume_state cannot be provided at the same time. The agent run will be aborted." + ) + raise ValueError(msg) + + self._agent_name = self.__component_name__ if hasattr(self, "__component_name__") else "isolated_agent" + + # validate breakpoints + if break_point and isinstance(break_point.break_point, ToolBreakpoint): + self._validate_tool_breakpoint_is_valid(break_point) + + # Handle resume state if provided + if resume_state: + # Extract component visits from pipeline state + component_visits = resume_state.get("pipeline_state", {}).get("component_visits", {}) + # Initialize with default values if not present in resume state + component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) | component_visits + + # Extract state data from pipeline state + state_data = resume_state.get("pipeline_state", {}).get("inputs", {}).get("state", {}).get("data", {}) + state = State(schema=self.state_schema, data=state_data) + + # Extract and deserialize messages from pipeline state + raw_messages = resume_state.get("pipeline_state", {}).get("inputs", {}).get("messages", messages) + # Convert raw message dictionaries to ChatMessage objects + messages = [ChatMessage.from_dict(msg) if isinstance(msg, dict) else msg for msg in raw_messages] + state.set("messages", messages) + else: + # Initialize new state if not resuming + state = State(schema=self.state_schema, data=kwargs) + state.set("messages", messages) + component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) + if self.system_prompt is not None: messages = [ChatMessage.from_system(self.system_prompt)] + messages @@ -364,7 +592,6 @@ async def run_async( state = State(schema=self.state_schema, data=kwargs) state.set("messages", messages) - component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True @@ -376,7 +603,16 @@ async def run_async( _deepcopy_with_exceptions({"messages": messages, "streaming_callback": streaming_callback, **kwargs}), ) counter = 0 + + if break_point and self._agent_name is None: + raise ValueError("When using breakpoints, the agent_name must be provided to save the state correctly.") + while counter < self.max_agent_steps: + # Check for breakpoint before ChatGenerator + self._check_chat_generator_breakpoint( + break_point, component_visits, messages, generator_inputs, debug_path, kwargs, state + ) + # 1. Call the ChatGenerator result = await AsyncPipeline._run_component_async( component_name="chat_generator", @@ -394,9 +630,13 @@ async def run_async( counter += 1 break + # Check for breakpoint before ToolInvoker + self._check_tool_invoker_breakpoint( + break_point, component_visits, llm_messages, streaming_callback, debug_path, messages, kwargs, state + ) + # 3. Call the ToolInvoker # We only send the messages from the LLM to the tool invoker - # Check if the ToolInvoker supports async execution. Currently, it doesn't. tool_invoker_result = await AsyncPipeline._run_component_async( component_name="tool_invoker", component={"instance": self._tool_invoker}, diff --git a/haystack/core/errors.py b/haystack/core/errors.py index 04c4ccc864..9de137da1f 100644 --- a/haystack/core/errors.py +++ b/haystack/core/errors.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Optional, Type +from typing import Any, Dict, Optional, Type class PipelineError(Exception): @@ -89,3 +89,30 @@ class DeserializationError(Exception): class SerializationError(Exception): pass + + +class BreakpointException(Exception): + """ + Exception raised when a pipeline breakpoint is triggered. + """ + + def __init__( + self, + message: str, + component: Optional[str] = None, + state: Optional[Dict[str, Any]] = None, + results: Optional[Dict[str, Any]] = None, + ): + super().__init__(message) + self.component = component + self.state = state + self.results = results + + +class PipelineInvalidResumeStateError(Exception): + """ + Exception raised when a pipeline is resumed from an invalid state. + """ + + def __init__(self, message: str): + super().__init__(message) diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index 138e666409..f77e458adb 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -1064,7 +1064,9 @@ def _convert_to_internal_format(pipeline_inputs: Dict[str, Any]) -> Dict[str, Di return inputs @staticmethod - def _consume_component_inputs(component_name: str, component: Dict, inputs: Dict) -> Dict[str, Any]: + def _consume_component_inputs( + component_name: str, component: Dict, inputs: Dict, is_resume: bool = False + ) -> Dict[str, Any]: """ Extracts the inputs needed to run for the component and removes them from the global inputs state. @@ -1079,6 +1081,11 @@ def _consume_component_inputs(component_name: str, component: Dict, inputs: Dict for socket_name, socket in component["input_sockets"].items(): socket_inputs = component_inputs.get(socket_name, []) socket_inputs = [sock["value"] for sock in socket_inputs if sock["value"] is not _NO_OUTPUT_PRODUCED] + + # if we are resuming a component, the inputs are already consumed, so we just return the first input + if is_resume: + consumed_inputs[socket_name] = socket_inputs[0] + continue if socket_inputs: if not socket.is_variadic: # We only care about the first input provided to the socket. diff --git a/haystack/core/pipeline/breakpoint.py b/haystack/core/pipeline/breakpoint.py new file mode 100644 index 0000000000..bc1a4adf5b --- /dev/null +++ b/haystack/core/pipeline/breakpoint.py @@ -0,0 +1,413 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=too-many-return-statements, too-many-positional-arguments + +import json +from copy import deepcopy +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from networkx import MultiDiGraph + +from haystack import logging +from haystack.core.errors import BreakpointException, PipelineInvalidResumeStateError +from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint +from haystack.utils.base_serialization import _serialize_value_with_schema + +logger = logging.getLogger(__name__) + + +def _validate_break_point(break_point: Union[Breakpoint, AgentBreakpoint], graph: MultiDiGraph) -> None: + """ + Validates the breakpoints passed to the pipeline. + + Makes sure the breakpoint contains a valid components registered in the pipeline. + + :param break_point: a breakpoint to validate, can be Breakpoint or AgentBreakpoint + """ + + # all Breakpoints must refer to a valid component in the pipeline + if isinstance(break_point, Breakpoint) and break_point.component_name not in graph.nodes: + raise ValueError(f"pipeline_breakpoint {break_point} is not a registered component in the pipeline") + + if isinstance(break_point, AgentBreakpoint): + breakpoint_agent_component = graph.nodes.get(break_point.agent_name) + if not breakpoint_agent_component: + raise ValueError(f"pipeline_breakpoint {break_point} is not a registered Agent component in the pipeline") + + if isinstance(break_point.break_point, ToolBreakpoint): + instance = breakpoint_agent_component["instance"] + for tool in instance.tools: + if break_point.break_point.tool_name == tool.name: + break + else: + raise ValueError( + f"pipeline_breakpoint {break_point.break_point} is not a registered tool in the Agent component" + ) + + +def _validate_components_against_pipeline(resume_state: Dict[str, Any], graph: MultiDiGraph) -> None: + """ + Validates that the resume_state contains valid configuration for the current pipeline. + + Raises a PipelineInvalidResumeStateError if any component in resume_state is not part of the target pipeline. + + :param resume_state: The saved state to validate. + """ + + pipeline_state = resume_state["pipeline_state"] + valid_components = set(graph.nodes.keys()) + + # Check if the ordered_component_names are valid components in the pipeline + invalid_ordered_components = set(pipeline_state["ordered_component_names"]) - valid_components + if invalid_ordered_components: + raise PipelineInvalidResumeStateError( + f"Invalid resume state: components {invalid_ordered_components} in 'ordered_component_names' " + f"are not part of the current pipeline." + ) + + # Check if the input_data is valid components in the pipeline + serialized_input_data = resume_state["input_data"]["serialized_data"] + invalid_input_data = set(serialized_input_data.keys()) - valid_components + if invalid_input_data: + raise PipelineInvalidResumeStateError( + f"Invalid resume state: components {invalid_input_data} in 'input_data' " + f"are not part of the current pipeline." + ) + + # Validate 'component_visits' + invalid_component_visits = set(pipeline_state["component_visits"].keys()) - valid_components + if invalid_component_visits: + raise PipelineInvalidResumeStateError( + f"Invalid resume state: components {invalid_component_visits} in 'component_visits' " + f"are not part of the current pipeline." + ) + + logger.info( + f"Resuming pipeline from component: {resume_state['pipeline_breakpoint']['component']} " + f"(visit {resume_state['pipeline_breakpoint']['visits']})" + ) + + +def _validate_resume_state(resume_state: Dict[str, Any]) -> None: + """ + Validates the loaded pipeline resume_state. + + Ensures that the resume_state contains required keys: "input_data", "pipeline_breakpoint", and "pipeline_state". + + Raises: + ValueError: If required keys are missing or the component sets are inconsistent. + """ + + # top-level state has all required keys + required_top_keys = {"input_data", "pipeline_breakpoint", "pipeline_state"} + missing_top = required_top_keys - resume_state.keys() + if missing_top: + raise ValueError(f"Invalid state file: missing required keys {missing_top}") + + # pipeline_state has the necessary keys + pipeline_state = resume_state["pipeline_state"] + + required_pipeline_keys = {"inputs", "component_visits", "ordered_component_names"} + missing_pipeline = required_pipeline_keys - pipeline_state.keys() + if missing_pipeline: + raise ValueError(f"Invalid pipeline_state: missing required keys {missing_pipeline}") + + # component_visits and ordered_component_names must be consistent + components_in_state = set(pipeline_state["component_visits"].keys()) + components_in_order = set(pipeline_state["ordered_component_names"]) + + if components_in_state != components_in_order: + raise ValueError( + f"Inconsistent state: components in pipeline_state['component_visits'] {components_in_state} " + f"do not match components in ordered_component_names {components_in_order}" + ) + + logger.info("Passed resume state validated successfully.") + + +def load_state(file_path: Union[str, Path]) -> Dict[str, Any]: + """ + Load a saved pipeline state. + + :param file_path: Path to the resume_state file. + :returns: + Dict containing the loaded resume_state. + """ + + file_path = Path(file_path) + + try: + with open(file_path, "r", encoding="utf-8") as f: + state = json.load(f) + except FileNotFoundError: + raise FileNotFoundError(f"File not found: {file_path}") + except json.JSONDecodeError as e: + raise json.JSONDecodeError(f"Invalid JSON file {file_path}: {str(e)}", e.doc, e.pos) + except IOError as e: + raise IOError(f"Error reading {file_path}: {str(e)}") + + try: + _validate_resume_state(resume_state=state) + except ValueError as e: + raise ValueError(f"Invalid pipeline state from {file_path}: {str(e)}") + + logger.info(f"Successfully loaded pipeline state from: {file_path}") + return state + + +def _process_main_pipeline_state(main_pipeline_state: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """ + Process and serialize main pipeline state for agent breakpoints. + + :param main_pipeline_state: Dictionary containing main pipeline state with keys: "component_visits", + "ordered_component_names", "original_input_data", and "inputs". + :returns: Processed main pipeline state or None if not available or invalid. + """ + if not main_pipeline_state: + return None + + original_input_data = main_pipeline_state.get("original_input_data") + inputs = main_pipeline_state.get("inputs") + + if not (original_input_data and inputs): + return None + + return { + "component_visits": main_pipeline_state.get("component_visits"), + "ordered_component_names": main_pipeline_state.get("ordered_component_names"), + "original_input_data": _serialize_value_with_schema(_transform_json_structure(original_input_data)), + "inputs": _serialize_value_with_schema(_transform_json_structure(inputs)), + } + + +def _save_state_to_file( + state: Dict[str, Any], + debug_path: Union[str, Path], + dt: datetime, + is_agent: bool, + agent_name: Optional[str], + component_name: str, +) -> None: + """ + Save state dictionary to a JSON file. + + :param state: The state dictionary to save. + :param debug_path: The path where to save the file. + :param dt: The datetime object for timestamping. + :param is_agent: Whether this is an agent pipeline. + :param agent_name: Name of the agent (if applicable). + :param component_name: Name of the component that triggered the breakpoint. + :raises: + ValueError: If the debug_path is not a string or a Path object. + Exception: If saving the JSON state fails. + """ + debug_path = Path(debug_path) if isinstance(debug_path, str) else debug_path + if not isinstance(debug_path, Path): + raise ValueError("Debug path must be a string or a Path object.") + + debug_path.mkdir(exist_ok=True) + + # Generate filename + if is_agent: + file_name = f"{agent_name}_{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json" + else: + file_name = f"{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json" + + try: + with open(debug_path / file_name, "w") as f_out: + json.dump(state, f_out, indent=2) + logger.info(f"Pipeline state saved at: {file_name}") + except Exception as e: + logger.error(f"Failed to save pipeline state: {str(e)}") + raise + + +def _save_state( + inputs: Dict[str, Any], + component_name: str, + component_visits: Dict[str, int], + debug_path: Optional[Union[str, Path]] = None, + original_input_data: Optional[Dict[str, Any]] = None, + ordered_component_names: Optional[List[str]] = None, + agent_name: Optional[str] = None, + main_pipeline_state: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """ + Save the pipeline state to a file. + + :param inputs: The current pipeline state inputs. + :param component_name: The name of the component that triggered the breakpoint. + :param component_visits: The visit count of the component that triggered the breakpoint. + :param debug_path: The path to save the state to. + :param original_input_data: The original input data. + :param ordered_component_names: The ordered component names. + :param main_pipeline_state: Dictionary containing main pipeline state with keys: "component_visits", + "ordered_component_names", "original_input_data", and "inputs". + + :returns: + The dictionary containing the state of the pipeline containing the following keys: + - input_data: The original input data passed to the pipeline. + - timestamp: The timestamp of the breakpoint. + - pipeline_breakpoint: The component name and visit count that triggered the breakpoint. + - pipeline_state: The state of the pipeline when the breakpoint was triggered containing the following keys: + - inputs: The current state of inputs for pipeline components. + - component_visits: The visit count of the components when the breakpoint was triggered. + - ordered_component_names: The order of components in the pipeline. + """ + dt = datetime.now() + + # remove duplicated information + if original_input_data: + original_input_data.pop("main_pipeline_state", None) + + transformed_original_input_data = _transform_json_structure(original_input_data) + transformed_inputs = _transform_json_structure(inputs) + + state = { + # related to the main pipeline where the agent running as a breakpoint - only used with AgentBreakpoint + "agent_name": agent_name if agent_name else None, + "main_pipeline_state": _process_main_pipeline_state(main_pipeline_state) if agent_name else None, + # breakpoint - information for the component that triggered the breakpoint, can also be an Agent + "component_name": component_name, + "input_data": _serialize_value_with_schema(transformed_original_input_data), # original input data + "timestamp": dt.isoformat(), + "pipeline_breakpoint": {"component": component_name, "visits": component_visits[component_name]}, + "pipeline_state": { + "inputs": _serialize_value_with_schema(transformed_inputs), # current pipeline state inputs + "component_visits": component_visits, + "ordered_component_names": ordered_component_names, + }, + } + + if not debug_path: + return state + + is_agent = agent_name is not None + _save_state_to_file(state, debug_path, dt, is_agent, agent_name, component_name) + + return state + + +def _transform_json_structure(data: Union[Dict[str, Any], List[Any], Any]) -> Any: + """ + Transforms a JSON structure by removing the 'sender' key and moving the 'value' to the top level. + + For example: + "key": [{"sender": null, "value": "some value"}] -> "key": "some value" + + :param data: The JSON structure to transform. + :returns: The transformed structure. + """ + if isinstance(data, dict): + # If this dict has both 'sender' and 'value', return just the value + if "value" in data and "sender" in data: + return data["value"] + # Otherwise, recursively process each key-value pair + return {k: _transform_json_structure(v) for k, v in data.items()} + + if isinstance(data, list): + # First, transform each item in the list. + transformed = [_transform_json_structure(item) for item in data] + # If the original list has exactly one element and that element was a dict + # with 'sender' and 'value', then unwrap the list. + if len(data) == 1 and isinstance(data[0], dict) and "value" in data[0] and "sender" in data[0]: + return transformed[0] + return transformed + + # For other data types, just return the value as is. + return data + + +def handle_agent_break_point( + break_point: AgentBreakpoint, + component_name: str, + component_inputs: Dict[str, Any], + inputs: Dict[str, Any], + component_visits: Dict[str, int], + ordered_component_names: list, + data: Dict[str, Any], + debug_path: Optional[Union[str, Path]], +) -> Dict[str, Any]: + """ + Handle agent-specific breakpoint logic. + + :param break_point: The agent breakpoint to handle + :param component_name: Name of the current component + :param component_inputs: Inputs for the current component + :param inputs: Global pipeline inputs + :param component_visits: Component visit counts + :param ordered_component_names: Ordered list of component names + :param data: Original pipeline data + :param debug_path: Path for debug files + :return: Updated component inputs + """ + component_inputs["break_point"] = break_point + component_inputs["debug_path"] = debug_path + + # Store pipeline state for agent resume + state_inputs_serialised = deepcopy(inputs) + state_inputs_serialised[component_name] = deepcopy(component_inputs) + component_inputs["main_pipeline_state"] = { + "inputs": state_inputs_serialised, + "component_visits": component_visits, + "ordered_component_names": ordered_component_names, + "original_input_data": data, + } + + return component_inputs + + +def check_regular_break_point(break_point: Breakpoint, component_name: str, component_visits: Dict[str, int]) -> bool: + """ + Check if a regular breakpoint should be triggered. + + :param break_point: The breakpoint to check + :param component_name: Name of the current component + :param component_visits: Component visit counts + :return: True if breakpoint should be triggered + """ + return break_point.component_name == component_name and break_point.visit_count == component_visits[component_name] + + +def trigger_break_point( + component_name: str, + component_inputs: Dict[str, Any], + inputs: Dict[str, Any], + component_visits: Dict[str, int], + debug_path: Optional[Union[str, Path]], + data: Dict[str, Any], + ordered_component_names: list, + pipeline_outputs: Dict[str, Any], +) -> None: + """ + Trigger a breakpoint by saving state and raising exception. + + :param component_name: Name of the component where breakpoint is triggered + :param component_inputs: Inputs for the current component + :param inputs: Global pipeline inputs + :param component_visits: Component visit counts + :param debug_path: Path for debug files + :param data: Original pipeline data + :param ordered_component_names: Ordered list of component names + :param pipeline_outputs: Current pipeline outputs + :raises PipelineBreakpointException: When breakpoint is triggered + """ + state_inputs_serialised = deepcopy(inputs) + state_inputs_serialised[component_name] = deepcopy(component_inputs) + _save_state( + inputs=state_inputs_serialised, + component_name=str(component_name), + component_visits=component_visits, + debug_path=debug_path, + original_input_data=data, + ordered_component_names=ordered_component_names, + ) + + msg = f"Breaking at component {component_name} at visit count {component_visits[component_name]}" + raise BreakpointException( + message=msg, component=component_name, state=state_inputs_serialised, results=pipeline_outputs + ) diff --git a/haystack/core/pipeline/delete_base_experimental.py b/haystack/core/pipeline/delete_base_experimental.py new file mode 100644 index 0000000000..94ea2f749e --- /dev/null +++ b/haystack/core/pipeline/delete_base_experimental.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict + +from haystack.core.pipeline.base import PipelineBase as HaystackPipelineBase +from haystack.core.pipeline.component_checks import _NO_OUTPUT_PRODUCED, is_socket_lazy_variadic + + +class PipelineBase(HaystackPipelineBase): + @staticmethod + def _consume_component_inputs( + component_name: str, component: Dict, inputs: Dict, is_resume: bool = False + ) -> Dict[str, Any]: + """ + Extracts the inputs needed to run for the component and removes them from the global inputs state. + + :param component_name: The name of a component. + :param component: Component with component metadata. + :param inputs: Global inputs state. + :returns: The inputs for the component. + """ + component_inputs = inputs.get(component_name, {}) + consumed_inputs = {} + greedy_inputs_to_remove = set() + for socket_name, socket in component["input_sockets"].items(): + socket_inputs = component_inputs.get(socket_name, []) + socket_inputs = [sock["value"] for sock in socket_inputs if sock["value"] is not _NO_OUTPUT_PRODUCED] + + # if we are resuming a component, the inputs are already consumed, so we just return the first input + if is_resume: + consumed_inputs[socket_name] = socket_inputs[0] + continue + if socket_inputs: + if not socket.is_variadic: + # We only care about the first input provided to the socket. + consumed_inputs[socket_name] = socket_inputs[0] + elif socket.is_greedy: + # We need to keep track of greedy inputs because we always remove them, even if they come from + # outside the pipeline. Otherwise, a greedy input from the user would trigger a pipeline to run + # indefinitely. + greedy_inputs_to_remove.add(socket_name) + consumed_inputs[socket_name] = [socket_inputs[0]] + elif is_socket_lazy_variadic(socket): + # We use all inputs provided to the socket on a lazy variadic socket. + consumed_inputs[socket_name] = socket_inputs + + # We prune all inputs except for those that were provided from outside the pipeline (e.g. user inputs). + pruned_inputs = { + socket_name: [ + sock for sock in socket if sock["sender"] is None and not socket_name in greedy_inputs_to_remove + ] + for socket_name, socket in component_inputs.items() + } + pruned_inputs = {socket_name: socket for socket_name, socket in pruned_inputs.items() if len(socket) > 0} + + inputs[component_name] = pruned_inputs + + return consumed_inputs diff --git a/haystack/core/pipeline/delete_pipeline_experimental.py b/haystack/core/pipeline/delete_pipeline_experimental.py new file mode 100644 index 0000000000..78d693b48d --- /dev/null +++ b/haystack/core/pipeline/delete_pipeline_experimental.py @@ -0,0 +1,393 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=too-many-return-statements, too-many-positional-arguments + + +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, Optional, Set, Union + +from haystack_experimental.core.errors import PipelineInvalidResumeStateError +from haystack_experimental.core.pipeline.base import PipelineBase + +from haystack import logging, tracing +from haystack.core.pipeline.base import ComponentPriority +from haystack.core.pipeline.pipeline import Pipeline as HaystackPipeline +from haystack.telemetry import pipeline_running +from haystack.utils import _deserialize_value_with_schema + +from ...components.agents import Agent +from ...dataclasses.breakpoints import AgentBreakpoint, Breakpoint +from .breakpoint import ( + _validate_break_point, + _validate_components_against_pipeline, + check_regular_break_point, + handle_agent_break_point, + trigger_break_point, +) + +logger = logging.getLogger(__name__) + + +# We inherit from both HaystackPipeline and PipelineBase to ensure that we have the +# necessary methods and properties from both classes. +class Pipeline(HaystackPipeline, PipelineBase): + """ + Synchronous version of the orchestration engine. + + Orchestrates component execution according to the execution graph, one after the other. + """ + + def _handle_resume_state(self, resume_state: Dict[str, Any]) -> tuple[Dict[str, int], Dict[str, Any], bool, list]: + """ + Handle resume state initialization. + + :param resume_state: The resume state to handle + :return: Tuple of (component_visits, data, resume_agent_in_pipeline, ordered_component_names) + """ + if resume_state.get("agent_name"): + return self._handle_agent_resume_state(resume_state) + else: + return self._handle_regular_resume_state(resume_state) + + def _handle_agent_resume_state( + self, resume_state: Dict[str, Any] + ) -> tuple[Dict[str, int], Dict[str, Any], bool, list]: + """ + Handle agent-specific resume state. + + :param resume_state: The resume state to handle + :return: Tuple of (component_visits, data, resume_agent_in_pipeline, ordered_component_names) + """ + agent_name = resume_state["agent_name"] + for name, component in self.graph.nodes.items(): + if component["instance"].__class__.__name__ == "Agent" and name == agent_name: + main_pipeline_state = resume_state.get("main_pipeline_state", {}) + component_visits = main_pipeline_state.get("component_visits", {}) + ordered_component_names = main_pipeline_state.get("ordered_component_names", []) + data = _deserialize_value_with_schema(main_pipeline_state.get("inputs", {})) + return component_visits, data, True, ordered_component_names + + # Fallback to regular resume if agent not found + return self._handle_regular_resume_state(resume_state) + + def _handle_regular_resume_state( + self, resume_state: Dict[str, Any] + ) -> tuple[Dict[str, int], Dict[str, Any], bool, list]: + """ + Handle regular component resume state. + + :param resume_state: The resume state to handle + :return: Tuple of (component_visits, data, resume_agent_in_pipeline, ordered_component_names) + """ + component_visits, data, resume_state, ordered_component_names = self.inject_resume_state_into_graph( + resume_state=resume_state + ) + data = _deserialize_value_with_schema(resume_state["pipeline_state"]["inputs"]) + return component_visits, data, False, ordered_component_names + + def run( # noqa: PLR0915, PLR0912 + self, + data: Dict[str, Any], + include_outputs_from: Optional[Set[str]] = None, + break_point: Optional[Union[Breakpoint, AgentBreakpoint]] = None, + resume_state: Optional[Dict[str, Any]] = None, + debug_path: Optional[Union[str, Path]] = None, + ) -> Dict[str, Any]: + """ + Runs the Pipeline with given input data. + + Usage: + ```python + from haystack import Pipeline, Document + from haystack.utils import Secret + from haystack.document_stores.in_memory import InMemoryDocumentStore + from haystack.components.retrievers.in_memory import InMemoryBM25Retriever + from haystack.components.generators import OpenAIGenerator + from haystack.components.builders.answer_builder import AnswerBuilder + from haystack.components.builders.prompt_builder import PromptBuilder + + # Write documents to InMemoryDocumentStore + document_store = InMemoryDocumentStore() + document_store.write_documents([ + Document(content="My name is Jean and I live in Paris."), + Document(content="My name is Mark and I live in Berlin."), + Document(content="My name is Giorgio and I live in Rome.") + ]) + + prompt_template = \"\"\" + Given these documents, answer the question. + Documents: + {% for doc in documents %} + {{ doc.content }} + {% endfor %} + Question: {{question}} + Answer: + \"\"\" + + retriever = InMemoryBM25Retriever(document_store=document_store) + prompt_builder = PromptBuilder(template=prompt_template) + llm = OpenAIGenerator(api_key=Secret.from_token(api_key)) + + rag_pipeline = Pipeline() + rag_pipeline.add_component("retriever", retriever) + rag_pipeline.add_component("prompt_builder", prompt_builder) + rag_pipeline.add_component("llm", llm) + rag_pipeline.connect("retriever", "prompt_builder.documents") + rag_pipeline.connect("prompt_builder", "llm") + + # Ask a question + question = "Who lives in Paris?" + results = rag_pipeline.run( + { + "retriever": {"query": question}, + "prompt_builder": {"question": question}, + } + ) + + print(results["llm"]["replies"]) + # Jean lives in Paris + ``` + + :param data: + A dictionary of inputs for the pipeline's components. Each key is a component name + and its value is a dictionary of that component's input parameters: + ``` + data = { + "comp1": {"input1": 1, "input2": 2}, + } + ``` + For convenience, this format is also supported when input names are unique: + ``` + data = { + "input1": 1, "input2": 2, + } + ``` + :param include_outputs_from: + Set of component names whose individual outputs are to be + included in the pipeline's output. For components that are + invoked multiple times (in a loop), only the last-produced + output is included. + + :param break_point: + A set of breakpoints that can be used to debug the pipeline execution. + + :param resume_state: + A dictionary containing the state of a previously saved pipeline execution. + + :param debug_path: + Path to the directory where the pipeline state should be saved. + + :returns: + A dictionary where each entry corresponds to a component name + and its output. If `include_outputs_from` is `None`, this dictionary + will only contain the outputs of leaf components, i.e., components + without outgoing connections. + + :raises ValueError: + If invalid inputs are provided to the pipeline. + :raises PipelineRuntimeError: + If the Pipeline contains cycles with unsupported connections that would cause + it to get stuck and fail running. + Or if a Component fails or returns output in an unsupported type. + :raises PipelineMaxComponentRuns: + If a Component reaches the maximum number of times it can be run in this Pipeline. + :raises PipelineBreakpointException: + When a pipeline_breakpoint is triggered. Contains the component name, state, and partial results. + """ + pipeline_running(self) + + if break_point and resume_state: + msg = ( + "pipeline_breakpoint and resume_state cannot be provided at the same time. " + "The pipeline run will be aborted." + ) + raise PipelineInvalidResumeStateError(message=msg) + + # make sure all breakpoints are valid, i.e. reference components in the pipeline + if break_point: + _validate_break_point(break_point, self.graph) + + # TODO: Remove this warmup once we can check reliably whether a component has been warmed up or not + # As of now it's here to make sure we don't have failing tests that assume warm_up() is called in run() + self.warm_up() + + if include_outputs_from is None: + include_outputs_from = set() + + if not resume_state: + # normalize `data` + data = self._prepare_component_input_data(data) + + # Raise ValueError if input is malformed in some way + self.validate_input(data) + + # We create a list of components in the pipeline sorted by name, so that the algorithm runs + # deterministically and independent of insertion order into the pipeline. + ordered_component_names = sorted(self.graph.nodes.keys()) + + # We track component visits to decide if a component can run. + component_visits = dict.fromkeys(ordered_component_names, 0) + resume_agent_in_pipeline = False + + else: + # Handle resume state + component_visits, data, resume_agent_in_pipeline, ordered_component_names = self._handle_resume_state( + resume_state + ) + + cached_topological_sort = None + # We need to access a component's receivers multiple times during a pipeline run. + # We store them here for easy access. + cached_receivers = {name: self._find_receivers_from(name) for name in ordered_component_names} + + pipeline_outputs: Dict[str, Any] = {} + with tracing.tracer.trace( + "haystack.pipeline.run", + tags={ + "haystack.pipeline.input_data": data, + "haystack.pipeline.output_data": pipeline_outputs, + "haystack.pipeline.metadata": self.metadata, + "haystack.pipeline.max_runs_per_component": self._max_runs_per_component, + }, + ) as span: + inputs = self._convert_to_internal_format(pipeline_inputs=data) + priority_queue = self._fill_queue(ordered_component_names, inputs, component_visits) + + # check if pipeline is blocked before execution + self.validate_pipeline(priority_queue) + + while True: + candidate = self._get_next_runnable_component(priority_queue, component_visits) + if candidate is None: + break + + priority, component_name, component = candidate + + if len(priority_queue) > 0 and priority in [ComponentPriority.DEFER, ComponentPriority.DEFER_LAST]: + component_name, topological_sort = self._tiebreak_waiting_components( + component_name=component_name, + priority=priority, + priority_queue=priority_queue, + topological_sort=cached_topological_sort, + ) + + cached_topological_sort = topological_sort + component = self._get_component_with_graph_metadata_and_visits( + component_name, component_visits[component_name] + ) + + is_resume = bool(resume_state and resume_state["pipeline_breakpoint"]["component"] == component_name) + component_inputs = self._consume_component_inputs( + component_name=component_name, component=component, inputs=inputs, is_resume=is_resume + ) + + # We need to add missing defaults using default values from input sockets because the run signature + # might not provide these defaults for components with inputs defined dynamically upon component + # initialization + component_inputs = self._add_missing_input_defaults(component_inputs, component["input_sockets"]) + + # Scenario 1: Resume state is provided to resume the pipeline at a specific component + # Deserialize the component_inputs if they are passed in resume state + # this check will prevent other component_inputs generated at runtime from being deserialized + if resume_state and component_name in resume_state["pipeline_state"]["inputs"].keys(): + for key, value in component_inputs.items(): + component_inputs[key] = _deserialize_value_with_schema(value) + + # Scenario 2: a breakpoint is provided to stop the pipeline at a specific component and visit count + breakpoint_triggered = False + if break_point is not None: + agent_breakpoint = False + + if isinstance(break_point, AgentBreakpoint): + component_instance = component["instance"] + if isinstance(component_instance, Agent): + component_inputs = handle_agent_break_point( + break_point, + component_name, + component_inputs, + inputs, + component_visits, + ordered_component_names, + data, + debug_path, + ) + agent_breakpoint = True + + if not agent_breakpoint and isinstance(break_point, Breakpoint): + breakpoint_triggered = check_regular_break_point(break_point, component_name, component_visits) + + if breakpoint_triggered: + trigger_break_point( + component_name, + component_inputs, + inputs, + component_visits, + debug_path, + data, + ordered_component_names, + pipeline_outputs, + ) + + if resume_agent_in_pipeline: + # inject the resume_state into the component (the Agent) inputs + component_inputs["resume_state"] = resume_state + component_inputs["break_point"] = None + + component_outputs = self._run_component( + component_name=component_name, + component=component, + inputs=component_inputs, # the inputs to the current component + component_visits=component_visits, + parent_span=span, + ) + + # Updates global input state with component outputs and returns outputs that should go to + # pipeline outputs. + component_pipeline_outputs = self._write_component_outputs( + component_name=component_name, + component_outputs=component_outputs, + inputs=inputs, + receivers=cached_receivers[component_name], + include_outputs_from=include_outputs_from, + ) + + if component_pipeline_outputs: + pipeline_outputs[component_name] = deepcopy(component_pipeline_outputs) + if self._is_queue_stale(priority_queue): + priority_queue = self._fill_queue(ordered_component_names, inputs, component_visits) + + if break_point and not agent_breakpoint: + logger.warning( + "The given breakpoint {break_point} was never triggered. This is because:\n" + "1. The provided component is not a part of the pipeline execution path.\n" + "2. The component did not reach the visit count specified in the pipeline_breakpoint", + pipeline_breakpoint=break_point, + ) + + return pipeline_outputs + + def inject_resume_state_into_graph(self, resume_state): + """ + Loads the resume state from a file and injects it into the pipeline graph. + + """ + # We previously check if the resume_state is None but this is needed to prevent a typing error + if not resume_state: + raise PipelineInvalidResumeStateError("Cannot inject resume state: resume_state is None") + + # check if the resume_state is valid for the current pipeline + _validate_components_against_pipeline(resume_state, self.graph) + + data = self._prepare_component_input_data(resume_state["pipeline_state"]["inputs"]) + component_visits = resume_state["pipeline_state"]["component_visits"] + ordered_component_names = resume_state["pipeline_state"]["ordered_component_names"] + logger.info( + "Resuming pipeline from {component} with visit count {visits}", + component=resume_state["pipeline_breakpoint"]["component"], + visits=resume_state["pipeline_breakpoint"]["visits"], + ) + + return component_visits, data, resume_state, ordered_component_names diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index 3c3c347341..1b702e4073 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -2,11 +2,17 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, Mapping, Optional, Set, cast +# pylint: disable=too-many-positional-arguments + +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, Mapping, Optional, Set, Union, cast from haystack import logging, tracing + +# from haystack.components.agents import Agent from haystack.core.component import Component -from haystack.core.errors import PipelineRuntimeError +from haystack.core.errors import PipelineInvalidResumeStateError, PipelineRuntimeError from haystack.core.pipeline.base import ( _COMPONENT_INPUT, _COMPONENT_OUTPUT, @@ -14,8 +20,17 @@ ComponentPriority, PipelineBase, ) +from haystack.core.pipeline.breakpoint import ( + _validate_break_point, + _validate_components_against_pipeline, + check_regular_break_point, + handle_agent_break_point, + trigger_break_point, +) from haystack.core.pipeline.utils import _deepcopy_with_exceptions +from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint from haystack.telemetry import pipeline_running +from haystack.utils import _deserialize_value_with_schema logger = logging.getLogger(__name__) @@ -70,7 +85,8 @@ def _run_component( return cast(Dict[Any, Any], component_output) - def run( # noqa: PLR0915, PLR0912 + # ToDo: delete + def run_old( # noqa: PLR0915, PLR0912 self, data: Dict[str, Any], include_outputs_from: Optional[Set[str]] = None ) -> Dict[str, Any]: """ @@ -257,3 +273,355 @@ def run( # noqa: PLR0915, PLR0912 priority_queue = self._fill_queue(ordered_component_names, inputs, component_visits) return pipeline_outputs + + def _handle_resume_state(self, resume_state: Dict[str, Any]) -> tuple[Dict[str, int], Dict[str, Any], bool, list]: + """ + Handle resume state initialization. + + :param resume_state: The resume state to handle + :return: Tuple of (component_visits, data, resume_agent_in_pipeline, ordered_component_names) + """ + if resume_state.get("agent_name"): + return self._handle_agent_resume_state(resume_state) + else: + return self._handle_regular_resume_state(resume_state) + + def _handle_agent_resume_state( + self, resume_state: Dict[str, Any] + ) -> tuple[Dict[str, int], Dict[str, Any], bool, list]: + """ + Handle agent-specific resume state. + + :param resume_state: The resume state to handle + :return: Tuple of (component_visits, data, resume_agent_in_pipeline, ordered_component_names) + """ + agent_name = resume_state["agent_name"] + for name, component in self.graph.nodes.items(): + if component["instance"].__class__.__name__ == "Agent" and name == agent_name: + main_pipeline_state = resume_state.get("main_pipeline_state", {}) + component_visits = main_pipeline_state.get("component_visits", {}) + ordered_component_names = main_pipeline_state.get("ordered_component_names", []) + data = _deserialize_value_with_schema(main_pipeline_state.get("inputs", {})) + return component_visits, data, True, ordered_component_names + + # Fallback to regular resume if agent not found + return self._handle_regular_resume_state(resume_state) + + def _handle_regular_resume_state( + self, resume_state: Dict[str, Any] + ) -> tuple[Dict[str, int], Dict[str, Any], bool, list]: + """ + Handle regular component resume state. + + :param resume_state: The resume state to handle + :return: Tuple of (component_visits, data, resume_agent_in_pipeline, ordered_component_names) + """ + component_visits, data, resume_state, ordered_component_names = self.inject_resume_state_into_graph( + resume_state=resume_state + ) + data = _deserialize_value_with_schema(resume_state["pipeline_state"]["inputs"]) + return component_visits, data, False, ordered_component_names + + def run( # noqa: PLR0915, PLR0912 + self, + data: Dict[str, Any], + include_outputs_from: Optional[Set[str]] = None, + break_point: Optional[Union[Breakpoint, AgentBreakpoint]] = None, + resume_state: Optional[Dict[str, Any]] = None, + debug_path: Optional[Union[str, Path]] = None, + ) -> Dict[str, Any]: + """ + Runs the Pipeline with given input data. + + Usage: + ```python + from haystack import Pipeline, Document + from haystack.utils import Secret + from haystack.document_stores.in_memory import InMemoryDocumentStore + from haystack.components.retrievers.in_memory import InMemoryBM25Retriever + from haystack.components.generators import OpenAIGenerator + from haystack.components.builders.answer_builder import AnswerBuilder + from haystack.components.builders.prompt_builder import PromptBuilder + + # Write documents to InMemoryDocumentStore + document_store = InMemoryDocumentStore() + document_store.write_documents([ + Document(content="My name is Jean and I live in Paris."), + Document(content="My name is Mark and I live in Berlin."), + Document(content="My name is Giorgio and I live in Rome.") + ]) + + prompt_template = \"\"\" + Given these documents, answer the question. + Documents: + {% for doc in documents %} + {{ doc.content }} + {% endfor %} + Question: {{question}} + Answer: + \"\"\" + + retriever = InMemoryBM25Retriever(document_store=document_store) + prompt_builder = PromptBuilder(template=prompt_template) + llm = OpenAIGenerator(api_key=Secret.from_token(api_key)) + + rag_pipeline = Pipeline() + rag_pipeline.add_component("retriever", retriever) + rag_pipeline.add_component("prompt_builder", prompt_builder) + rag_pipeline.add_component("llm", llm) + rag_pipeline.connect("retriever", "prompt_builder.documents") + rag_pipeline.connect("prompt_builder", "llm") + + # Ask a question + question = "Who lives in Paris?" + results = rag_pipeline.run( + { + "retriever": {"query": question}, + "prompt_builder": {"question": question}, + } + ) + + print(results["llm"]["replies"]) + # Jean lives in Paris + ``` + + :param data: + A dictionary of inputs for the pipeline's components. Each key is a component name + and its value is a dictionary of that component's input parameters: + ``` + data = { + "comp1": {"input1": 1, "input2": 2}, + } + ``` + For convenience, this format is also supported when input names are unique: + ``` + data = { + "input1": 1, "input2": 2, + } + ``` + :param include_outputs_from: + Set of component names whose individual outputs are to be + included in the pipeline's output. For components that are + invoked multiple times (in a loop), only the last-produced + output is included. + + :param break_point: + A set of breakpoints that can be used to debug the pipeline execution. + + :param resume_state: + A dictionary containing the state of a previously saved pipeline execution. + + :param debug_path: + Path to the directory where the pipeline state should be saved. + + :returns: + A dictionary where each entry corresponds to a component name + and its output. If `include_outputs_from` is `None`, this dictionary + will only contain the outputs of leaf components, i.e., components + without outgoing connections. + + :raises ValueError: + If invalid inputs are provided to the pipeline. + :raises PipelineRuntimeError: + If the Pipeline contains cycles with unsupported connections that would cause + it to get stuck and fail running. + Or if a Component fails or returns output in an unsupported type. + :raises PipelineMaxComponentRuns: + If a Component reaches the maximum number of times it can be run in this Pipeline. + :raises PipelineBreakpointException: + When a pipeline_breakpoint is triggered. Contains the component name, state, and partial results. + """ + pipeline_running(self) + + if break_point and resume_state: + msg = ( + "pipeline_breakpoint and resume_state cannot be provided at the same time. " + "The pipeline run will be aborted." + ) + raise PipelineInvalidResumeStateError(message=msg) + + # make sure all breakpoints are valid, i.e. reference components in the pipeline + if break_point: + _validate_break_point(break_point, self.graph) + + # TODO: Remove this warmup once we can check reliably whether a component has been warmed up or not + # As of now it's here to make sure we don't have failing tests that assume warm_up() is called in run() + self.warm_up() + + if include_outputs_from is None: + include_outputs_from = set() + + if not resume_state: + # normalize `data` + data = self._prepare_component_input_data(data) + + # Raise ValueError if input is malformed in some way + self.validate_input(data) + + # We create a list of components in the pipeline sorted by name, so that the algorithm runs + # deterministically and independent of insertion order into the pipeline. + ordered_component_names = sorted(self.graph.nodes.keys()) + + # We track component visits to decide if a component can run. + component_visits = dict.fromkeys(ordered_component_names, 0) + resume_agent_in_pipeline = False + + else: + # Handle resume state + component_visits, data, resume_agent_in_pipeline, ordered_component_names = self._handle_resume_state( + resume_state + ) + + cached_topological_sort = None + # We need to access a component's receivers multiple times during a pipeline run. + # We store them here for easy access. + cached_receivers = {name: self._find_receivers_from(name) for name in ordered_component_names} + + pipeline_outputs: Dict[str, Any] = {} + with tracing.tracer.trace( + "haystack.pipeline.run", + tags={ + "haystack.pipeline.input_data": data, + "haystack.pipeline.output_data": pipeline_outputs, + "haystack.pipeline.metadata": self.metadata, + "haystack.pipeline.max_runs_per_component": self._max_runs_per_component, + }, + ) as span: + inputs = self._convert_to_internal_format(pipeline_inputs=data) + priority_queue = self._fill_queue(ordered_component_names, inputs, component_visits) + + # check if pipeline is blocked before execution + self.validate_pipeline(priority_queue) + + while True: + candidate = self._get_next_runnable_component(priority_queue, component_visits) + if candidate is None: + break + + priority, component_name, component = candidate + + if len(priority_queue) > 0 and priority in [ComponentPriority.DEFER, ComponentPriority.DEFER_LAST]: + component_name, topological_sort = self._tiebreak_waiting_components( + component_name=component_name, + priority=priority, + priority_queue=priority_queue, + topological_sort=cached_topological_sort, + ) + + cached_topological_sort = topological_sort + component = self._get_component_with_graph_metadata_and_visits( + component_name, component_visits[component_name] + ) + + is_resume = bool(resume_state and resume_state["pipeline_breakpoint"]["component"] == component_name) + component_inputs = self._consume_component_inputs( + component_name=component_name, component=component, inputs=inputs, is_resume=is_resume + ) + + # We need to add missing defaults using default values from input sockets because the run signature + # might not provide these defaults for components with inputs defined dynamically upon component + # initialization + component_inputs = self._add_missing_input_defaults(component_inputs, component["input_sockets"]) + + # Scenario 1: Resume state is provided to resume the pipeline at a specific component + # Deserialize the component_inputs if they are passed in resume state + # this check will prevent other component_inputs generated at runtime from being deserialized + if resume_state and component_name in resume_state["pipeline_state"]["inputs"].keys(): + for key, value in component_inputs.items(): + component_inputs[key] = _deserialize_value_with_schema(value) + + # Scenario 2: a breakpoint is provided to stop the pipeline at a specific component and visit count + breakpoint_triggered = False + if break_point is not None: + agent_breakpoint = False + + if isinstance(break_point, AgentBreakpoint): + # component_instance = component["instance"] + # if isinstance(component_instance, Agent): + component_inputs = handle_agent_break_point( + break_point, + component_name, + component_inputs, + inputs, + component_visits, + ordered_component_names, + data, + debug_path, + ) + agent_breakpoint = True + + if not agent_breakpoint and isinstance(break_point, Breakpoint): + breakpoint_triggered = check_regular_break_point(break_point, component_name, component_visits) + + if breakpoint_triggered: + trigger_break_point( + component_name, + component_inputs, + inputs, + component_visits, + debug_path, + data, + ordered_component_names, + pipeline_outputs, + ) + + if resume_agent_in_pipeline: + # inject the resume_state into the component (the Agent) inputs + component_inputs["resume_state"] = resume_state + component_inputs["break_point"] = None + + component_outputs = self._run_component( + component_name=component_name, + component=component, + inputs=component_inputs, # the inputs to the current component + component_visits=component_visits, + parent_span=span, + ) + + # Updates global input state with component outputs and returns outputs that should go to + # pipeline outputs. + component_pipeline_outputs = self._write_component_outputs( + component_name=component_name, + component_outputs=component_outputs, + inputs=inputs, + receivers=cached_receivers[component_name], + include_outputs_from=include_outputs_from, + ) + + if component_pipeline_outputs: + pipeline_outputs[component_name] = deepcopy(component_pipeline_outputs) + if self._is_queue_stale(priority_queue): + priority_queue = self._fill_queue(ordered_component_names, inputs, component_visits) + + if break_point and not agent_breakpoint: + logger.warning( + "The given breakpoint {break_point} was never triggered. This is because:\n" + "1. The provided component is not a part of the pipeline execution path.\n" + "2. The component did not reach the visit count specified in the pipeline_breakpoint", + pipeline_breakpoint=break_point, + ) + + return pipeline_outputs + + def inject_resume_state_into_graph(self, resume_state): + """ + Loads the resume state from a file and injects it into the pipeline graph. + + """ + # We previously check if the resume_state is None but this is needed to prevent a typing error + if not resume_state: + raise PipelineInvalidResumeStateError("Cannot inject resume state: resume_state is None") + + # check if the resume_state is valid for the current pipeline + _validate_components_against_pipeline(resume_state, self.graph) + + data = self._prepare_component_input_data(resume_state["pipeline_state"]["inputs"]) + component_visits = resume_state["pipeline_state"]["component_visits"] + ordered_component_names = resume_state["pipeline_state"]["ordered_component_names"] + logger.info( + "Resuming pipeline from {component} with visit count {visits}", + component=resume_state["pipeline_breakpoint"]["component"], + visits=resume_state["pipeline_breakpoint"]["visits"], + ) + + return component_visits, data, resume_state, ordered_component_names diff --git a/haystack/dataclasses/breakpoints.py b/haystack/dataclasses/breakpoints.py new file mode 100644 index 0000000000..cf4443c29b --- /dev/null +++ b/haystack/dataclasses/breakpoints.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Optional, Union + + +@dataclass +class Breakpoint: + """ + A dataclass to hold a breakpoint for a component. + """ + + component_name: str + visit_count: int = 0 + + def __hash__(self): + return hash((self.component_name, self.visit_count)) + + def __eq__(self, other): + if not isinstance(other, Breakpoint): + return False + return self.component_name == other.component_name and self.visit_count == other.visit_count + + def __str__(self): + return f"Breakpoint(component_name={self.component_name}, visit_count={self.visit_count})" + + def __repr__(self): + return self.__str__() + + +@dataclass +class ToolBreakpoint(Breakpoint): + """ + A dataclass to hold a breakpoint that can be used to debug a Tool. + + If tool_name is None, it means that the breakpoint is for every tool in the component. + Otherwise, it means that the breakpoint is for the tool with the given name. + """ + + tool_name: Optional[str] = None + + def __hash__(self): + return hash((self.component_name, self.visit_count, self.tool_name)) + + def __eq__(self, other): + if not isinstance(other, ToolBreakpoint): + return False + return super().__eq__(other) and self.tool_name == other.tool_name + + def __str__(self): + if self.tool_name: + return ( + f"ToolBreakpoint(component_name={self.component_name}, visit_count={self.visit_count}, " + f"tool_name={self.tool_name})" + ) + else: + return ( + f"ToolBreakpoint(component_name={self.component_name}, visit_count={self.visit_count}, " + f"tool_name=ALL_TOOLS)" + ) + + def __repr__(self): + return self.__str__() + + +@dataclass +class AgentBreakpoint: + """ + A dataclass to hold a breakpoint that can be used to debug an Agent. + """ + + break_point: Union[Breakpoint, ToolBreakpoint] + agent_name: str = "" + + def __init__(self, agent_name: str, break_point: Union[Breakpoint, ToolBreakpoint]): + if not isinstance(break_point, ToolBreakpoint) and break_point.component_name != "chat_generator": + raise ValueError( + "The break_point must be a Breakpoint that has the component_name " + "'chat_generator' or be a ToolBreakpoint." + ) + + if not break_point: + raise ValueError("A Breakpoint must be provided.") + + self.agent_name = agent_name + + if ( + isinstance(break_point, ToolBreakpoint) + or isinstance(break_point, Breakpoint) + and not isinstance(break_point, ToolBreakpoint) + ): + self.break_point = break_point + else: + raise ValueError("The breakpoint must be either Breakpoint or ToolBreakpoint.") diff --git a/pyproject.toml b/pyproject.toml index fdf10524dc..3d44da0f93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,6 +144,7 @@ dependencies = [ "pip", # mypy needs pip to install missing stub packages "pylint", "ipython", + "colorama==0.4.6", # Pipeline checkpoints test - ToDo: rewrite the test without this lib ] [tool.hatch.envs.test.scripts] diff --git a/test/components/agents/test_agent_breakpoints_inside_pipeline.py b/test/components/agents/test_agent_breakpoints_inside_pipeline.py new file mode 100644 index 0000000000..6a56e6ec17 --- /dev/null +++ b/test/components/agents/test_agent_breakpoints_inside_pipeline.py @@ -0,0 +1,410 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import tempfile +from pathlib import Path +from typing import Dict, List, Optional + +from haystack import component +from haystack.components.agents import Agent +from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.core.errors import BreakpointException, PipelineRuntimeError +from haystack.core.pipeline import Pipeline +from haystack.core.pipeline.breakpoint import load_state +from haystack.dataclasses import ByteStream, ChatMessage, Document, ToolCall +from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack.tools import tool + +document_store = InMemoryDocumentStore() + + +@component +class MockLinkContentFetcher: + @component.output_types(streams=List[ByteStream]) + def run(self, urls: List[str]) -> Dict[str, List[ByteStream]]: + mock_html_content = """ + + + + Deepset - About Our Team + + +

About Deepset

+

Deepset is a company focused on natural language processing and AI.

+

Our Leadership Team

+
+

Malte Pietsch

+

Malte Pietsch is the CEO and co-founder of Deepset. He has extensive experience in machine learning + and natural language processing.

+

Job Title: Chief Executive Officer

+
+
+

Milos Rusic

+

Milos Rusic is the CTO and co-founder of Deepset. He specializes in building scalable AI systems + and has worked on various NLP projects.

+

Job Title: Chief Technology Officer

+
+

Our Mission

+

Deepset aims to make natural language processing accessible to developers and businesses worldwide + through open-source tools and enterprise solutions.

+ + + """ + + bytestream = ByteStream( + data=mock_html_content.encode("utf-8"), + mime_type="text/html", + meta={"url": urls[0] if urls else "https://en.wikipedia.org/wiki/Deepset"}, + ) + + return {"streams": [bytestream]} + + +@component +class MockHTMLToDocument: + @component.output_types(documents=List[Document]) + def run(self, sources: List[ByteStream]) -> Dict[str, List[Document]]: + """Mock HTML to Document converter that extracts text content from HTML ByteStreams.""" + + documents = [] + for source in sources: + # Extract the HTML content from the ByteStream + html_content = source.data.decode("utf-8") + + # Simple text extraction - remove HTML tags and extract meaningful content + # This is a simplified version that extracts the main content + import re + + # Remove HTML tags + text_content = re.sub(r"<[^>]+>", " ", html_content) + # Remove extra whitespace + text_content = re.sub(r"\s+", " ", text_content).strip() + + # Create a Document with the extracted text + document = Document( + content=text_content, + meta={"url": source.meta.get("url", "unknown"), "mime_type": source.mime_type, "source_type": "html"}, + ) + documents.append(document) + + return {"documents": documents} + + +@tool +def add_database_tool(name: str, surname: str, job_title: Optional[str], other: Optional[str]): + document_store.write_documents( + [Document(content=name + " " + surname + " " + (job_title or ""), meta={"other": other})] + ) + + +def create_pipeline(): + generator = OpenAIChatGenerator() + call_count = 0 + + def mock_run(messages, tools=None, **kwargs): + nonlocal call_count + call_count += 1 + + if call_count == 1: + return { + "replies": [ + ChatMessage.from_assistant( + "I'll extract the information about the people mentioned in the context.", + tool_calls=[ + ToolCall( + tool_name="add_database_tool", + arguments={ + "name": "Malte", + "surname": "Pietsch", + "job_title": "Chief Executive Officer", + "other": "CEO and co-founder of Deepset with extensive experience in machine " + "learning and natural language processing", + }, + ), + ToolCall( + tool_name="add_database_tool", + arguments={ + "name": "Milos", + "surname": "Rusic", + "job_title": "Chief Technology Officer", + "other": "CTO and co-founder of Deepset specializing in building scalable " + "AI systems and NLP projects", + }, + ), + ], + ) + ] + } + else: + return { + "replies": [ + ChatMessage.from_assistant( + "I have successfully extracted and stored information about the following people:\n\n" + "1. **Malte Pietsch** - Chief Executive Officer\n" + " - CEO and co-founder of Deepset\n" + " - Extensive experience in machine learning and natural language processing\n\n" + "2. **Milos Rusic** - Chief Technology Officer\n" + " - CTO and co-founder of Deepset\n" + " - Specializes in building scalable AI systems and NLP projects\n\n" + "Both individuals have been added to the knowledge base with their respective information." + ) + ] + } + + generator.run = mock_run + + database_assistant = Agent( + chat_generator=generator, + tools=[add_database_tool], + system_prompt=""" + You are a database assistant. + Your task is to extract the names of people mentioned in the given context and add them to a knowledge base, + along with additional relevant information about them that can be extracted from the context. + Do not use you own knowledge, stay grounded to the given context. + Do not ask the user for confirmation. Instead, automatically update the knowledge base and return a brief + summary of the people added, including the information stored for each. + """, + exit_conditions=["text"], + max_agent_steps=100, + raise_on_tool_invocation_failure=False, + ) + + extraction_agent = Pipeline() + extraction_agent.add_component("fetcher", MockLinkContentFetcher()) + extraction_agent.add_component("converter", MockHTMLToDocument()) + extraction_agent.add_component( + "builder", + ChatPromptBuilder( + template=[ + ChatMessage.from_user(""" + {% for doc in docs %} + {{ doc.content|default|truncate(25000) }} + {% endfor %} + """) + ], + required_variables=["docs"], + ), + ) + extraction_agent.add_component("database_agent", database_assistant) + + extraction_agent.connect("fetcher.streams", "converter.sources") + extraction_agent.connect("converter.documents", "builder.docs") + extraction_agent.connect("builder", "database_agent") + + return extraction_agent + + +def run_pipeline_without_any_breakpoints(): + pipeline_with_agent = create_pipeline() + agent_output = pipeline_with_agent.run(data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}) + + # pipeline completed + assert "database_agent" in agent_output + assert "messages" in agent_output["database_agent"] + assert len(agent_output["database_agent"]["messages"]) > 0 + + # final message contains the expected summary + final_message = agent_output["database_agent"]["messages"][-1].text + assert "Malte Pietsch" in final_message + assert "Milos Rusic" in final_message + assert "Chief Executive Officer" in final_message + assert "Chief Technology Officer" in final_message + + +def test_chat_generator_breakpoint_in_pipeline_agent(): + pipeline_with_agent = create_pipeline() + agent_generator_breakpoint = Breakpoint("chat_generator", 0) + agent_breakpoint = AgentBreakpoint(break_point=agent_generator_breakpoint, agent_name="database_agent") + + with tempfile.TemporaryDirectory() as debug_path: + try: + pipeline_with_agent.run( + data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, + break_point=agent_breakpoint, + debug_path=debug_path, + ) + assert False, "Expected exception was not raised" + + except BreakpointException as e: # this is the exception from the Agent + assert e.component == "chat_generator" + assert e.state is not None + assert "messages" in e.state + assert e.results is not None + except PipelineRuntimeError as e: + # propagated exception to core Pipeline - assure that the cause is a PipelineBreakpointException + if hasattr(e, "__cause__") and isinstance(e.__cause__, BreakpointException): + original_exception = e.__cause__ + assert original_exception.component == "chat_generator" + assert original_exception.state is not None + assert "messages" in original_exception.state + assert original_exception.results is not None + else: + # re-raise if it's a different PipelineRuntimeError - test failed + raise + + # verify that debug/state file was created + chat_generator_state_files = list(Path(debug_path).glob("database_agent_chat_generator_*.json")) + assert len(chat_generator_state_files) > 0, f"No chat_generator state files found in {debug_path}" + + +def test_tool_breakpoint_in_pipeline_agent(): + pipeline_with_agent = create_pipeline() + agent_tool_breakpoint = ToolBreakpoint("tool_invoker", 0, "add_database_tool") + agent_breakpoints = AgentBreakpoint(break_point=agent_tool_breakpoint, agent_name="database_agent") + + with tempfile.TemporaryDirectory() as debug_path: + try: + pipeline_with_agent.run( + data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, + break_point=agent_breakpoints, + debug_path=debug_path, + ) + assert False, "Expected exception was not raised" + except BreakpointException as e: # this is the exception from the Agent + assert e.component == "tool_invoker" + assert e.state is not None + assert "messages" in e.state + assert e.results is not None + except PipelineRuntimeError as e: + # propagated exception to core Pipeline - assure that the cause is a PipelineBreakpointException + if hasattr(e, "__cause__") and isinstance(e.__cause__, BreakpointException): + original_exception = e.__cause__ + assert original_exception.component == "tool_invoker" + assert original_exception.state is not None + assert "messages" in original_exception.state + assert original_exception.results is not None + else: + # re-raise if it's a different PipelineRuntimeError - test failed + raise + + # verify that debug/state file was created + tool_invoker_state_files = list(Path(debug_path).glob("database_agent_tool_invoker_*.json")) + assert len(tool_invoker_state_files) > 0, f"No tool_invoker state files found in {debug_path}" + + +def test_agent_breakpoint_chat_generator_and_resume_pipeline(): + pipeline_with_agent = create_pipeline() + agent_generator_breakpoint = Breakpoint("chat_generator", 0) + agent_breakpoints = AgentBreakpoint(break_point=agent_generator_breakpoint, agent_name="database_agent") + + with tempfile.TemporaryDirectory() as debug_path: + try: + pipeline_with_agent.run( + data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, + break_point=agent_breakpoints, + debug_path=debug_path, + ) + assert False, "Expected PipelineBreakpointException was not raised" + + except BreakpointException as e: + assert e.component == "chat_generator" + assert e.state is not None + assert "messages" in e.state + assert e.results is not None + + except PipelineRuntimeError as e: + if hasattr(e, "__cause__") and isinstance(e.__cause__, BreakpointException): + original_exception = e.__cause__ + assert original_exception.component == "chat_generator" + assert original_exception.state is not None + assert "messages" in original_exception.state + assert original_exception.results is not None + else: + raise + + # verify that the state file was created + chat_generator_state_files = list(Path(debug_path).glob("database_agent_chat_generator_*.json")) + assert len(chat_generator_state_files) > 0, f"No chat_generator state files found in {debug_path}" + + # resume the pipeline from the saved state + latest_state_file = max(chat_generator_state_files, key=os.path.getctime) + resume_state = load_state(latest_state_file) + + result = pipeline_with_agent.run(data={}, resume_state=resume_state) + + # pipeline completed successfully after resuming + assert "database_agent" in result + assert "messages" in result["database_agent"] + assert len(result["database_agent"]["messages"]) > 0 + + # final message contains the expected summary + final_message = result["database_agent"]["messages"][-1].text + assert "Malte Pietsch" in final_message + assert "Milos Rusic" in final_message + assert "Chief Executive Officer" in final_message + assert "Chief Technology Officer" in final_message + + # tool should have been called during the resumed execution + documents = document_store.filter_documents() + assert len(documents) >= 2, "Expected at least 2 documents to be added to the database" + + # both people were added + person_names = [doc.content for doc in documents] + assert any("Malte Pietsch" in name for name in person_names) + assert any("Milos Rusic" in name for name in person_names) + + +def test_agent_breakpoint_tool_and_resume_pipeline(): + pipeline_with_agent = create_pipeline() + agent_tool_breakpoint = ToolBreakpoint("tool_invoker", 0, "add_database_tool") + agent_breakpoints = AgentBreakpoint(break_point=agent_tool_breakpoint, agent_name="database_agent") + + with tempfile.TemporaryDirectory() as debug_path: + try: + pipeline_with_agent.run( + data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, + break_point=agent_breakpoints, + debug_path=debug_path, + ) + assert False, "Expected PipelineBreakpointException was not raised" + + except BreakpointException as e: + assert e.component == "tool_invoker" + assert e.state is not None + assert "messages" in e.state + assert e.results is not None + + except PipelineRuntimeError as e: + if hasattr(e, "__cause__") and isinstance(e.__cause__, BreakpointException): + original_exception = e.__cause__ + assert original_exception.component == "tool_invoker" + assert original_exception.state is not None + assert "messages" in original_exception.state + assert original_exception.results is not None + else: + raise + + # verify that the state file was created + tool_invoker_state_files = list(Path(debug_path).glob("database_agent_tool_invoker_*.json")) + assert len(tool_invoker_state_files) > 0, f"No tool_invoker state files found in {debug_path}" + + # resume the pipeline from the saved state + latest_state_file = max(tool_invoker_state_files, key=os.path.getctime) + resume_state = load_state(latest_state_file) + + result = pipeline_with_agent.run(data={}, resume_state=resume_state) + + # pipeline completed successfully after resuming + assert "database_agent" in result + assert "messages" in result["database_agent"] + assert len(result["database_agent"]["messages"]) > 0 + + # final message contains the expected summary + final_message = result["database_agent"]["messages"][-1].text + assert "Malte Pietsch" in final_message + assert "Milos Rusic" in final_message + assert "Chief Executive Officer" in final_message + assert "Chief Technology Officer" in final_message + + # tool should have been called during the resumed execution + documents = document_store.filter_documents() + assert len(documents) >= 2, "Expected at least 2 documents to be added to the database" + + # both people were added + person_names = [doc.content for doc in documents] + assert any("Malte Pietsch" in name for name in person_names) + assert any("Milos Rusic" in name for name in person_names) diff --git a/test/components/agents/test_agent_breakpoints_isolation_async.py b/test/components/agents/test_agent_breakpoints_isolation_async.py new file mode 100644 index 0000000000..4f86875605 --- /dev/null +++ b/test/components/agents/test_agent_breakpoints_isolation_async.py @@ -0,0 +1,197 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from pathlib import Path +from typing import Optional +from unittest.mock import AsyncMock + +import pytest + +from haystack.components.agents import Agent +from haystack.core.errors import BreakpointException +from haystack.core.pipeline.breakpoint import load_state +from haystack.dataclasses import ChatMessage, ToolCall +from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint +from haystack.tools import Tool +from test.components.agents.test_agent import MockChatGeneratorWithRunAsync, weather_function + +agent_name = "isolated_agent" + + +def create_chat_generator_breakpoint(visit_count: int = 0) -> Breakpoint: + return Breakpoint(component_name="chat_generator", visit_count=visit_count) + + +def create_tool_breakpoint(tool_name: Optional[str] = None, visit_count: int = 0) -> ToolBreakpoint: + return ToolBreakpoint(component_name="tool_invoker", visit_count=visit_count, tool_name=tool_name) + + +@pytest.fixture +def weather_tool(): + return Tool( + name="weather_tool", + description="Provides weather information for a given location.", + parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, + function=weather_function, + ) + + +@pytest.fixture +def mock_chat_generator(): + generator = MockChatGeneratorWithRunAsync() + mock_run_async = AsyncMock() + mock_run_async.return_value = { + "replies": [ + ChatMessage.from_assistant( + "I'll help you check the weather.", + tool_calls=[{"tool_name": "weather_tool", "tool_args": {"location": "Berlin"}}], + ) + ] + } + + async def mock_run_async_with_tools(messages, tools=None, **kwargs): + return mock_run_async.return_value + + generator.run_async = mock_run_async_with_tools + return generator + + +@pytest.fixture +def agent(mock_chat_generator, weather_tool): + return Agent( + chat_generator=mock_chat_generator, + tools=[weather_tool], + system_prompt="You are a helpful assistant that can use tools to help users.", + max_agent_steps=10, # Increase max steps to allow breakpoints to trigger + ) + + +@pytest.fixture +def debug_path(tmp_path): + return str(tmp_path / "debug_states") + + +@pytest.fixture +def mock_agent_with_tool_calls(monkeypatch, weather_tool): + monkeypatch.setenv("OPENAI_API_KEY", "fake-key") + generator = MockChatGeneratorWithRunAsync() + mock_messages = [ + ChatMessage.from_assistant("First response"), + ChatMessage.from_assistant(tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]), + ] + agent = Agent(chat_generator=generator, tools=[weather_tool], max_agent_steps=10) # Increase max steps + agent.warm_up() + agent.chat_generator.run_async = AsyncMock(return_value={"replies": mock_messages}) + return agent + + +@pytest.mark.asyncio +async def test_run_async_with_chat_generator_breakpoint(agent): + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + chat_generator_bp = create_chat_generator_breakpoint(visit_count=0) + agent_breakpoint = AgentBreakpoint(break_point=chat_generator_bp, agent_name="test") + with pytest.raises(BreakpointException) as exc_info: + await agent.run_async(messages=messages, break_point=agent_breakpoint, agent_name=agent_name) + assert exc_info.value.component == "chat_generator" + assert "messages" in exc_info.value.state + + +@pytest.mark.asyncio +async def test_run_async_with_tool_invoker_breakpoint(mock_agent_with_tool_calls): + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + tool_bp = create_tool_breakpoint(tool_name="weather_tool", visit_count=0) + agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test") + with pytest.raises(BreakpointException) as exc_info: + await mock_agent_with_tool_calls.run_async( + messages=messages, break_point=agent_breakpoint, agent_name=agent_name + ) + + assert exc_info.value.component == "tool_invoker" + assert "messages" in exc_info.value.state + + +@pytest.mark.asyncio +async def test_resume_from_chat_generator_async(agent, debug_path): + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + chat_generator_bp = create_chat_generator_breakpoint(visit_count=0) + agent_breakpoint = AgentBreakpoint(break_point=chat_generator_bp, agent_name=agent_name) + + try: + await agent.run_async( + messages=messages, break_point=agent_breakpoint, debug_path=debug_path, agent_name=agent_name + ) + except BreakpointException: + pass + + state_files = list(Path(debug_path).glob(agent_name + "_chat_generator_*.json")) + + assert len(state_files) > 0 + latest_state_file = str(max(state_files, key=os.path.getctime)) + + resume_state = load_state(latest_state_file) + result = await agent.run_async( + messages=[ChatMessage.from_user("Continue from where we left off.")], resume_state=resume_state + ) + + assert "messages" in result + assert "last_message" in result + assert len(result["messages"]) > 0 + + +@pytest.mark.asyncio +async def test_resume_from_tool_invoker_async(mock_agent_with_tool_calls, debug_path): + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + tool_bp = create_tool_breakpoint(tool_name="weather_tool", visit_count=0) + agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name=agent_name) + + try: + await mock_agent_with_tool_calls.run_async( + messages=messages, break_point=agent_breakpoint, debug_path=debug_path, agent_name=agent_name + ) + except BreakpointException: + pass + + state_files = list(Path(debug_path).glob(agent_name + "_tool_invoker_*.json")) + + assert len(state_files) > 0 + latest_state_file = str(max(state_files, key=os.path.getctime)) + + resume_state = load_state(latest_state_file) + + result = await mock_agent_with_tool_calls.run_async( + messages=[ChatMessage.from_user("Continue from where we left off.")], resume_state=resume_state + ) + + assert "messages" in result + assert "last_message" in result + assert len(result["messages"]) > 0 + + +@pytest.mark.asyncio +async def test_invalid_combination_breakpoint_and_resume_state_async(mock_agent_with_tool_calls): + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + tool_bp = create_tool_breakpoint(tool_name="weather_tool", visit_count=0) + agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test") + with pytest.raises(ValueError, match="agent_breakpoint and resume_state cannot be provided at the same time"): + await mock_agent_with_tool_calls.run_async( + messages=messages, break_point=agent_breakpoint, resume_state={"some": "state"} + ) + + +@pytest.mark.asyncio +async def test_breakpoint_with_invalid_component_async(mock_agent_with_tool_calls): + invalid_bp = Breakpoint(component_name="invalid_breakpoint", visit_count=0) + with pytest.raises(ValueError): + AgentBreakpoint(break_point=invalid_bp, agent_name="test") + + +@pytest.mark.asyncio +async def test_breakpoint_with_invalid_tool_name_async(mock_agent_with_tool_calls): + tool_breakpoint = create_tool_breakpoint(tool_name="invalid_tool", visit_count=0) + with pytest.raises(ValueError, match="Tool 'invalid_tool' is not available in the agent's tools"): + agent_breakpoint = AgentBreakpoint(break_point=tool_breakpoint, agent_name="test") + await mock_agent_with_tool_calls.run_async( + messages=[ChatMessage.from_user("What's the weather in Berlin?")], break_point=agent_breakpoint + ) diff --git a/test/components/agents/test_agent_breakpoints_isolation_sync.py b/test/components/agents/test_agent_breakpoints_isolation_sync.py new file mode 100644 index 0000000000..23c3b5b277 --- /dev/null +++ b/test/components/agents/test_agent_breakpoints_isolation_sync.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from pathlib import Path + +import pytest + +from haystack.core.errors import BreakpointException +from haystack.core.pipeline.breakpoint import load_state +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint +from test.components.agents.test_agent_breakpoints_utils import ( + agent_sync, + create_chat_generator_breakpoint, + create_tool_breakpoint, + mock_agent_with_tool_calls_sync, + weather_tool, +) + +agent_name = "isolated_agent" + + +def test_run_with_chat_generator_breakpoint(agent_sync): # noqa: F811 + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + chat_generator_bp = create_chat_generator_breakpoint(visit_count=0) + agent_breakpoint = AgentBreakpoint(break_point=chat_generator_bp, agent_name="test_agent") + with pytest.raises(BreakpointException) as exc_info: + agent_sync.run(messages=messages, break_point=agent_breakpoint, agent_name="test") + assert exc_info.value.component == "chat_generator" + assert "messages" in exc_info.value.state + + +def test_run_with_tool_invoker_breakpoint(mock_agent_with_tool_calls_sync): # noqa: F811 + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + tool_bp = create_tool_breakpoint(tool_name="weather_tool", visit_count=0) + agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test_agent") + with pytest.raises(BreakpointException) as exc_info: + mock_agent_with_tool_calls_sync.run(messages=messages, break_point=agent_breakpoint, agent_name="test") + + assert exc_info.value.component == "tool_invoker" + assert "messages" in exc_info.value.state + + +def test_resume_from_chat_generator(agent_sync, tmp_path): # noqa: F811 + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + chat_generator_bp = create_chat_generator_breakpoint(visit_count=0) + agent_breakpoint = AgentBreakpoint(break_point=chat_generator_bp, agent_name=agent_name) + debug_path = str(tmp_path / "debug_states") + + try: + agent_sync.run(messages=messages, break_point=agent_breakpoint, debug_path=debug_path, agent_name=agent_name) + except BreakpointException: + pass + + state_files = list(Path(debug_path).glob(agent_name + "_chat_generator_*.json")) + assert len(state_files) > 0 + latest_state_file = str(max(state_files, key=os.path.getctime)) + + resume_state = load_state(latest_state_file) + result = agent_sync.run( + messages=[ChatMessage.from_user("Continue from where we left off.")], resume_state=resume_state + ) + + assert "messages" in result + assert "last_message" in result + assert len(result["messages"]) > 0 + + +def test_resume_from_tool_invoker(mock_agent_with_tool_calls_sync, tmp_path): # noqa: F811 + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + tool_bp = create_tool_breakpoint(tool_name="weather_tool", visit_count=0) + agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name=agent_name) + debug_path = str(tmp_path / "debug_states") + + try: + mock_agent_with_tool_calls_sync.run( + messages=messages, break_point=agent_breakpoint, debug_path=debug_path, agent_name=agent_name + ) + except BreakpointException: + pass + + state_files = list(Path(debug_path).glob(agent_name + "_tool_invoker_*.json")) + assert len(state_files) > 0 + latest_state_file = str(max(state_files, key=os.path.getctime)) + + resume_state = load_state(latest_state_file) + + result = mock_agent_with_tool_calls_sync.run( + messages=[ChatMessage.from_user("Continue from where we left off.")], resume_state=resume_state + ) + + assert "messages" in result + assert "last_message" in result + assert len(result["messages"]) > 0 + + +def test_invalid_combination_breakpoint_and_resume_state(mock_agent_with_tool_calls_sync): # noqa: F811 + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + tool_bp = create_tool_breakpoint(tool_name="weather_tool", visit_count=0) + agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test_agent") + with pytest.raises(ValueError, match="agent_breakpoint and resume_state cannot be provided at the same time"): + mock_agent_with_tool_calls_sync.run( + messages=messages, break_point=agent_breakpoint, resume_state={"some": "state"} + ) + + +def test_breakpoint_with_invalid_component(mock_agent_with_tool_calls_sync): # noqa: F811 + invalid_bp = Breakpoint(component_name="invalid_breakpoint", visit_count=0) + with pytest.raises(ValueError): + AgentBreakpoint(break_point=invalid_bp, agent_name="test_agent") + + +def test_breakpoint_with_invalid_tool_name(mock_agent_with_tool_calls_sync): # noqa: F811 + tool_breakpoint = create_tool_breakpoint(tool_name="invalid_tool", visit_count=0) + with pytest.raises(ValueError, match="Tool 'invalid_tool' is not available in the agent's tools"): + agent_breakpoints = AgentBreakpoint(break_point=tool_breakpoint, agent_name="test_agent") + mock_agent_with_tool_calls_sync.run( + messages=[ChatMessage.from_user("What's the weather in Berlin?")], break_point=agent_breakpoints + ) diff --git a/test/components/agents/test_agent_breakpoints_utils.py b/test/components/agents/test_agent_breakpoints_utils.py new file mode 100644 index 0000000000..a7beeaecf8 --- /dev/null +++ b/test/components/agents/test_agent_breakpoints_utils.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from haystack.components.agents import Agent +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.dataclasses import ChatMessage, ToolCall +from haystack.dataclasses.breakpoints import Breakpoint, ToolBreakpoint +from haystack.tools import Tool +from test.components.agents.test_agent import ( + MockChatGeneratorWithoutRunAsync, + MockChatGeneratorWithRunAsync, + weather_function, +) + + +def create_chat_generator_breakpoint(visit_count: int = 0) -> Breakpoint: + return Breakpoint(component_name="chat_generator", visit_count=visit_count) + + +def create_tool_breakpoint(tool_name: Optional[str] = None, visit_count: int = 0) -> ToolBreakpoint: + return ToolBreakpoint(component_name="tool_invoker", visit_count=visit_count, tool_name=tool_name) + + +# Common fixtures +@pytest.fixture +def weather_tool(): + return Tool( + name="weather_tool", + description="Provides weather information for a given location.", + parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, + function=weather_function, + ) + + +@pytest.fixture +def debug_path(tmp_path): + return str(tmp_path / "debug_states") + + +@pytest.fixture +def agent_sync(weather_tool): + generator = MockChatGeneratorWithoutRunAsync() + mock_run = MagicMock() + mock_run.return_value = { + "replies": [ + ChatMessage.from_assistant( + "I'll help you check the weather.", + tool_calls=[{"tool_name": "weather_tool", "tool_args": {"location": "Berlin"}}], + ) + ] + } + + def mock_run_with_tools(messages, tools=None, **kwargs): + return mock_run.return_value + + generator.run = mock_run_with_tools + + return Agent( + chat_generator=generator, + tools=[weather_tool], + system_prompt="You are a helpful assistant that can use tools to help users.", + ) + + +@pytest.fixture +def mock_agent_with_tool_calls_sync(monkeypatch, weather_tool): + monkeypatch.setenv("OPENAI_API_KEY", "fake-key") + generator = OpenAIChatGenerator() + mock_messages = [ + ChatMessage.from_assistant("First response"), + ChatMessage.from_assistant(tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]), + ] + agent = Agent(chat_generator=generator, tools=[weather_tool], max_agent_steps=1) + agent.warm_up() + agent.chat_generator.run = MagicMock(return_value={"replies": mock_messages}) + return agent + + +@pytest.fixture +def agent_async(weather_tool): + generator = MockChatGeneratorWithRunAsync() + mock_run_async = AsyncMock() + mock_run_async.return_value = { + "replies": [ + ChatMessage.from_assistant( + "I'll help you check the weather.", + tool_calls=[{"tool_name": "weather_tool", "tool_args": {"location": "Berlin"}}], + ) + ] + } + + async def mock_run_async_with_tools(messages, tools=None, **kwargs): + return mock_run_async.return_value + + generator.run_async = mock_run_async_with_tools + return Agent( + chat_generator=generator, + tools=[weather_tool], + system_prompt="You are a helpful assistant that can use tools to help users.", + ) + + +@pytest.fixture +def mock_agent_with_tool_calls_async(monkeypatch, weather_tool): + monkeypatch.setenv("OPENAI_API_KEY", "fake-key") + generator = MockChatGeneratorWithRunAsync() + mock_messages = [ + ChatMessage.from_assistant("First response"), + ChatMessage.from_assistant(tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]), + ] + agent = Agent(chat_generator=generator, tools=[weather_tool], max_agent_steps=1) + agent.warm_up() + agent.chat_generator.run_async = AsyncMock(return_value={"replies": mock_messages}) + return agent diff --git a/test/components/agents/test_state_class_experimental.py b/test/components/agents/test_state_class_experimental.py new file mode 100644 index 0000000000..9b7d3d7728 --- /dev/null +++ b/test/components/agents/test_state_class_experimental.py @@ -0,0 +1,477 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import inspect +from dataclasses import dataclass +from typing import Dict, Generic, List, Optional, TypeVar, Union + +import pytest + +from haystack.components.agents.state.state import ( + State, + _is_list_type, + _is_valid_type, + _schema_from_dict, + _schema_to_dict, + _validate_schema, + merge_lists, +) +from haystack.dataclasses import ChatMessage + + +@pytest.fixture +def basic_schema(): + return {"numbers": {"type": list}, "metadata": {"type": dict}, "name": {"type": str}} + + +def numbers_handler(current, new): + if current is None: + return sorted(set(new)) + return sorted(set(current + new)) + + +@pytest.fixture +def complex_schema(): + return {"numbers": {"type": list, "handler": numbers_handler}, "metadata": {"type": dict}, "name": {"type": str}} + + +def test_is_list_type(): + assert _is_list_type(list) is True + assert _is_list_type(List[int]) is True + assert _is_list_type(List[str]) is True + assert _is_list_type(dict) is False + assert _is_list_type(int) is False + assert _is_list_type(Union[List[int], None]) is False + + +class TestMergeLists: + def test_merge_two_lists(self): + current = [1, 2, 3] + new = [4, 5, 6] + result = merge_lists(current, new) + assert result == [1, 2, 3, 4, 5, 6] + # Ensure original lists weren't modified + assert current == [1, 2, 3] + assert new == [4, 5, 6] + + def test_append_to_list(self): + current = [1, 2, 3] + new = 4 + result = merge_lists(current, new) + assert result == [1, 2, 3, 4] + assert current == [1, 2, 3] # Ensure original wasn't modified + + def test_create_new_list(self): + current = 1 + new = 2 + result = merge_lists(current, new) + assert result == [1, 2] + + def test_replace_with_list(self): + current = 1 + new = [2, 3] + result = merge_lists(current, new) + assert result == [1, 2, 3] + + +class TestIsValidType: + def test_builtin_types(self): + assert _is_valid_type(str) is True + assert _is_valid_type(int) is True + assert _is_valid_type(dict) is True + assert _is_valid_type(list) is True + assert _is_valid_type(tuple) is True + assert _is_valid_type(set) is True + assert _is_valid_type(bool) is True + assert _is_valid_type(float) is True + + def test_generic_types(self): + assert _is_valid_type(List[str]) is True + assert _is_valid_type(Dict[str, int]) is True + assert _is_valid_type(List[Dict[str, int]]) is True + assert _is_valid_type(Dict[str, List[int]]) is True + + def test_custom_classes(self): + @dataclass + class CustomClass: + value: int + + T = TypeVar("T") + + class GenericCustomClass(Generic[T]): + pass + + # Test regular and generic custom classes + assert _is_valid_type(CustomClass) is True + assert _is_valid_type(GenericCustomClass) is True + assert _is_valid_type(GenericCustomClass[int]) is True + + # Test generic types with custom classes + assert _is_valid_type(List[CustomClass]) is True + assert _is_valid_type(Dict[str, CustomClass]) is True + assert _is_valid_type(Dict[str, GenericCustomClass[int]]) is True + + def test_invalid_types(self): + # Test regular values + assert _is_valid_type(42) is False + assert _is_valid_type("string") is False + assert _is_valid_type([1, 2, 3]) is False + assert _is_valid_type({"a": 1}) is False + assert _is_valid_type(True) is False + + # Test class instances + @dataclass + class SampleClass: + value: int + + instance = SampleClass(42) + assert _is_valid_type(instance) is False + + # Test callable objects + assert _is_valid_type(len) is False + assert _is_valid_type(lambda x: x) is False + assert _is_valid_type(print) is False + + def test_union_and_optional_types(self): + # Test basic Union types + assert _is_valid_type(Union[str, int]) is True + assert _is_valid_type(Union[str, None]) is True + assert _is_valid_type(Union[List[int], Dict[str, str]]) is True + + # Test Optional types (which are Union[T, None]) + assert _is_valid_type(Optional[str]) is True + assert _is_valid_type(Optional[List[int]]) is True + assert _is_valid_type(Optional[Dict[str, list]]) is True + + # Test that Union itself is not a valid type (only instantiated Unions are) + assert _is_valid_type(Union) is False + + def test_nested_generic_types(self): + assert _is_valid_type(List[List[Dict[str, List[int]]]]) is True + assert _is_valid_type(Dict[str, List[Dict[str, set]]]) is True + assert _is_valid_type(Dict[str, Optional[List[int]]]) is True + assert _is_valid_type(List[Union[str, Dict[str, List[int]]]]) is True + + def test_edge_cases(self): + # Test None and NoneType + assert _is_valid_type(None) is False + assert _is_valid_type(type(None)) is True + + # Test functions and methods + def sample_func(): + pass + + assert _is_valid_type(sample_func) is False + assert _is_valid_type(type(sample_func)) is True + + # Test modules + assert _is_valid_type(inspect) is False + + # Test type itself + assert _is_valid_type(type) is True + + @pytest.mark.parametrize( + "test_input,expected", + [ + (str, True), + (int, True), + (List[int], True), + (Dict[str, int], True), + (Union[str, int], True), + (Optional[str], True), + (42, False), + ("string", False), + ([1, 2, 3], False), + (lambda x: x, False), + ], + ) + def test_parametrized_cases(self, test_input, expected): + assert _is_valid_type(test_input) is expected + + +class TestState: + def test_validate_schema_valid(self, basic_schema): + # Should not raise any exceptions + _validate_schema(basic_schema) + + def test_validate_schema_invalid_type(self): + invalid_schema = {"test": {"type": "not_a_type"}} + with pytest.raises(ValueError, match="must be a Python type"): + _validate_schema(invalid_schema) + + def test_validate_schema_missing_type(self): + invalid_schema = {"test": {"handler": lambda x, y: x + y}} + with pytest.raises(ValueError, match="missing a 'type' entry"): + _validate_schema(invalid_schema) + + def test_validate_schema_invalid_handler(self): + invalid_schema = {"test": {"type": list, "handler": "not_callable"}} + with pytest.raises(ValueError, match="must be callable or None"): + _validate_schema(invalid_schema) + + def test_state_initialization(self, basic_schema): + # Test empty initialization + state = State(basic_schema) + assert state.data == {} + + # Test initialization with data + initial_data = {"numbers": [1, 2, 3], "name": "test"} + state = State(basic_schema, initial_data) + assert state.data["numbers"] == [1, 2, 3] + assert state.data["name"] == "test" + + def test_state_get(self, basic_schema): + state = State(basic_schema, {"name": "test"}) + assert state.get("name") == "test" + assert state.get("non_existent") is None + assert state.get("non_existent", "default") == "default" + + def test_state_set_basic(self, basic_schema): + state = State(basic_schema) + + # Test setting new values + state.set("numbers", [1, 2]) + assert state.get("numbers") == [1, 2] + + # Test updating existing values + state.set("numbers", [3, 4]) + assert state.get("numbers") == [1, 2, 3, 4] + + def test_state_set_with_handler(self, complex_schema): + state = State(complex_schema) + + # Test custom handler for numbers + state.set("numbers", [3, 2, 1]) + assert state.get("numbers") == [1, 2, 3] + + state.set("numbers", [6, 5, 4]) + assert state.get("numbers") == [1, 2, 3, 4, 5, 6] + + def test_state_set_with_handler_override(self, basic_schema): + state = State(basic_schema) + + # Custom handler that concatenates strings + custom_handler = lambda current, new: f"{current}-{new}" if current else new + + state.set("name", "first") + state.set("name", "second", handler_override=custom_handler) + assert state.get("name") == "first-second" + + def test_state_has(self, basic_schema): + state = State(basic_schema, {"name": "test"}) + assert state.has("name") is True + assert state.has("non_existent") is False + + def test_state_empty_schema(self): + state = State({}) + assert state.data == {} + + # Instead of comparing the entire schema directly, check structure separately + assert "messages" in state.schema + assert state.schema["messages"]["type"] == List[ChatMessage] + assert callable(state.schema["messages"]["handler"]) + + with pytest.raises(ValueError, match="Key 'any_key' not found in schema"): + state.set("any_key", "value") + + def test_state_none_values(self, basic_schema): + state = State(basic_schema) + state.set("name", None) + assert state.get("name") is None + state.set("name", "value") + assert state.get("name") == "value" + + def test_state_merge_lists(self, basic_schema): + state = State(basic_schema) + state.set("numbers", "not_a_list") + assert state.get("numbers") == ["not_a_list"] + state.set("numbers", [1, 2]) + assert state.get("numbers") == ["not_a_list", 1, 2] + + def test_state_nested_structures(self): + schema = { + "complex": { + "type": Dict[str, List[int]], + "handler": lambda current, new: { + k: current.get(k, []) + new.get(k, []) for k in set(current.keys()) | set(new.keys()) + } + if current + else new, + } + } + + state = State(schema) + state.set("complex", {"a": [1, 2], "b": [3, 4]}) + state.set("complex", {"b": [5, 6], "c": [7, 8]}) + + expected = {"a": [1, 2], "b": [3, 4, 5, 6], "c": [7, 8]} + assert state.get("complex") == expected + + def test_schema_to_dict(self, basic_schema): + expected_dict = {"numbers": {"type": "list"}, "metadata": {"type": "dict"}, "name": {"type": "str"}} + result = _schema_to_dict(basic_schema) + assert result == expected_dict + + def test_schema_to_dict_with_handlers(self, complex_schema): + expected_dict = { + "numbers": {"type": "list", "handler": "test_state_class.numbers_handler"}, + "metadata": {"type": "dict"}, + "name": {"type": "str"}, + } + result = _schema_to_dict(complex_schema) + assert result == expected_dict + + def test_schema_from_dict(self, basic_schema): + schema_dict = {"numbers": {"type": "list"}, "metadata": {"type": "dict"}, "name": {"type": "str"}} + result = _schema_from_dict(schema_dict) + assert result == basic_schema + + def test_schema_from_dict_with_handlers(self, complex_schema): + schema_dict = { + "numbers": {"type": "list", "handler": "test_state_class.numbers_handler"}, + "metadata": {"type": "dict"}, + "name": {"type": "str"}, + } + result = _schema_from_dict(schema_dict) + assert result == complex_schema + + def test_state_mutability(self): + state = State({"my_list": {"type": list}}, {"my_list": [1, 2]}) + + my_list = state.get("my_list") + my_list.append(3) + + assert state.get("my_list") == [1, 2] + + def test_state_to_dict(self): + # we test dict, a python type and a haystack dataclass + state_schema = { + "numbers": {"type": int}, + "messages": {"type": List[ChatMessage]}, + "dict_of_lists": {"type": dict}, + } + + data = { + "numbers": 1, + "messages": [ChatMessage.from_user(text="Hello, world!")], + "dict_of_lists": {"numbers": [1, 2, 3]}, + } + state = State(state_schema, data) + state_dict = state.to_dict() + assert state_dict["schema"] == { + "numbers": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"}, + "messages": { + "type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", + "handler": "haystack.components.agents.state.state_utils.merge_lists", + }, + "dict_of_lists": {"type": "dict", "handler": "haystack.components.agents.state.state_utils.replace_values"}, + } + assert state_dict["data"] == { + "serialization_schema": { + "type": "object", + "properties": { + "numbers": {"type": "integer"}, + "messages": {"type": "array", "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}}, + "dict_of_lists": { + "type": "object", + "properties": {"numbers": {"type": "array", "items": {"type": "integer"}}}, + }, + }, + }, + "serialized_data": { + "numbers": 1, + "messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}], + "dict_of_lists": {"numbers": [1, 2, 3]}, + }, + } + + def test_state_from_dict(self): + state_dict = { + "schema": { + "numbers": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"}, + "messages": { + "type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", + "handler": "haystack.components.agents.state.state_utils.merge_lists", + }, + "dict_of_lists": { + "type": "dict", + "handler": "haystack.components.agents.state.state_utils.replace_values", + }, + }, + "data": { + "serialization_schema": { + "type": "object", + "properties": { + "numbers": {"type": "integer"}, + "messages": { + "type": "array", + "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}, + }, + "dict_of_lists": { + "type": "object", + "properties": {"numbers": {"type": "array", "items": {"type": "integer"}}}, + }, + }, + }, + "serialized_data": { + "numbers": 1, + "messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}], + "dict_of_lists": {"numbers": [1, 2, 3]}, + }, + }, + } + state = State.from_dict(state_dict) + # Check types are correctly converted + assert state.schema["numbers"]["type"] == int + assert state.schema["dict_of_lists"]["type"] == dict + # Check handlers are functions, not comparing exact functions as they might be different references + assert callable(state.schema["numbers"]["handler"]) + assert callable(state.schema["messages"]["handler"]) + assert callable(state.schema["dict_of_lists"]["handler"]) + # Check data is correct + assert state.data["numbers"] == 1 + assert state.data["messages"] == [ChatMessage.from_user(text="Hello, world!")] + assert state.data["dict_of_lists"] == {"numbers": [1, 2, 3]} + + def test_state_from_dict_legacy(self): + # this is the old format of the state dictionary + # it is kept for backward compatibility + # it will be removed in Haystack 2.16.0 + state_dict = { + "schema": { + "numbers": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"}, + "messages": { + "type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", + "handler": "haystack.components.agents.state.state_utils.merge_lists", + }, + "dict_of_lists": { + "type": "dict", + "handler": "haystack.components.agents.state.state_utils.replace_values", + }, + }, + "data": { + "serialization_schema": { + "numbers": {"type": "integer"}, + "messages": {"type": "array", "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}}, + "dict_of_lists": {"type": "object"}, + }, + "serialized_data": { + "numbers": 1, + "messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}], + "dict_of_lists": {"numbers": [1, 2, 3]}, + }, + }, + } + state = State.from_dict(state_dict) + # Check types are correctly converted + assert state.schema["numbers"]["type"] == int + assert state.schema["dict_of_lists"]["type"] == dict + # Check handlers are functions, not comparing exact functions as they might be different references + assert callable(state.schema["numbers"]["handler"]) + assert callable(state.schema["messages"]["handler"]) + assert callable(state.schema["dict_of_lists"]["handler"]) + # Check data is correct + assert state.data["numbers"] == 1 + assert state.data["messages"] == [ChatMessage.from_user(text="Hello, world!")] + assert state.data["dict_of_lists"] == {"numbers": [1, 2, 3]} diff --git a/test/conftest.py b/test/conftest.py index 8dde8737a3..98477b957e 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -11,6 +11,7 @@ import pytest from haystack import component, tracing +from haystack.core.pipeline.breakpoint import load_state from haystack.testing.test_utils import set_all_seeds from test.tracing.utils import SpyingTracer @@ -80,3 +81,34 @@ def spying_tracer() -> Generator[SpyingTracer, None, None]: # Make sure to disable tracing after the test to avoid affecting other tests tracing.disable_tracing() + + +def load_and_resume_pipeline_state(pipeline, output_directory: Path, component: str, data: Dict = None) -> Dict: + """ + Utility function to load and resume pipeline state from a breakpoint file. + + Args: + pipeline: The pipeline instance to resume + output_directory: Directory containing the breakpoint files + component: Component name to look for in breakpoint files + data: Data to pass to the pipeline run (defaults to empty dict) + + Returns: + Dict containing the pipeline run results + + Raises: + ValueError: If no breakpoint file is found for the given component + """ + data = data or {} + all_files = list(output_directory.glob("*")) + file_found = False + + for full_path in all_files: + f_name = Path(full_path).name + if str(f_name).startswith(component): + resume_state = load_state(full_path) + return pipeline.run(data=data, resume_state=resume_state) + + if not file_found: + msg = f"No files found for {component} in {output_directory}." + raise ValueError(msg) diff --git a/test/core/pipeline/test_breakpoint.py b/test/core/pipeline/test_breakpoint.py new file mode 100644 index 0000000000..8f6c671171 --- /dev/null +++ b/test/core/pipeline/test_breakpoint.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json + +import pytest + +from haystack.core.pipeline.breakpoint import ( + _transform_json_structure, + _validate_break_point, + _validate_resume_state, + load_state, +) + + +def test_transform_json_structure_unwraps_sender_value(): + data = { + "key1": [{"sender": None, "value": "some value"}], + "key2": [{"sender": "comp1", "value": 42}], + "key3": "direct value", + } + + result = _transform_json_structure(data) + + assert result == {"key1": "some value", "key2": 42, "key3": "direct value"} + + +def test_transform_json_structure_handles_nested_structures(): + data = { + "key1": [{"sender": None, "value": "value1"}], + "key2": {"nested": [{"sender": "comp1", "value": "value2"}], "direct": "value3"}, + "key3": [[{"sender": None, "value": "value4"}], [{"sender": "comp2", "value": "value5"}]], + } + + result = _transform_json_structure(data) + + assert result == {"key1": "value1", "key2": {"nested": "value2", "direct": "value3"}, "key3": ["value4", "value5"]} + + +def test_validate_resume_state_validates_required_keys(): + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + # Missing pipeline_state + } + + with pytest.raises(ValueError, match="Invalid state file: missing required keys"): + _validate_resume_state(state) + + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_state": { + "inputs": {}, + "component_visits": {}, + # Missing ordered_component_names + }, + } + + with pytest.raises(ValueError, match="Invalid pipeline_state: missing required keys"): + _validate_resume_state(state) + + +def test_validate_resume_state_validates_component_consistency(): + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_state": { + "inputs": {}, + "component_visits": {"comp1": 0, "comp2": 0}, + "ordered_component_names": ["comp1", "comp3"], # inconsistent with component_visits + }, + } + + with pytest.raises(ValueError, match="Inconsistent state: components in pipeline_state"): + _validate_resume_state(state) + + +def test_validate_resume_state_validates_valid_state(): + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_state": { + "inputs": {}, + "component_visits": {"comp1": 0, "comp2": 0}, + "ordered_component_names": ["comp1", "comp2"], + }, + } + + _validate_resume_state(state) # should not raise any exception + + +def test_load_state_loads_valid_state(tmp_path): + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_state": { + "inputs": {}, + "component_visits": {"comp1": 0, "comp2": 0}, + "ordered_component_names": ["comp1", "comp2"], + }, + } + state_file = tmp_path / "state.json" + with open(state_file, "w") as f: + json.dump(state, f) + + loaded_state = load_state(state_file) + assert loaded_state == state + + +def test_load_state_handles_invalid_state(tmp_path): + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_state": { + "inputs": {}, + "component_visits": {"comp1": 0, "comp2": 0}, + "ordered_component_names": ["comp1", "comp3"], # inconsistent with component_visits + }, + } + + state_file = tmp_path / "invalid_state.json" + with open(state_file, "w") as f: + json.dump(state, f) + + with pytest.raises(ValueError, match="Invalid pipeline state from"): + load_state(state_file) diff --git a/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py new file mode 100644 index 0000000000..bec638846c --- /dev/null +++ b/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import MagicMock, patch + +import pytest + +from haystack.components.builders.answer_builder import AnswerBuilder +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.components.joiners import AnswerJoiner +from haystack.core.errors import BreakpointException +from haystack.core.pipeline.pipeline import Pipeline +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import Breakpoint +from haystack.utils.auth import Secret +from test.conftest import load_and_resume_pipeline_state + + +class TestPipelineBreakpoints: + @pytest.fixture + def mock_openai_chat_generator(self): + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + mock_completion = MagicMock() + mock_completion.choices = [ + MagicMock( + finish_reason="stop", + index=0, + message=MagicMock( + content="Natural Language Processing (NLP) is a field of AI focused on enabling " + "computers to understand, interpret, and generate human language." + ), + ) + ] + mock_completion.usage = {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97} + + mock_chat_completion_create.return_value = mock_completion + + # Create a mock for the OpenAIChatGenerator + @patch.dict("os.environ", {"OPENAI_API_KEY": "test-api-key"}) + def create_mock_generator(model_name): + generator = OpenAIChatGenerator(model=model_name, api_key=Secret.from_env_var("OPENAI_API_KEY")) + + # Mock the run method + def mock_run(messages, streaming_callback=None, generation_kwargs=None, tools=None, tools_strict=None): + if "gpt-4" in model_name: + content = ( + "Natural Language Processing (NLP) is a field of AI focused on enabling computers " + "to understand, interpret, and generate human language." + ) + else: + content = "NLP is a branch of AI that helps machines understand and process human language." + + return { + "replies": [ChatMessage.from_assistant(content)], + "meta": { + "model": model_name, + "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + }, + } + + # Replace the run method with our mock + generator.run = mock_run + + return generator + + yield create_mock_generator + + @pytest.fixture + def answer_join_pipeline(self, mock_openai_chat_generator): + """ + Creates a pipeline with mocked OpenAI components. + """ + # Create the pipeline with mocked components + pipeline = Pipeline(connection_type_validation=False) + pipeline.add_component("gpt-4o", mock_openai_chat_generator("gpt-4o")) + pipeline.add_component("gpt-3", mock_openai_chat_generator("gpt-3.5-turbo")) + pipeline.add_component("answer_builder_a", AnswerBuilder()) + pipeline.add_component("answer_builder_b", AnswerBuilder()) + pipeline.add_component("answer_joiner", AnswerJoiner()) + pipeline.connect("gpt-4o.replies", "answer_builder_a") + pipeline.connect("gpt-3.replies", "answer_builder_b") + pipeline.connect("answer_builder_a.answers", "answer_joiner") + pipeline.connect("answer_builder_b.answers", "answer_joiner") + + return pipeline + + @pytest.fixture(scope="session") + def output_directory(self, tmp_path_factory): + return tmp_path_factory.mktemp("output_files") + + components = [ + Breakpoint("gpt-4o", 0), + Breakpoint("gpt-3", 0), + Breakpoint("answer_builder_a", 0), + Breakpoint("answer_builder_b", 0), + Breakpoint("answer_joiner", 0), + ] + + @pytest.mark.parametrize("component", components) + @pytest.mark.integration + def test_pipeline_breakpoints_answer_joiner(self, answer_join_pipeline, output_directory, component): + """ + Test that an answer joiner pipeline can be executed with breakpoints at each component. + """ + query = "What's Natural Language Processing?" + messages = [ + ChatMessage.from_system("You are a helpful, respectful and honest assistant. Be super concise."), + ChatMessage.from_user(query), + ] + data = { + "gpt-4o": {"messages": messages}, + "gpt-3": {"messages": messages}, + "answer_builder_a": {"query": query}, + "answer_builder_b": {"query": query}, + } + + try: + _ = answer_join_pipeline.run(data, break_point=component, debug_path=str(output_directory)) + except BreakpointException: + pass + + result = load_and_resume_pipeline_state(answer_join_pipeline, output_directory, component.component_name, data) + assert result["answer_joiner"] diff --git a/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py new file mode 100644 index 0000000000..4fb41dbc87 --- /dev/null +++ b/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import List +from unittest.mock import MagicMock, patch + +import pytest + +from haystack.components.converters import OutputAdapter +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.components.joiners import BranchJoiner +from haystack.components.validators import JsonSchemaValidator +from haystack.core.errors import BreakpointException +from haystack.core.pipeline.pipeline import Pipeline +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import Breakpoint +from haystack.utils.auth import Secret +from test.conftest import load_and_resume_pipeline_state + + +class TestPipelineBreakpoints: + @pytest.fixture + def mock_openai_chat_generator(self): + """ + Creates a mock for the OpenAIChatGenerator. + """ + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + # Create mock completion objects + mock_completion = MagicMock() + mock_completion.choices = [ + MagicMock( + finish_reason="stop", + index=0, + message=MagicMock( + content='{"first_name": "Peter", "last_name": "Parker", "nationality": "American"}' + ), + ) + ] + mock_completion.usage = {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97} + + mock_chat_completion_create.return_value = mock_completion + + # Create a mock for the OpenAIChatGenerator + @patch.dict("os.environ", {"OPENAI_API_KEY": "test-api-key"}) + def create_mock_generator(model_name): + generator = OpenAIChatGenerator(model=model_name, api_key=Secret.from_env_var("OPENAI_API_KEY")) + + # Mock the run method + def mock_run(messages, streaming_callback=None, generation_kwargs=None, tools=None, tools_strict=None): + content = '{"first_name": "Peter", "last_name": "Parker", "nationality": "American"}' + + return { + "replies": [ChatMessage.from_assistant(content)], + "meta": { + "model": model_name, + "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + }, + } + + # Replace the run method with our mock + generator.run = mock_run + + return generator + + yield create_mock_generator + + @pytest.fixture + def branch_joiner_pipeline(self, mock_openai_chat_generator): + person_schema = { + "type": "object", + "properties": { + "first_name": {"type": "string", "pattern": "^[A-Z][a-z]+$"}, + "last_name": {"type": "string", "pattern": "^[A-Z][a-z]+$"}, + "nationality": {"type": "string", "enum": ["Italian", "Portuguese", "American"]}, + }, + "required": ["first_name", "last_name", "nationality"], + } + + pipe = Pipeline() + pipe.add_component("joiner", BranchJoiner(List[ChatMessage])) + pipe.add_component("fc_llm", mock_openai_chat_generator("gpt-4o-mini")) + pipe.add_component("validator", JsonSchemaValidator(json_schema=person_schema)) + pipe.add_component("adapter", OutputAdapter("{{chat_message}}", List[ChatMessage], unsafe=True)) + + pipe.connect("adapter", "joiner") + pipe.connect("joiner", "fc_llm") + pipe.connect("fc_llm.replies", "validator.messages") + pipe.connect("validator.validation_error", "joiner") + + return pipe + + @pytest.fixture(scope="session") + def output_directory(self, tmp_path_factory): + return tmp_path_factory.mktemp("output_files") + + components = [ + Breakpoint("joiner", 0), + Breakpoint("fc_llm", 0), + Breakpoint("validator", 0), + Breakpoint("adapter", 0), + ] + + @pytest.mark.parametrize("component", components) + @pytest.mark.integration + def test_pipeline_breakpoints_branch_joiner(self, branch_joiner_pipeline, output_directory, component): + data = { + "fc_llm": {"generation_kwargs": {"response_format": {"type": "json_object"}}}, + "adapter": {"chat_message": [ChatMessage.from_user("Create JSON from Peter Parker")]}, + } + + try: + _ = branch_joiner_pipeline.run(data, break_point=component, debug_path=str(output_directory)) + except BreakpointException: + pass + + result = load_and_resume_pipeline_state( + branch_joiner_pipeline, output_directory, component.component_name, data + ) + assert result["validator"], "The result should be valid according to the schema." diff --git a/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py new file mode 100644 index 0000000000..5c12c40d64 --- /dev/null +++ b/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import List +from unittest.mock import MagicMock, patch + +import pytest + +from haystack import Pipeline +from haystack.components.builders import ChatPromptBuilder +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.components.joiners import ListJoiner +from haystack.core.errors import BreakpointException +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import Breakpoint +from haystack.utils.auth import Secret +from test.conftest import load_and_resume_pipeline_state + + +class TestPipelineBreakpoints: + @pytest.fixture + def mock_openai_chat_generator(self): + """ + Creates a mock for the OpenAIChatGenerator. + """ + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + # Create mock completion objects + mock_completion = MagicMock() + mock_completion.choices = [ + MagicMock( + finish_reason="stop", + index=0, + message=MagicMock( + content="Nuclear physics is the study of atomic nuclei, their constituents, " + "and their interactions." + ), + ) + ] + mock_completion.usage = {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97} + + mock_chat_completion_create.return_value = mock_completion + + # Create a mock for the OpenAIChatGenerator + @patch.dict("os.environ", {"OPENAI_API_KEY": "test-api-key"}) + def create_mock_generator(model_name): + generator = OpenAIChatGenerator(model=model_name, api_key=Secret.from_env_var("OPENAI_API_KEY")) + + # Mock the run method + def mock_run(messages, streaming_callback=None, generation_kwargs=None, tools=None, tools_strict=None): + # Check if this is a feedback request or a regular query + if any("feedback" in msg.text.lower() for msg in messages): + content = ( + "Score: 8/10. The answer is concise and accurate, providing a good overview " + "of nuclear physics." + ) + else: + content = ( + "Nuclear physics is the study of atomic nuclei, their constituents, and their interactions." + ) + + return { + "replies": [ChatMessage.from_assistant(content)], + "meta": { + "model": model_name, + "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + }, + } + + # Replace the run method with our mock + generator.run = mock_run + + return generator + + yield create_mock_generator + + @pytest.fixture + def list_joiner_pipeline(self, mock_openai_chat_generator): + user_message = [ChatMessage.from_user("Give a brief answer the following question: {{query}}")] + + feedback_prompt = """ + You are given a question and an answer. + Your task is to provide a score and a brief feedback on the answer. + Question: {{query}} + Answer: {{response}} + """ + + feedback_message = [ChatMessage.from_system(feedback_prompt)] + + prompt_builder = ChatPromptBuilder(template=user_message) + feedback_prompt_builder = ChatPromptBuilder(template=feedback_message) + llm = mock_openai_chat_generator("gpt-4o-mini") + feedback_llm = mock_openai_chat_generator("gpt-4o-mini") + + pipe = Pipeline() + pipe.add_component("prompt_builder", prompt_builder) + pipe.add_component("llm", llm) + pipe.add_component("feedback_prompt_builder", feedback_prompt_builder) + pipe.add_component("feedback_llm", feedback_llm) + pipe.add_component("list_joiner", ListJoiner(List[ChatMessage])) + + pipe.connect("prompt_builder.prompt", "llm.messages") + pipe.connect("prompt_builder.prompt", "list_joiner") + pipe.connect("llm.replies", "list_joiner") + pipe.connect("llm.replies", "feedback_prompt_builder.response") + pipe.connect("feedback_prompt_builder.prompt", "feedback_llm.messages") + pipe.connect("feedback_llm.replies", "list_joiner") + + return pipe + + @pytest.fixture(scope="session") + def output_directory(self, tmp_path_factory): + return tmp_path_factory.mktemp("output_files") + + components = [ + Breakpoint("prompt_builder", 0), + Breakpoint("llm", 0), + Breakpoint("feedback_prompt_builder", 0), + Breakpoint("feedback_llm", 0), + Breakpoint("list_joiner", 0), + ] + + @pytest.mark.parametrize("component", components) + @pytest.mark.integration + def test_list_joiner_pipeline(self, list_joiner_pipeline, output_directory, component): + query = "What is nuclear physics?" + data = { + "prompt_builder": {"template_variables": {"query": query}}, + "feedback_prompt_builder": {"template_variables": {"query": query}}, + } + + try: + _ = list_joiner_pipeline.run(data, break_point=component, debug_path=str(output_directory)) + except BreakpointException: + pass + + result = load_and_resume_pipeline_state(list_joiner_pipeline, output_directory, component.component_name, data) + assert result["list_joiner"] diff --git a/test/core/pipeline/test_pipeline_breakpoints_loops.py b/test/core/pipeline/test_pipeline_breakpoints_loops.py new file mode 100644 index 0000000000..2b63f4da9b --- /dev/null +++ b/test/core/pipeline/test_pipeline_breakpoints_loops.py @@ -0,0 +1,236 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +from pathlib import Path +from typing import List, Optional +from unittest.mock import MagicMock, patch + +import pydantic +import pytest +from colorama import Fore +from pydantic import BaseModel, ValidationError + +from haystack import component +from haystack.components.builders import ChatPromptBuilder +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.core.errors import BreakpointException +from haystack.core.pipeline.breakpoint import load_state +from haystack.core.pipeline.pipeline import Pipeline +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import Breakpoint +from haystack.utils.auth import Secret + + +# Define the component input parameters +@component +class OutputValidator: + def __init__(self, pydantic_model: pydantic.BaseModel): + self.pydantic_model = pydantic_model + self.iteration_counter = 0 + + # Define the component output + @component.output_types(valid_replies=List[str], invalid_replies=Optional[List[str]], error_message=Optional[str]) + def run(self, replies: List[ChatMessage]): + self.iteration_counter += 1 + + ## Try to parse the LLM's reply ## + # If the LLM's reply is a valid object, return `"valid_replies"` + try: + output_dict = json.loads(replies[0].text) + self.pydantic_model.model_validate(output_dict) + print( + Fore.GREEN + f"OutputValidator at Iteration {self.iteration_counter}: " + f"Valid JSON from LLM - No need for looping: {replies[0]}" + ) + return {"valid_replies": replies} + + # If the LLM's reply is corrupted or not valid, return "invalid_replies" and the "error_message" for the LLM + # to try again + except (ValueError, ValidationError) as e: + print( + Fore.RED + + f"OutputValidator at Iteration {self.iteration_counter}: Invalid JSON from LLM - Let's try again.\n" + f"Output from LLM:\n {replies[0]} \n" + f"Error from OutputValidator: {e}" + ) + return {"invalid_replies": replies, "error_message": str(e)} + + +class City(BaseModel): + name: str + country: str + population: int + + +class CitiesData(BaseModel): + cities: List[City] + + +class TestPipelineBreakpointsLoops: + """ + This class contains tests for pipelines with validation loops and breakpoints. + """ + + @pytest.fixture + def mock_openai_chat_generator(self): + """ + Creates a mock for the OpenAIChatGenerator that returns valid JSON responses. + """ + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + # Create mock completion objects + mock_completion = MagicMock() + mock_completion.choices = [ + MagicMock( + finish_reason="stop", + index=0, + message=MagicMock( + content='{"cities": [{"name": "Berlin", "country": "Germany", "population": 3850809}, ' + '{"name": "Paris", "country": "France", "population": 2161000}, ' + '{"name": "Lisbon", "country": "Portugal", "population": 504718}]}' + ), + ) + ] + mock_completion.usage = {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97} + + mock_chat_completion_create.return_value = mock_completion + + # Create a mock for the OpenAIChatGenerator + @patch.dict("os.environ", {"OPENAI_API_KEY": "test-api-key"}) + def create_mock_generator(): + generator = OpenAIChatGenerator(api_key=Secret.from_env_var("OPENAI_API_KEY")) + + # Mock the run method + def mock_run(messages, streaming_callback=None, generation_kwargs=None, tools=None, tools_strict=None): + # Check if this is a retry attempt + if any("You already created the following output" in msg.text for msg in messages): + # Return a valid JSON response for retry attempts + return { + "replies": [ + ChatMessage.from_assistant( + '{"cities": [{"name": "Berlin", "country": "Germany", "population": 3850809}, ' + '{"name": "Paris", "country": "France", "population": 2161000}, ' + '{"name": "Lisbon", "country": "Portugal", "population": 504718}]}' + ) + ], + "meta": { + "model": "gpt-4", + "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + }, + } + else: + # Return a valid JSON response for initial attempts + return { + "replies": [ + ChatMessage.from_assistant( + '{"cities": [{"name": "Berlin", "country": "Germany", "population": 3850809}, ' + '{"name": "Paris", "country": "France", "population": 2161000}, ' + '{"name": "Lisbon", "country": "Portugal", "population": 504718}]}' + ) + ], + "meta": { + "model": "gpt-4", + "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + }, + } + + # Replace the run method with our mock + generator.run = mock_run + + return generator + + yield create_mock_generator + + @pytest.fixture + def validation_loop_pipeline(self, mock_openai_chat_generator): + """Create a pipeline with validation loops for testing.""" + prompt_template = [ + ChatMessage.from_user( + """ + Create a JSON object from the information present in this passage: {{passage}}. + Only use information that is present in the passage. Follow this JSON schema, but only return the + actual instances without any additional schema definition: + {{schema}} + Make sure your response is a dict and not a list. + {% if invalid_replies and error_message %} + You already created the following output in a previous attempt: {{invalid_replies}} + However, this doesn't comply with the format requirements from above and triggered this + Python exception: {{error_message}} + Correct the output and try again. Just return the corrected output without any extra explanations. + {% endif %} + """ + ) + ] + + pipeline = Pipeline(max_runs_per_component=5) + pipeline.add_component(instance=ChatPromptBuilder(template=prompt_template), name="prompt_builder") + pipeline.add_component(instance=mock_openai_chat_generator(), name="llm") + pipeline.add_component(instance=OutputValidator(pydantic_model=CitiesData), name="output_validator") + + # Connect components + pipeline.connect("prompt_builder.prompt", "llm.messages") + pipeline.connect("llm.replies", "output_validator") + pipeline.connect("output_validator.invalid_replies", "prompt_builder.invalid_replies") + pipeline.connect("output_validator.error_message", "prompt_builder.error_message") + + return pipeline + + @pytest.fixture(scope="session") + def output_directory(self, tmp_path_factory): + return tmp_path_factory.mktemp("output_files") + + @pytest.fixture + def test_data(self): + json_schema = { + "cities": [ + {"name": "Berlin", "country": "Germany", "population": 3850809}, + {"name": "Paris", "country": "France", "population": 2161000}, + {"name": "Lisbon", "country": "Portugal", "population": 504718}, + ] + } + + passage = ( + "Berlin is the capital of Germany. It has a population of 3,850,809. Paris, France's capital, has " + "2.161 million residents. Lisbon is the capital and the largest city of Portugal with the " + "population of 504,718." + ) + + return {"schema": json_schema, "passage": passage} + + components = [Breakpoint("prompt_builder", 0), Breakpoint("llm", 0), Breakpoint("output_validator", 0)] + + @pytest.mark.parametrize("component", components) + @pytest.mark.integration + def test_pipeline_breakpoints_validation_loop( + self, validation_loop_pipeline, output_directory, test_data, component + ): + """ + Test that a pipeline with validation loops can be executed with breakpoints at each component. + """ + data = {"prompt_builder": {"passage": test_data["passage"], "schema": test_data["schema"]}} + + try: + _ = validation_loop_pipeline.run(data, break_point=component, debug_path=str(output_directory)) + except BreakpointException: + pass + + all_files = list(output_directory.glob("*")) + file_found = False + for full_path in all_files: + f_name = Path(full_path).name + if str(f_name).startswith(component.component_name): + file_found = True + resume_state = load_state(full_path) + result = validation_loop_pipeline.run(data={}, resume_state=resume_state) + # Verify the result contains valid output + if "output_validator" in result and "valid_replies" in result["output_validator"]: + valid_reply = result["output_validator"]["valid_replies"][0].text + valid_json = json.loads(valid_reply) + assert isinstance(valid_json, dict) + assert "cities" in valid_json + cities_data = CitiesData.model_validate(valid_json) + assert len(cities_data.cities) == 3 + if not file_found: + msg = f"No files found for {component} in {output_directory}." + raise ValueError(msg) diff --git a/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py b/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py new file mode 100644 index 0000000000..affb836dcc --- /dev/null +++ b/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import MagicMock, patch + +import pytest + +from haystack import Document +from haystack.components.builders.answer_builder import AnswerBuilder +from haystack.components.builders.prompt_builder import PromptBuilder +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder +from haystack.components.generators import OpenAIGenerator +from haystack.components.joiners import DocumentJoiner +from haystack.components.rankers import TransformersSimilarityRanker +from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever +from haystack.components.writers import DocumentWriter +from haystack.core.errors import BreakpointException +from haystack.core.pipeline.pipeline import Pipeline +from haystack.dataclasses.breakpoints import Breakpoint +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack.document_stores.types import DuplicatePolicy +from haystack.utils.auth import Secret +from test.conftest import load_and_resume_pipeline_state + + +class TestPipelineBreakpoints: + """ + This class contains tests for pipelines with breakpoints. + """ + + @pytest.fixture + def mock_sentence_transformers_doc_embedder(self): + with patch( + "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" # noqa: E501 + ) as mock_doc_embedder: + mock_model = MagicMock() + mock_doc_embedder.return_value = mock_model + + # the mock returns a fixed embedding + def mock_encode( + documents, batch_size=None, show_progress_bar=None, normalize_embeddings=None, precision=None, **kwargs + ): + import numpy as np + + return [np.ones(384).tolist() for _ in documents] + + mock_model.encode = mock_encode + embedder = SentenceTransformersDocumentEmbedder(model="mock-model", progress_bar=False) + + # mocked run method to return a fixed embedding + def mock_run(documents: list[Document]): + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + raise TypeError( + "SentenceTransformersDocumentEmbedder expects a list of Documents as input." + "In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder." + ) + + import numpy as np + + embedding = np.ones(384).tolist() + + # Add the embedding to each document + for doc in documents: + doc.embedding = embedding + + # Return the documents with embeddings, matching the actual implementation + return {"documents": documents} + + # mocked run + embedder.run = mock_run + + # initialize the component + embedder.warm_up() + + return embedder + + @pytest.fixture + def document_store(self, mock_sentence_transformers_doc_embedder): + """Create and populate a document store for testing.""" + documents = [ + Document(content="My name is Jean and I live in Paris."), + Document(content="My name is Mark and I live in Berlin."), + Document(content="My name is Giorgio and I live in Rome."), + ] + + document_store = InMemoryDocumentStore() + doc_writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP) + ingestion_pipe = Pipeline() + ingestion_pipe.add_component(instance=mock_sentence_transformers_doc_embedder, name="doc_embedder") + ingestion_pipe.add_component(instance=doc_writer, name="doc_writer") + ingestion_pipe.connect("doc_embedder.documents", "doc_writer.documents") + ingestion_pipe.run({"doc_embedder": {"documents": documents}}) + + return document_store + + @pytest.fixture + def mock_openai_completion(self): + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + mock_completion = MagicMock() + mock_completion.model = "gpt-4o-mini" + mock_completion.choices = [ + MagicMock(finish_reason="stop", index=0, message=MagicMock(content="Mark lives in Berlin.")) + ] + mock_completion.usage = {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97} + + mock_chat_completion_create.return_value = mock_completion + yield mock_chat_completion_create + + @pytest.fixture + def mock_transformers_similarity_ranker(self): + """ + This mock simulates the behavior of the ranker without loading the actual model. + """ + with ( + patch( + "haystack.components.rankers.transformers_similarity.AutoModelForSequenceClassification" + ) as mock_model_class, + patch("haystack.components.rankers.transformers_similarity.AutoTokenizer") as mock_tokenizer_class, + ): + mock_model = MagicMock() + mock_tokenizer = MagicMock() + + mock_model_class.from_pretrained.return_value = mock_model + mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer + + ranker = TransformersSimilarityRanker(model="mock-model", top_k=5, scale_score=True, calibration_factor=1.0) + + def mock_run(query, documents, top_k=None, scale_score=None, calibration_factor=None, score_threshold=None): + # assign random scores + import random + + ranked_docs = documents.copy() + for doc in ranked_docs: + doc.score = random.random() # random score between 0 and 1 + + # sort reverse order and select top_k if provided + ranked_docs.sort(key=lambda x: x.score, reverse=True) + if top_k is not None: + ranked_docs = ranked_docs[:top_k] + else: + ranked_docs = ranked_docs[: ranker.top_k] + + # apply score threshold if provided + if score_threshold is not None: + ranked_docs = [doc for doc in ranked_docs if doc.score >= score_threshold] + + return {"documents": ranked_docs} + + # replace the run method with our mock + ranker.run = mock_run + + # warm_up to initialize the component + ranker.warm_up() + + return ranker + + @pytest.fixture + def mock_sentence_transformers_text_embedder(self): + """ + Simulates the behavior of the embedder without loading the actual model + """ + with patch( + "haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory" + ) as mock_text_embedder: # noqa: E501 + mock_model = MagicMock() + mock_text_embedder.return_value = mock_model + + # the mock returns a fixed embedding + def mock_encode( + texts, batch_size=None, show_progress_bar=None, normalize_embeddings=None, precision=None, **kwargs + ): + import numpy as np + + return [np.ones(384).tolist() for _ in texts] + + mock_model.encode = mock_encode + embedder = SentenceTransformersTextEmbedder(model="mock-model", progress_bar=False) + + # mocked run method to return a fixed embedding + def mock_run(text): + if not isinstance(text, str): + raise TypeError( + "SentenceTransformersTextEmbedder expects a string as input." + "In case you want to embed a list of Documents, please use the " + "SentenceTransformersDocumentEmbedder." + ) + + import numpy as np + + embedding = np.ones(384).tolist() + return {"embedding": embedding} + + # mocked run + embedder.run = mock_run + + # initialize the component + embedder.warm_up() + + return embedder + + @pytest.fixture + @patch.dict("os.environ", {"OPENAI_API_KEY": "test-api-key"}) + def hybrid_rag_pipeline( + self, document_store, mock_transformers_similarity_ranker, mock_sentence_transformers_text_embedder + ): + """Create a hybrid RAG pipeline for testing.""" + + prompt_template = """ + Given these documents, answer the question based on the document content only.\nDocuments: + {% for doc in documents %} + {{ doc.content }} + {% endfor %} + + \nQuestion: {{question}} + \nAnswer: + """ + pipeline = Pipeline(connection_type_validation=False) + pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever") + + # Use the mocked embedder instead of creating a new one + pipeline.add_component(instance=mock_sentence_transformers_text_embedder, name="query_embedder") + + pipeline.add_component( + instance=InMemoryEmbeddingRetriever(document_store=document_store), name="embedding_retriever" + ) + pipeline.add_component(instance=DocumentJoiner(sort_by_score=False), name="doc_joiner") + + # Use the mocked ranker instead of the real one + pipeline.add_component(instance=mock_transformers_similarity_ranker, name="ranker") + + pipeline.add_component( + instance=PromptBuilder(template=prompt_template, required_variables=["documents", "question"]), + name="prompt_builder", + ) + + # Use a mocked API key for the OpenAIGenerator + pipeline.add_component(instance=OpenAIGenerator(api_key=Secret.from_env_var("OPENAI_API_KEY")), name="llm") + pipeline.add_component(instance=AnswerBuilder(), name="answer_builder") + + pipeline.connect("query_embedder", "embedding_retriever.query_embedding") + pipeline.connect("embedding_retriever", "doc_joiner.documents") + pipeline.connect("bm25_retriever", "doc_joiner.documents") + pipeline.connect("doc_joiner", "ranker.documents") + pipeline.connect("ranker", "prompt_builder.documents") + pipeline.connect("prompt_builder", "llm") + pipeline.connect("llm.replies", "answer_builder.replies") + pipeline.connect("llm.meta", "answer_builder.meta") + pipeline.connect("doc_joiner", "answer_builder.documents") + + return pipeline + + @pytest.fixture(scope="session") + def output_directory(self, tmp_path_factory): + return tmp_path_factory.mktemp("output_files") + + components = [ + Breakpoint("bm25_retriever", 0), + Breakpoint("query_embedder", 0), + Breakpoint("embedding_retriever", 0), + Breakpoint("doc_joiner", 0), + Breakpoint("ranker", 0), + Breakpoint("prompt_builder", 0), + Breakpoint("llm", 0), + Breakpoint("answer_builder", 0), + ] + + @pytest.mark.parametrize("component", components) + @pytest.mark.integration + def test_pipeline_breakpoints_hybrid_rag( + self, hybrid_rag_pipeline, document_store, output_directory, component, mock_openai_completion + ): + """ + Test that a hybrid RAG pipeline can be executed with breakpoints at each component. + """ + # Test data + question = "Where does Mark live?" + data = { + "query_embedder": {"text": question}, + "bm25_retriever": {"query": question}, + "ranker": {"query": question, "top_k": 5}, + "prompt_builder": {"question": question}, + "answer_builder": {"query": question}, + } + + try: + _ = hybrid_rag_pipeline.run(data, break_point=component, debug_path=str(output_directory)) + except BreakpointException: + pass + + result = load_and_resume_pipeline_state(hybrid_rag_pipeline, output_directory, component.component_name, data) + assert result["answer_builder"] diff --git a/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py new file mode 100644 index 0000000000..8c01b5973b --- /dev/null +++ b/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder +from haystack.components.converters import OutputAdapter +from haystack.components.joiners import StringJoiner +from haystack.core.errors import BreakpointException +from haystack.core.pipeline.pipeline import Pipeline +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import Breakpoint +from test.conftest import load_and_resume_pipeline_state + + +class TestPipelineBreakpoints: + @pytest.fixture + def string_joiner_pipeline(self): + pipeline = Pipeline() + pipeline.add_component( + "prompt_builder_1", ChatPromptBuilder(template=[ChatMessage.from_user("Builder 1: {{query}}")]) + ) + pipeline.add_component( + "prompt_builder_2", ChatPromptBuilder(template=[ChatMessage.from_user("Builder 2: {{query}}")]) + ) + pipeline.add_component("adapter_1", OutputAdapter("{{messages[0].text}}", output_type=str)) + pipeline.add_component("adapter_2", OutputAdapter("{{messages[0].text}}", output_type=str)) + pipeline.add_component("string_joiner", StringJoiner()) + + pipeline.connect("prompt_builder_1.prompt", "adapter_1.messages") + pipeline.connect("prompt_builder_2.prompt", "adapter_2.messages") + pipeline.connect("adapter_1", "string_joiner.strings") + pipeline.connect("adapter_2", "string_joiner.strings") + + return pipeline + + @pytest.fixture(scope="session") + def output_directory(self, tmp_path_factory): + return tmp_path_factory.mktemp("output_files") + + components = [ + Breakpoint("prompt_builder_1", 0), + Breakpoint("prompt_builder_2", 0), + Breakpoint("adapter_1", 0), + Breakpoint("adapter_2", 0), + Breakpoint("string_joiner", 0), + ] + + @pytest.mark.parametrize("component", components) + @pytest.mark.integration + def test_string_joiner_pipeline(self, string_joiner_pipeline, output_directory, component): + string_1 = "What's Natural Language Processing?" + string_2 = "What is life?" + data = {"prompt_builder_1": {"query": string_1}, "prompt_builder_2": {"query": string_2}} + + try: + _ = string_joiner_pipeline.run(data, break_point=component, debug_path=str(output_directory)) + except BreakpointException: + pass + + result = load_and_resume_pipeline_state( + string_joiner_pipeline, output_directory, component.component_name, data + ) + assert result["string_joiner"] diff --git a/test/core/pipeline/test_pipeline_experimental.py b/test/core/pipeline/test_pipeline_experimental.py new file mode 100644 index 0000000000..c772e4f022 --- /dev/null +++ b/test/core/pipeline/test_pipeline_experimental.py @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from concurrent.futures import ThreadPoolExecutor + +import pytest + +from haystack.components.joiners import BranchJoiner +from haystack.core.component import component +from haystack.core.errors import PipelineRuntimeError +from haystack.core.pipeline.pipeline import Pipeline + + +class TestPipeline: + """ + This class contains only unit tests for the Pipeline class. + It doesn't test Pipeline.run(), that is done separately in a different way. + """ + + def test_pipeline_thread_safety(self, waiting_component, spying_tracer): + # Initialize pipeline with synchronous components + pp = Pipeline() + pp.add_component("wait", waiting_component()) + + run_data = [{"wait_for": 0.001}, {"wait_for": 0.002}] + + # Use ThreadPoolExecutor to run pipeline calls in parallel + with ThreadPoolExecutor(max_workers=len(run_data)) as executor: + # Submit pipeline runs to the executor + futures = [executor.submit(pp.run, data) for data in run_data] + + # Wait for all futures to complete + for future in futures: + future.result() + + # Verify component visits using tracer + component_spans = [sp for sp in spying_tracer.spans if sp.operation_name == "haystack.component.run"] + + for span in component_spans: + assert span.tags["haystack.component.visits"] == 1 + + def test_prepare_component_inputs(self): + joiner_1 = BranchJoiner(type_=str) + joiner_2 = BranchJoiner(type_=str) + pp = Pipeline() + component_name = "joiner_1" + pp.add_component(component_name, joiner_1) + pp.add_component("joiner_2", joiner_2) + pp.connect(component_name, "joiner_2") + inputs = {"joiner_1": {"value": [{"sender": None, "value": "test_value"}]}} + comp_dict = pp._get_component_with_graph_metadata_and_visits(component_name, 0) + + _ = pp._consume_component_inputs(component_name=component_name, component=comp_dict, inputs=inputs) + # We remove input in greedy variadic sockets, even if they are from the user + assert inputs == {"joiner_1": {}} + + def test__run_component_success(self): + """Test successful component execution""" + joiner_1 = BranchJoiner(type_=str) + joiner_2 = BranchJoiner(type_=str) + pp = Pipeline() + component_name = "joiner_1" + pp.add_component(component_name, joiner_1) + pp.add_component("joiner_2", joiner_2) + pp.connect(component_name, "joiner_2") + inputs = {"value": ["test_value"]} + + outputs = pp._run_component( + component_name=component_name, + component=pp._get_component_with_graph_metadata_and_visits(component_name, 0), + inputs=inputs, + component_visits={component_name: 0, "joiner_2": 0}, + ) + + assert outputs == {"value": "test_value"} + + def test__run_component_fail(self): + """Test error when component doesn't return a dictionary""" + + @component + class WrongOutput: + @component.output_types(output=str) + def run(self, value: str): + return "not_a_dict" + + wrong = WrongOutput() + pp = Pipeline() + pp.add_component("wrong", wrong) + inputs = {"value": "test_value"} + + with pytest.raises(PipelineRuntimeError) as exc_info: + pp._run_component( + component_name="wrong", + component=pp._get_component_with_graph_metadata_and_visits("wrong", 0), + inputs=inputs, + component_visits={"wrong": 0}, + ) + + assert "Expected a dict" in str(exc_info.value) + + def test_run(self): + joiner_1 = BranchJoiner(type_=str) + joiner_2 = BranchJoiner(type_=str) + pp = Pipeline() + pp.add_component("joiner_1", joiner_1) + pp.add_component("joiner_2", joiner_2) + pp.connect("joiner_1", "joiner_2") + + _ = pp.run({"value": "test_value"}) From 28c84bf3f5771cd15d62606c32ea6b4c8b19fb01 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Mon, 14 Jul 2025 17:53:05 +0200 Subject: [PATCH 02/21] wip: fixing tests --- .../core/pipeline/delete_base_experimental.py | 60 --- .../pipeline/delete_pipeline_experimental.py | 393 ------------------ test/core/pipeline/test_pipeline.py | 10 + .../pipeline/test_pipeline_experimental.py | 110 ----- 4 files changed, 10 insertions(+), 563 deletions(-) delete mode 100644 haystack/core/pipeline/delete_base_experimental.py delete mode 100644 haystack/core/pipeline/delete_pipeline_experimental.py delete mode 100644 test/core/pipeline/test_pipeline_experimental.py diff --git a/haystack/core/pipeline/delete_base_experimental.py b/haystack/core/pipeline/delete_base_experimental.py deleted file mode 100644 index 94ea2f749e..0000000000 --- a/haystack/core/pipeline/delete_base_experimental.py +++ /dev/null @@ -1,60 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Dict - -from haystack.core.pipeline.base import PipelineBase as HaystackPipelineBase -from haystack.core.pipeline.component_checks import _NO_OUTPUT_PRODUCED, is_socket_lazy_variadic - - -class PipelineBase(HaystackPipelineBase): - @staticmethod - def _consume_component_inputs( - component_name: str, component: Dict, inputs: Dict, is_resume: bool = False - ) -> Dict[str, Any]: - """ - Extracts the inputs needed to run for the component and removes them from the global inputs state. - - :param component_name: The name of a component. - :param component: Component with component metadata. - :param inputs: Global inputs state. - :returns: The inputs for the component. - """ - component_inputs = inputs.get(component_name, {}) - consumed_inputs = {} - greedy_inputs_to_remove = set() - for socket_name, socket in component["input_sockets"].items(): - socket_inputs = component_inputs.get(socket_name, []) - socket_inputs = [sock["value"] for sock in socket_inputs if sock["value"] is not _NO_OUTPUT_PRODUCED] - - # if we are resuming a component, the inputs are already consumed, so we just return the first input - if is_resume: - consumed_inputs[socket_name] = socket_inputs[0] - continue - if socket_inputs: - if not socket.is_variadic: - # We only care about the first input provided to the socket. - consumed_inputs[socket_name] = socket_inputs[0] - elif socket.is_greedy: - # We need to keep track of greedy inputs because we always remove them, even if they come from - # outside the pipeline. Otherwise, a greedy input from the user would trigger a pipeline to run - # indefinitely. - greedy_inputs_to_remove.add(socket_name) - consumed_inputs[socket_name] = [socket_inputs[0]] - elif is_socket_lazy_variadic(socket): - # We use all inputs provided to the socket on a lazy variadic socket. - consumed_inputs[socket_name] = socket_inputs - - # We prune all inputs except for those that were provided from outside the pipeline (e.g. user inputs). - pruned_inputs = { - socket_name: [ - sock for sock in socket if sock["sender"] is None and not socket_name in greedy_inputs_to_remove - ] - for socket_name, socket in component_inputs.items() - } - pruned_inputs = {socket_name: socket for socket_name, socket in pruned_inputs.items() if len(socket) > 0} - - inputs[component_name] = pruned_inputs - - return consumed_inputs diff --git a/haystack/core/pipeline/delete_pipeline_experimental.py b/haystack/core/pipeline/delete_pipeline_experimental.py deleted file mode 100644 index 78d693b48d..0000000000 --- a/haystack/core/pipeline/delete_pipeline_experimental.py +++ /dev/null @@ -1,393 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - -# pylint: disable=too-many-return-statements, too-many-positional-arguments - - -from copy import deepcopy -from pathlib import Path -from typing import Any, Dict, Optional, Set, Union - -from haystack_experimental.core.errors import PipelineInvalidResumeStateError -from haystack_experimental.core.pipeline.base import PipelineBase - -from haystack import logging, tracing -from haystack.core.pipeline.base import ComponentPriority -from haystack.core.pipeline.pipeline import Pipeline as HaystackPipeline -from haystack.telemetry import pipeline_running -from haystack.utils import _deserialize_value_with_schema - -from ...components.agents import Agent -from ...dataclasses.breakpoints import AgentBreakpoint, Breakpoint -from .breakpoint import ( - _validate_break_point, - _validate_components_against_pipeline, - check_regular_break_point, - handle_agent_break_point, - trigger_break_point, -) - -logger = logging.getLogger(__name__) - - -# We inherit from both HaystackPipeline and PipelineBase to ensure that we have the -# necessary methods and properties from both classes. -class Pipeline(HaystackPipeline, PipelineBase): - """ - Synchronous version of the orchestration engine. - - Orchestrates component execution according to the execution graph, one after the other. - """ - - def _handle_resume_state(self, resume_state: Dict[str, Any]) -> tuple[Dict[str, int], Dict[str, Any], bool, list]: - """ - Handle resume state initialization. - - :param resume_state: The resume state to handle - :return: Tuple of (component_visits, data, resume_agent_in_pipeline, ordered_component_names) - """ - if resume_state.get("agent_name"): - return self._handle_agent_resume_state(resume_state) - else: - return self._handle_regular_resume_state(resume_state) - - def _handle_agent_resume_state( - self, resume_state: Dict[str, Any] - ) -> tuple[Dict[str, int], Dict[str, Any], bool, list]: - """ - Handle agent-specific resume state. - - :param resume_state: The resume state to handle - :return: Tuple of (component_visits, data, resume_agent_in_pipeline, ordered_component_names) - """ - agent_name = resume_state["agent_name"] - for name, component in self.graph.nodes.items(): - if component["instance"].__class__.__name__ == "Agent" and name == agent_name: - main_pipeline_state = resume_state.get("main_pipeline_state", {}) - component_visits = main_pipeline_state.get("component_visits", {}) - ordered_component_names = main_pipeline_state.get("ordered_component_names", []) - data = _deserialize_value_with_schema(main_pipeline_state.get("inputs", {})) - return component_visits, data, True, ordered_component_names - - # Fallback to regular resume if agent not found - return self._handle_regular_resume_state(resume_state) - - def _handle_regular_resume_state( - self, resume_state: Dict[str, Any] - ) -> tuple[Dict[str, int], Dict[str, Any], bool, list]: - """ - Handle regular component resume state. - - :param resume_state: The resume state to handle - :return: Tuple of (component_visits, data, resume_agent_in_pipeline, ordered_component_names) - """ - component_visits, data, resume_state, ordered_component_names = self.inject_resume_state_into_graph( - resume_state=resume_state - ) - data = _deserialize_value_with_schema(resume_state["pipeline_state"]["inputs"]) - return component_visits, data, False, ordered_component_names - - def run( # noqa: PLR0915, PLR0912 - self, - data: Dict[str, Any], - include_outputs_from: Optional[Set[str]] = None, - break_point: Optional[Union[Breakpoint, AgentBreakpoint]] = None, - resume_state: Optional[Dict[str, Any]] = None, - debug_path: Optional[Union[str, Path]] = None, - ) -> Dict[str, Any]: - """ - Runs the Pipeline with given input data. - - Usage: - ```python - from haystack import Pipeline, Document - from haystack.utils import Secret - from haystack.document_stores.in_memory import InMemoryDocumentStore - from haystack.components.retrievers.in_memory import InMemoryBM25Retriever - from haystack.components.generators import OpenAIGenerator - from haystack.components.builders.answer_builder import AnswerBuilder - from haystack.components.builders.prompt_builder import PromptBuilder - - # Write documents to InMemoryDocumentStore - document_store = InMemoryDocumentStore() - document_store.write_documents([ - Document(content="My name is Jean and I live in Paris."), - Document(content="My name is Mark and I live in Berlin."), - Document(content="My name is Giorgio and I live in Rome.") - ]) - - prompt_template = \"\"\" - Given these documents, answer the question. - Documents: - {% for doc in documents %} - {{ doc.content }} - {% endfor %} - Question: {{question}} - Answer: - \"\"\" - - retriever = InMemoryBM25Retriever(document_store=document_store) - prompt_builder = PromptBuilder(template=prompt_template) - llm = OpenAIGenerator(api_key=Secret.from_token(api_key)) - - rag_pipeline = Pipeline() - rag_pipeline.add_component("retriever", retriever) - rag_pipeline.add_component("prompt_builder", prompt_builder) - rag_pipeline.add_component("llm", llm) - rag_pipeline.connect("retriever", "prompt_builder.documents") - rag_pipeline.connect("prompt_builder", "llm") - - # Ask a question - question = "Who lives in Paris?" - results = rag_pipeline.run( - { - "retriever": {"query": question}, - "prompt_builder": {"question": question}, - } - ) - - print(results["llm"]["replies"]) - # Jean lives in Paris - ``` - - :param data: - A dictionary of inputs for the pipeline's components. Each key is a component name - and its value is a dictionary of that component's input parameters: - ``` - data = { - "comp1": {"input1": 1, "input2": 2}, - } - ``` - For convenience, this format is also supported when input names are unique: - ``` - data = { - "input1": 1, "input2": 2, - } - ``` - :param include_outputs_from: - Set of component names whose individual outputs are to be - included in the pipeline's output. For components that are - invoked multiple times (in a loop), only the last-produced - output is included. - - :param break_point: - A set of breakpoints that can be used to debug the pipeline execution. - - :param resume_state: - A dictionary containing the state of a previously saved pipeline execution. - - :param debug_path: - Path to the directory where the pipeline state should be saved. - - :returns: - A dictionary where each entry corresponds to a component name - and its output. If `include_outputs_from` is `None`, this dictionary - will only contain the outputs of leaf components, i.e., components - without outgoing connections. - - :raises ValueError: - If invalid inputs are provided to the pipeline. - :raises PipelineRuntimeError: - If the Pipeline contains cycles with unsupported connections that would cause - it to get stuck and fail running. - Or if a Component fails or returns output in an unsupported type. - :raises PipelineMaxComponentRuns: - If a Component reaches the maximum number of times it can be run in this Pipeline. - :raises PipelineBreakpointException: - When a pipeline_breakpoint is triggered. Contains the component name, state, and partial results. - """ - pipeline_running(self) - - if break_point and resume_state: - msg = ( - "pipeline_breakpoint and resume_state cannot be provided at the same time. " - "The pipeline run will be aborted." - ) - raise PipelineInvalidResumeStateError(message=msg) - - # make sure all breakpoints are valid, i.e. reference components in the pipeline - if break_point: - _validate_break_point(break_point, self.graph) - - # TODO: Remove this warmup once we can check reliably whether a component has been warmed up or not - # As of now it's here to make sure we don't have failing tests that assume warm_up() is called in run() - self.warm_up() - - if include_outputs_from is None: - include_outputs_from = set() - - if not resume_state: - # normalize `data` - data = self._prepare_component_input_data(data) - - # Raise ValueError if input is malformed in some way - self.validate_input(data) - - # We create a list of components in the pipeline sorted by name, so that the algorithm runs - # deterministically and independent of insertion order into the pipeline. - ordered_component_names = sorted(self.graph.nodes.keys()) - - # We track component visits to decide if a component can run. - component_visits = dict.fromkeys(ordered_component_names, 0) - resume_agent_in_pipeline = False - - else: - # Handle resume state - component_visits, data, resume_agent_in_pipeline, ordered_component_names = self._handle_resume_state( - resume_state - ) - - cached_topological_sort = None - # We need to access a component's receivers multiple times during a pipeline run. - # We store them here for easy access. - cached_receivers = {name: self._find_receivers_from(name) for name in ordered_component_names} - - pipeline_outputs: Dict[str, Any] = {} - with tracing.tracer.trace( - "haystack.pipeline.run", - tags={ - "haystack.pipeline.input_data": data, - "haystack.pipeline.output_data": pipeline_outputs, - "haystack.pipeline.metadata": self.metadata, - "haystack.pipeline.max_runs_per_component": self._max_runs_per_component, - }, - ) as span: - inputs = self._convert_to_internal_format(pipeline_inputs=data) - priority_queue = self._fill_queue(ordered_component_names, inputs, component_visits) - - # check if pipeline is blocked before execution - self.validate_pipeline(priority_queue) - - while True: - candidate = self._get_next_runnable_component(priority_queue, component_visits) - if candidate is None: - break - - priority, component_name, component = candidate - - if len(priority_queue) > 0 and priority in [ComponentPriority.DEFER, ComponentPriority.DEFER_LAST]: - component_name, topological_sort = self._tiebreak_waiting_components( - component_name=component_name, - priority=priority, - priority_queue=priority_queue, - topological_sort=cached_topological_sort, - ) - - cached_topological_sort = topological_sort - component = self._get_component_with_graph_metadata_and_visits( - component_name, component_visits[component_name] - ) - - is_resume = bool(resume_state and resume_state["pipeline_breakpoint"]["component"] == component_name) - component_inputs = self._consume_component_inputs( - component_name=component_name, component=component, inputs=inputs, is_resume=is_resume - ) - - # We need to add missing defaults using default values from input sockets because the run signature - # might not provide these defaults for components with inputs defined dynamically upon component - # initialization - component_inputs = self._add_missing_input_defaults(component_inputs, component["input_sockets"]) - - # Scenario 1: Resume state is provided to resume the pipeline at a specific component - # Deserialize the component_inputs if they are passed in resume state - # this check will prevent other component_inputs generated at runtime from being deserialized - if resume_state and component_name in resume_state["pipeline_state"]["inputs"].keys(): - for key, value in component_inputs.items(): - component_inputs[key] = _deserialize_value_with_schema(value) - - # Scenario 2: a breakpoint is provided to stop the pipeline at a specific component and visit count - breakpoint_triggered = False - if break_point is not None: - agent_breakpoint = False - - if isinstance(break_point, AgentBreakpoint): - component_instance = component["instance"] - if isinstance(component_instance, Agent): - component_inputs = handle_agent_break_point( - break_point, - component_name, - component_inputs, - inputs, - component_visits, - ordered_component_names, - data, - debug_path, - ) - agent_breakpoint = True - - if not agent_breakpoint and isinstance(break_point, Breakpoint): - breakpoint_triggered = check_regular_break_point(break_point, component_name, component_visits) - - if breakpoint_triggered: - trigger_break_point( - component_name, - component_inputs, - inputs, - component_visits, - debug_path, - data, - ordered_component_names, - pipeline_outputs, - ) - - if resume_agent_in_pipeline: - # inject the resume_state into the component (the Agent) inputs - component_inputs["resume_state"] = resume_state - component_inputs["break_point"] = None - - component_outputs = self._run_component( - component_name=component_name, - component=component, - inputs=component_inputs, # the inputs to the current component - component_visits=component_visits, - parent_span=span, - ) - - # Updates global input state with component outputs and returns outputs that should go to - # pipeline outputs. - component_pipeline_outputs = self._write_component_outputs( - component_name=component_name, - component_outputs=component_outputs, - inputs=inputs, - receivers=cached_receivers[component_name], - include_outputs_from=include_outputs_from, - ) - - if component_pipeline_outputs: - pipeline_outputs[component_name] = deepcopy(component_pipeline_outputs) - if self._is_queue_stale(priority_queue): - priority_queue = self._fill_queue(ordered_component_names, inputs, component_visits) - - if break_point and not agent_breakpoint: - logger.warning( - "The given breakpoint {break_point} was never triggered. This is because:\n" - "1. The provided component is not a part of the pipeline execution path.\n" - "2. The component did not reach the visit count specified in the pipeline_breakpoint", - pipeline_breakpoint=break_point, - ) - - return pipeline_outputs - - def inject_resume_state_into_graph(self, resume_state): - """ - Loads the resume state from a file and injects it into the pipeline graph. - - """ - # We previously check if the resume_state is None but this is needed to prevent a typing error - if not resume_state: - raise PipelineInvalidResumeStateError("Cannot inject resume state: resume_state is None") - - # check if the resume_state is valid for the current pipeline - _validate_components_against_pipeline(resume_state, self.graph) - - data = self._prepare_component_input_data(resume_state["pipeline_state"]["inputs"]) - component_visits = resume_state["pipeline_state"]["component_visits"] - ordered_component_names = resume_state["pipeline_state"]["ordered_component_names"] - logger.info( - "Resuming pipeline from {component} with visit count {visits}", - component=resume_state["pipeline_breakpoint"]["component"], - visits=resume_state["pipeline_breakpoint"]["visits"], - ) - - return component_visits, data, resume_state, ordered_component_names diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index 20514b7325..7f5962d42d 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -123,3 +123,13 @@ def run(self): component_visits={"erroring_component": 0}, ) assert "Component name: 'erroring_component'" in str(exc_info.value) + + def test_run(self): + joiner_1 = BranchJoiner(type_=str) + joiner_2 = BranchJoiner(type_=str) + pp = Pipeline() + pp.add_component("joiner_1", joiner_1) + pp.add_component("joiner_2", joiner_2) + pp.connect("joiner_1", "joiner_2") + + _ = pp.run({"value": "test_value"}) diff --git a/test/core/pipeline/test_pipeline_experimental.py b/test/core/pipeline/test_pipeline_experimental.py deleted file mode 100644 index c772e4f022..0000000000 --- a/test/core/pipeline/test_pipeline_experimental.py +++ /dev/null @@ -1,110 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - -from concurrent.futures import ThreadPoolExecutor - -import pytest - -from haystack.components.joiners import BranchJoiner -from haystack.core.component import component -from haystack.core.errors import PipelineRuntimeError -from haystack.core.pipeline.pipeline import Pipeline - - -class TestPipeline: - """ - This class contains only unit tests for the Pipeline class. - It doesn't test Pipeline.run(), that is done separately in a different way. - """ - - def test_pipeline_thread_safety(self, waiting_component, spying_tracer): - # Initialize pipeline with synchronous components - pp = Pipeline() - pp.add_component("wait", waiting_component()) - - run_data = [{"wait_for": 0.001}, {"wait_for": 0.002}] - - # Use ThreadPoolExecutor to run pipeline calls in parallel - with ThreadPoolExecutor(max_workers=len(run_data)) as executor: - # Submit pipeline runs to the executor - futures = [executor.submit(pp.run, data) for data in run_data] - - # Wait for all futures to complete - for future in futures: - future.result() - - # Verify component visits using tracer - component_spans = [sp for sp in spying_tracer.spans if sp.operation_name == "haystack.component.run"] - - for span in component_spans: - assert span.tags["haystack.component.visits"] == 1 - - def test_prepare_component_inputs(self): - joiner_1 = BranchJoiner(type_=str) - joiner_2 = BranchJoiner(type_=str) - pp = Pipeline() - component_name = "joiner_1" - pp.add_component(component_name, joiner_1) - pp.add_component("joiner_2", joiner_2) - pp.connect(component_name, "joiner_2") - inputs = {"joiner_1": {"value": [{"sender": None, "value": "test_value"}]}} - comp_dict = pp._get_component_with_graph_metadata_and_visits(component_name, 0) - - _ = pp._consume_component_inputs(component_name=component_name, component=comp_dict, inputs=inputs) - # We remove input in greedy variadic sockets, even if they are from the user - assert inputs == {"joiner_1": {}} - - def test__run_component_success(self): - """Test successful component execution""" - joiner_1 = BranchJoiner(type_=str) - joiner_2 = BranchJoiner(type_=str) - pp = Pipeline() - component_name = "joiner_1" - pp.add_component(component_name, joiner_1) - pp.add_component("joiner_2", joiner_2) - pp.connect(component_name, "joiner_2") - inputs = {"value": ["test_value"]} - - outputs = pp._run_component( - component_name=component_name, - component=pp._get_component_with_graph_metadata_and_visits(component_name, 0), - inputs=inputs, - component_visits={component_name: 0, "joiner_2": 0}, - ) - - assert outputs == {"value": "test_value"} - - def test__run_component_fail(self): - """Test error when component doesn't return a dictionary""" - - @component - class WrongOutput: - @component.output_types(output=str) - def run(self, value: str): - return "not_a_dict" - - wrong = WrongOutput() - pp = Pipeline() - pp.add_component("wrong", wrong) - inputs = {"value": "test_value"} - - with pytest.raises(PipelineRuntimeError) as exc_info: - pp._run_component( - component_name="wrong", - component=pp._get_component_with_graph_metadata_and_visits("wrong", 0), - inputs=inputs, - component_visits={"wrong": 0}, - ) - - assert "Expected a dict" in str(exc_info.value) - - def test_run(self): - joiner_1 = BranchJoiner(type_=str) - joiner_2 = BranchJoiner(type_=str) - pp = Pipeline() - pp.add_component("joiner_1", joiner_1) - pp.add_component("joiner_2", joiner_2) - pp.connect("joiner_1", "joiner_2") - - _ = pp.run({"value": "test_value"}) From ef2e2bc1fb83a79013ca013ad1e347cccb4fbb0f Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Mon, 14 Jul 2025 18:44:49 +0200 Subject: [PATCH 03/21] wip: fixing tests --- haystack/components/agents/agent.py | 28 +- haystack/components/agents/agent_original.py | 467 +++++++++++++++++++ 2 files changed, 485 insertions(+), 10 deletions(-) create mode 100644 haystack/components/agents/agent_original.py diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index ceeab4760b..713ab0ecbc 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -73,8 +73,9 @@ def __init__( exit_conditions: Optional[List[str]] = None, state_schema: Optional[Dict[str, Any]] = None, max_agent_steps: int = 100, - raise_on_tool_invocation_failure: bool = False, streaming_callback: Optional[StreamingCallbackT] = None, + raise_on_tool_invocation_failure: bool = False, + tool_invoker_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """ Initialize the agent component. @@ -88,10 +89,11 @@ def __init__( :param state_schema: The schema for the runtime state used by the tools. :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100. If the agent exceeds this number of steps, it will stop and return the current state. - :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails? - If set to False, the exception will be turned into a chat message and passed to the LLM. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails? + If set to False, the exception will be turned into a chat message and passed to the LLM. + :param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker. :raises TypeError: If the chat_generator does not support tools parameter in its run method. :raises ValueError: If the exit_conditions are not valid. """ @@ -141,9 +143,15 @@ def __init__( component.set_input_type(self, name=param, type=config["type"], default=None) component.set_output_types(self, **output_types) + self.tool_invoker_kwargs = tool_invoker_kwargs self._tool_invoker = None if self.tools: - self._tool_invoker = ToolInvoker(tools=self.tools, raise_on_failure=self.raise_on_tool_invocation_failure) + resolved_tool_invoker_kwargs = { + "tools": self.tools, + "raise_on_failure": self.raise_on_tool_invocation_failure, + **(tool_invoker_kwargs or {}), + } + self._tool_invoker = ToolInvoker(**resolved_tool_invoker_kwargs) else: logger.warning( "No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text " @@ -151,7 +159,6 @@ def __init__( ) self._is_warmed_up = False - self._agent_name: Optional[str] = None def warm_up(self) -> None: """ @@ -182,8 +189,9 @@ def to_dict(self) -> Dict[str, Any]: # We serialize the original state schema, not the resolved one to reflect the original user input state_schema=_schema_to_dict(self._state_schema), max_agent_steps=self.max_agent_steps, - raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure, streaming_callback=streaming_callback, + raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure, + tool_invoker_kwargs=self.tool_invoker_kwargs, ) @classmethod @@ -429,8 +437,8 @@ def run( # noqa: PLR0915 "Agent will not perform any actions specific to user input. Consider adding user messages to the input." ) - state = State(schema=self.state_schema, data=kwargs) - state.set("messages", messages) + # state = State(schema=self.state_schema, data=kwargs) + # state.set("messages", messages) streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False @@ -590,8 +598,8 @@ async def run_async( # noqa: PLR0915 "Agent will not perform any actions specific to user input. Consider adding user messages to the input." ) - state = State(schema=self.state_schema, data=kwargs) - state.set("messages", messages) + # state = State(schema=self.state_schema, data=kwargs) + # state.set("messages", messages) streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True diff --git a/haystack/components/agents/agent_original.py b/haystack/components/agents/agent_original.py new file mode 100644 index 0000000000..b9d618b3ee --- /dev/null +++ b/haystack/components/agents/agent_original.py @@ -0,0 +1,467 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import inspect +from typing import Any, Dict, List, Optional, Union + +from haystack import component, default_from_dict, default_to_dict, logging, tracing +from haystack.components.generators.chat.types import ChatGenerator +from haystack.components.tools import ToolInvoker +from haystack.core.pipeline.async_pipeline import AsyncPipeline +from haystack.core.pipeline.pipeline import Pipeline +from haystack.core.pipeline.utils import _deepcopy_with_exceptions +from haystack.core.serialization import component_to_dict +from haystack.dataclasses import ChatMessage, ChatRole +from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback +from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset +from haystack.utils.callable_serialization import deserialize_callable, serialize_callable +from haystack.utils.deserialization import deserialize_chatgenerator_inplace + +from .state.state import State, _schema_from_dict, _schema_to_dict, _validate_schema +from .state.state_utils import merge_lists + +logger = logging.getLogger(__name__) + + +@component +class Agent: + """ + A Haystack component that implements a tool-using agent with provider-agnostic chat model support. + + The component processes messages and executes tools until an exit_condition condition is met. + The exit_condition can be triggered either by a direct text response or by invoking a specific designated tool. + + When you call an Agent without tools, it acts as a ChatGenerator, produces one response, then exits. + + ### Usage example + ```python + from haystack.components.agents import Agent + from haystack.components.generators.chat import OpenAIChatGenerator + from haystack.dataclasses import ChatMessage + from haystack.tools.tool import Tool + + tools = [Tool(name="calculator", description="..."), Tool(name="search", description="...")] + + agent = Agent( + chat_generator=OpenAIChatGenerator(), + tools=tools, + exit_condition="search", + ) + + # Run the agent + result = agent.run( + messages=[ChatMessage.from_user("Find information about Haystack")] + ) + + assert "messages" in result # Contains conversation history + ``` + """ + + def __init__( + self, + *, + chat_generator: ChatGenerator, + tools: Optional[Union[List[Tool], Toolset]] = None, + system_prompt: Optional[str] = None, + exit_conditions: Optional[List[str]] = None, + state_schema: Optional[Dict[str, Any]] = None, + max_agent_steps: int = 100, + streaming_callback: Optional[StreamingCallbackT] = None, + raise_on_tool_invocation_failure: bool = False, + tool_invoker_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Initialize the agent component. + + :param chat_generator: An instance of the chat generator that your agent should use. It must support tools. + :param tools: List of Tool objects or a Toolset that the agent can use. + :param system_prompt: System prompt for the agent. + :param exit_conditions: List of conditions that will cause the agent to return. + Can include "text" if the agent should return when it generates a message without tool calls, + or tool names that will cause the agent to return once the tool was executed. Defaults to ["text"]. + :param state_schema: The schema for the runtime state used by the tools. + :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100. + If the agent exceeds this number of steps, it will stop and return the current state. + :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. + The same callback can be configured to emit tool results when a tool is called. + :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails? + If set to False, the exception will be turned into a chat message and passed to the LLM. + :param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker. + :raises TypeError: If the chat_generator does not support tools parameter in its run method. + :raises ValueError: If the exit_conditions are not valid. + """ + # Check if chat_generator supports tools parameter + chat_generator_run_method = inspect.signature(chat_generator.run) + if "tools" not in chat_generator_run_method.parameters: + raise TypeError( + f"{type(chat_generator).__name__} does not accept tools parameter in its run method. " + "The Agent component requires a chat generator that supports tools." + ) + + valid_exits = ["text"] + [tool.name for tool in tools or []] + if exit_conditions is None: + exit_conditions = ["text"] + if not all(condition in valid_exits for condition in exit_conditions): + raise ValueError( + f"Invalid exit conditions provided: {exit_conditions}. " + f"Valid exit conditions must be a subset of {valid_exits}. " + "Ensure that each exit condition corresponds to either 'text' or a valid tool name." + ) + + # Validate state schema if provided + if state_schema is not None: + _validate_schema(state_schema) + self._state_schema = state_schema or {} + + # Initialize state schema + resolved_state_schema = _deepcopy_with_exceptions(self._state_schema) + if resolved_state_schema.get("messages") is None: + resolved_state_schema["messages"] = {"type": List[ChatMessage], "handler": merge_lists} + self.state_schema = resolved_state_schema + + self.chat_generator = chat_generator + self.tools = tools or [] + self.system_prompt = system_prompt + self.exit_conditions = exit_conditions + self.max_agent_steps = max_agent_steps + self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure + self.streaming_callback = streaming_callback + + output_types = {"last_message": ChatMessage} + for param, config in self.state_schema.items(): + output_types[param] = config["type"] + # Skip setting input types for parameters that are already in the run method + if param in ["messages", "streaming_callback"]: + continue + component.set_input_type(self, name=param, type=config["type"], default=None) + component.set_output_types(self, **output_types) + + self.tool_invoker_kwargs = tool_invoker_kwargs + self._tool_invoker = None + if self.tools: + resolved_tool_invoker_kwargs = { + "tools": self.tools, + "raise_on_failure": self.raise_on_tool_invocation_failure, + **(tool_invoker_kwargs or {}), + } + self._tool_invoker = ToolInvoker(**resolved_tool_invoker_kwargs) + else: + logger.warning( + "No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text " + "responses. To enable tool usage, pass tools directly to the Agent, not to the chat_generator." + ) + + self._is_warmed_up = False + + def warm_up(self) -> None: + """ + Warm up the Agent. + """ + if not self._is_warmed_up: + if hasattr(self.chat_generator, "warm_up"): + self.chat_generator.warm_up() + self._is_warmed_up = True + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize the component to a dictionary. + + :return: Dictionary with serialized data + """ + if self.streaming_callback is not None: + streaming_callback = serialize_callable(self.streaming_callback) + else: + streaming_callback = None + + return default_to_dict( + self, + chat_generator=component_to_dict(obj=self.chat_generator, name="chat_generator"), + tools=serialize_tools_or_toolset(self.tools), + system_prompt=self.system_prompt, + exit_conditions=self.exit_conditions, + # We serialize the original state schema, not the resolved one to reflect the original user input + state_schema=_schema_to_dict(self._state_schema), + max_agent_steps=self.max_agent_steps, + streaming_callback=streaming_callback, + raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure, + tool_invoker_kwargs=self.tool_invoker_kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Agent": + """ + Deserialize the agent from a dictionary. + + :param data: Dictionary to deserialize from + :return: Deserialized agent + """ + init_params = data.get("init_parameters", {}) + + deserialize_chatgenerator_inplace(init_params, key="chat_generator") + + if "state_schema" in init_params: + init_params["state_schema"] = _schema_from_dict(init_params["state_schema"]) + + if init_params.get("streaming_callback") is not None: + init_params["streaming_callback"] = deserialize_callable(init_params["streaming_callback"]) + + deserialize_tools_or_toolset_inplace(init_params, key="tools") + + return default_from_dict(cls, data) + + def _prepare_generator_inputs(self, streaming_callback: Optional[StreamingCallbackT] = None) -> Dict[str, Any]: + """Prepare inputs for the chat generator.""" + generator_inputs: Dict[str, Any] = {"tools": self.tools} + if streaming_callback is not None: + generator_inputs["streaming_callback"] = streaming_callback + return generator_inputs + + def _create_agent_span(self) -> Any: + """Create a span for the agent run.""" + return tracing.tracer.trace( + "haystack.agent.run", + tags={ + "haystack.agent.max_steps": self.max_agent_steps, + "haystack.agent.tools": self.tools, + "haystack.agent.exit_conditions": self.exit_conditions, + "haystack.agent.state_schema": _schema_to_dict(self.state_schema), + }, + ) + + def run( + self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any + ) -> Dict[str, Any]: + """ + Process messages and execute tools until an exit condition is met. + + :param messages: List of Haystack ChatMessage objects to process. + If a list of dictionaries is provided, each dictionary will be converted to a ChatMessage object. + :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. + The same callback can be configured to emit tool results when a tool is called. + :param kwargs: Additional data to pass to the State schema used by the Agent. + The keys must match the schema defined in the Agent's `state_schema`. + :returns: + A dictionary with the following keys: + - "messages": List of all messages exchanged during the agent's run. + - "last_message": The last message exchanged during the agent's run. + - Any additional keys defined in the `state_schema`. + :raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`. + """ + if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"): + raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run()'.") + + if self.system_prompt is not None: + messages = [ChatMessage.from_system(self.system_prompt)] + messages + + if all(m.is_from(ChatRole.SYSTEM) for m in messages): + logger.warning( + "All messages provided to the Agent component are system messages. This is not recommended as the " + "Agent will not perform any actions specific to user input. Consider adding user messages to the input." + ) + + state = State(schema=self.state_schema, data=kwargs) + state.set("messages", messages) + component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) + + streaming_callback = select_streaming_callback( + init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False + ) + generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback) + with self._create_agent_span() as span: + span.set_content_tag( + "haystack.agent.input", + _deepcopy_with_exceptions({"messages": messages, "streaming_callback": streaming_callback, **kwargs}), + ) + counter = 0 + while counter < self.max_agent_steps: + # 1. Call the ChatGenerator + result = Pipeline._run_component( + component_name="chat_generator", + component={"instance": self.chat_generator}, + inputs={"messages": messages, **generator_inputs}, + component_visits=component_visits, + parent_span=span, + ) + llm_messages = result["replies"] + state.set("messages", llm_messages) + + # 2. Check if any of the LLM responses contain a tool call or if the LLM is not using tools + if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None: + counter += 1 + break + + # 3. Call the ToolInvoker + # We only send the messages from the LLM to the tool invoker + tool_invoker_result = Pipeline._run_component( + component_name="tool_invoker", + component={"instance": self._tool_invoker}, + inputs={"messages": llm_messages, "state": state, "streaming_callback": streaming_callback}, + component_visits=component_visits, + parent_span=span, + ) + tool_messages = tool_invoker_result["tool_messages"] + state = tool_invoker_result["state"] + state.set("messages", tool_messages) + + # 4. Check if any LLM message's tool call name matches an exit condition + if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages): + counter += 1 + break + + # 5. Fetch the combined messages and send them back to the LLM + messages = state.get("messages") + counter += 1 + + if counter >= self.max_agent_steps: + logger.warning( + "Agent reached maximum agent steps of {max_agent_steps}, stopping.", + max_agent_steps=self.max_agent_steps, + ) + span.set_content_tag("haystack.agent.output", state.data) + span.set_tag("haystack.agent.steps_taken", counter) + + result = {**state.data} + all_messages = state.get("messages") + if all_messages: + result.update({"last_message": all_messages[-1]}) + return result + + async def run_async( + self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any + ) -> Dict[str, Any]: + """ + Asynchronously process messages and execute tools until the exit condition is met. + + This is the asynchronous version of the `run` method. It follows the same logic but uses + asynchronous operations where possible, such as calling the `run_async` method of the ChatGenerator + if available. + + :param messages: List of chat messages to process + :param streaming_callback: An asynchronous callback that will be invoked when a response + is streamed from the LLM. The same callback can be configured to emit tool results + when a tool is called. + :param kwargs: Additional data to pass to the State schema used by the Agent. + The keys must match the schema defined in the Agent's `state_schema`. + :returns: + A dictionary with the following keys: + - "messages": List of all messages exchanged during the agent's run. + - "last_message": The last message exchanged during the agent's run. + - Any additional keys defined in the `state_schema`. + :raises RuntimeError: If the Agent component wasn't warmed up before calling `run_async()`. + """ + if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"): + raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run_async()'.") + + if self.system_prompt is not None: + messages = [ChatMessage.from_system(self.system_prompt)] + messages + + if all(m.is_from(ChatRole.SYSTEM) for m in messages): + logger.warning( + "All messages provided to the Agent component are system messages. This is not recommended as the " + "Agent will not perform any actions specific to user input. Consider adding user messages to the input." + ) + + state = State(schema=self.state_schema, data=kwargs) + state.set("messages", messages) + component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) + + streaming_callback = select_streaming_callback( + init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True + ) + generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback) + with self._create_agent_span() as span: + span.set_content_tag( + "haystack.agent.input", + _deepcopy_with_exceptions({"messages": messages, "streaming_callback": streaming_callback, **kwargs}), + ) + counter = 0 + while counter < self.max_agent_steps: + # 1. Call the ChatGenerator + result = await AsyncPipeline._run_component_async( + component_name="chat_generator", + component={"instance": self.chat_generator}, + component_inputs={"messages": messages, **generator_inputs}, + component_visits=component_visits, + max_runs_per_component=self.max_agent_steps, + parent_span=span, + ) + llm_messages = result["replies"] + state.set("messages", llm_messages) + + # 2. Check if any of the LLM responses contain a tool call or if the LLM is not using tools + if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None: + counter += 1 + break + + # 3. Call the ToolInvoker + # We only send the messages from the LLM to the tool invoker + # Check if the ToolInvoker supports async execution. Currently, it doesn't. + tool_invoker_result = await AsyncPipeline._run_component_async( + component_name="tool_invoker", + component={"instance": self._tool_invoker}, + component_inputs={ + "messages": llm_messages, + "state": state, + "streaming_callback": streaming_callback, + }, + component_visits=component_visits, + max_runs_per_component=self.max_agent_steps, + parent_span=span, + ) + tool_messages = tool_invoker_result["tool_messages"] + state = tool_invoker_result["state"] + state.set("messages", tool_messages) + + # 4. Check if any LLM message's tool call name matches an exit condition + if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages): + counter += 1 + break + + # 5. Fetch the combined messages and send them back to the LLM + messages = state.get("messages") + counter += 1 + + if counter >= self.max_agent_steps: + logger.warning( + "Agent reached maximum agent steps of {max_agent_steps}, stopping.", + max_agent_steps=self.max_agent_steps, + ) + span.set_content_tag("haystack.agent.output", state.data) + span.set_tag("haystack.agent.steps_taken", counter) + + result = {**state.data} + all_messages = state.get("messages") + if all_messages: + result.update({"last_message": all_messages[-1]}) + return result + + def _check_exit_conditions(self, llm_messages: List[ChatMessage], tool_messages: List[ChatMessage]) -> bool: + """ + Check if any of the LLM messages' tool calls match an exit condition and if there are no errors. + + :param llm_messages: List of messages from the LLM + :param tool_messages: List of messages from the tool invoker + :return: True if an exit condition is met and there are no errors, False otherwise + """ + matched_exit_conditions = set() + has_errors = False + + for msg in llm_messages: + if msg.tool_call and msg.tool_call.tool_name in self.exit_conditions: + matched_exit_conditions.add(msg.tool_call.tool_name) + + # Check if any error is specifically from the tool matching the exit condition + tool_errors = [ + tool_msg.tool_call_result.error + for tool_msg in tool_messages + if tool_msg.tool_call_result is not None + and tool_msg.tool_call_result.origin.tool_name == msg.tool_call.tool_name + ] + if any(tool_errors): + has_errors = True + # No need to check further if we found an error + break + + # Only return True if at least one exit condition was matched AND none had errors + return bool(matched_exit_conditions) and not has_errors From 9fad9bbb571100946fce3c3e414d27bed10012d4 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Mon, 14 Jul 2025 18:47:43 +0200 Subject: [PATCH 04/21] wip: fixing tests --- haystack/components/agents/__init__.py | 2 +- haystack/components/agents/agent_original.py | 467 ----------------- .../agents/test_state_class_experimental.py | 477 ------------------ 3 files changed, 1 insertion(+), 945 deletions(-) delete mode 100644 haystack/components/agents/agent_original.py delete mode 100644 test/components/agents/test_state_class_experimental.py diff --git a/haystack/components/agents/__init__.py b/haystack/components/agents/__init__.py index f94e305a6e..d331918f68 100644 --- a/haystack/components/agents/__init__.py +++ b/haystack/components/agents/__init__.py @@ -10,7 +10,7 @@ _import_structure = {"agent": ["Agent"], "state": ["State"]} if TYPE_CHECKING: - from .origina_agent import Agent as Agent + from .agent import Agent as Agent from .state import State as State else: diff --git a/haystack/components/agents/agent_original.py b/haystack/components/agents/agent_original.py deleted file mode 100644 index b9d618b3ee..0000000000 --- a/haystack/components/agents/agent_original.py +++ /dev/null @@ -1,467 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - -import inspect -from typing import Any, Dict, List, Optional, Union - -from haystack import component, default_from_dict, default_to_dict, logging, tracing -from haystack.components.generators.chat.types import ChatGenerator -from haystack.components.tools import ToolInvoker -from haystack.core.pipeline.async_pipeline import AsyncPipeline -from haystack.core.pipeline.pipeline import Pipeline -from haystack.core.pipeline.utils import _deepcopy_with_exceptions -from haystack.core.serialization import component_to_dict -from haystack.dataclasses import ChatMessage, ChatRole -from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback -from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset -from haystack.utils.callable_serialization import deserialize_callable, serialize_callable -from haystack.utils.deserialization import deserialize_chatgenerator_inplace - -from .state.state import State, _schema_from_dict, _schema_to_dict, _validate_schema -from .state.state_utils import merge_lists - -logger = logging.getLogger(__name__) - - -@component -class Agent: - """ - A Haystack component that implements a tool-using agent with provider-agnostic chat model support. - - The component processes messages and executes tools until an exit_condition condition is met. - The exit_condition can be triggered either by a direct text response or by invoking a specific designated tool. - - When you call an Agent without tools, it acts as a ChatGenerator, produces one response, then exits. - - ### Usage example - ```python - from haystack.components.agents import Agent - from haystack.components.generators.chat import OpenAIChatGenerator - from haystack.dataclasses import ChatMessage - from haystack.tools.tool import Tool - - tools = [Tool(name="calculator", description="..."), Tool(name="search", description="...")] - - agent = Agent( - chat_generator=OpenAIChatGenerator(), - tools=tools, - exit_condition="search", - ) - - # Run the agent - result = agent.run( - messages=[ChatMessage.from_user("Find information about Haystack")] - ) - - assert "messages" in result # Contains conversation history - ``` - """ - - def __init__( - self, - *, - chat_generator: ChatGenerator, - tools: Optional[Union[List[Tool], Toolset]] = None, - system_prompt: Optional[str] = None, - exit_conditions: Optional[List[str]] = None, - state_schema: Optional[Dict[str, Any]] = None, - max_agent_steps: int = 100, - streaming_callback: Optional[StreamingCallbackT] = None, - raise_on_tool_invocation_failure: bool = False, - tool_invoker_kwargs: Optional[Dict[str, Any]] = None, - ) -> None: - """ - Initialize the agent component. - - :param chat_generator: An instance of the chat generator that your agent should use. It must support tools. - :param tools: List of Tool objects or a Toolset that the agent can use. - :param system_prompt: System prompt for the agent. - :param exit_conditions: List of conditions that will cause the agent to return. - Can include "text" if the agent should return when it generates a message without tool calls, - or tool names that will cause the agent to return once the tool was executed. Defaults to ["text"]. - :param state_schema: The schema for the runtime state used by the tools. - :param max_agent_steps: Maximum number of steps the agent will run before stopping. Defaults to 100. - If the agent exceeds this number of steps, it will stop and return the current state. - :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. - The same callback can be configured to emit tool results when a tool is called. - :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails? - If set to False, the exception will be turned into a chat message and passed to the LLM. - :param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker. - :raises TypeError: If the chat_generator does not support tools parameter in its run method. - :raises ValueError: If the exit_conditions are not valid. - """ - # Check if chat_generator supports tools parameter - chat_generator_run_method = inspect.signature(chat_generator.run) - if "tools" not in chat_generator_run_method.parameters: - raise TypeError( - f"{type(chat_generator).__name__} does not accept tools parameter in its run method. " - "The Agent component requires a chat generator that supports tools." - ) - - valid_exits = ["text"] + [tool.name for tool in tools or []] - if exit_conditions is None: - exit_conditions = ["text"] - if not all(condition in valid_exits for condition in exit_conditions): - raise ValueError( - f"Invalid exit conditions provided: {exit_conditions}. " - f"Valid exit conditions must be a subset of {valid_exits}. " - "Ensure that each exit condition corresponds to either 'text' or a valid tool name." - ) - - # Validate state schema if provided - if state_schema is not None: - _validate_schema(state_schema) - self._state_schema = state_schema or {} - - # Initialize state schema - resolved_state_schema = _deepcopy_with_exceptions(self._state_schema) - if resolved_state_schema.get("messages") is None: - resolved_state_schema["messages"] = {"type": List[ChatMessage], "handler": merge_lists} - self.state_schema = resolved_state_schema - - self.chat_generator = chat_generator - self.tools = tools or [] - self.system_prompt = system_prompt - self.exit_conditions = exit_conditions - self.max_agent_steps = max_agent_steps - self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure - self.streaming_callback = streaming_callback - - output_types = {"last_message": ChatMessage} - for param, config in self.state_schema.items(): - output_types[param] = config["type"] - # Skip setting input types for parameters that are already in the run method - if param in ["messages", "streaming_callback"]: - continue - component.set_input_type(self, name=param, type=config["type"], default=None) - component.set_output_types(self, **output_types) - - self.tool_invoker_kwargs = tool_invoker_kwargs - self._tool_invoker = None - if self.tools: - resolved_tool_invoker_kwargs = { - "tools": self.tools, - "raise_on_failure": self.raise_on_tool_invocation_failure, - **(tool_invoker_kwargs or {}), - } - self._tool_invoker = ToolInvoker(**resolved_tool_invoker_kwargs) - else: - logger.warning( - "No tools provided to the Agent. The Agent will behave like a ChatGenerator and only return text " - "responses. To enable tool usage, pass tools directly to the Agent, not to the chat_generator." - ) - - self._is_warmed_up = False - - def warm_up(self) -> None: - """ - Warm up the Agent. - """ - if not self._is_warmed_up: - if hasattr(self.chat_generator, "warm_up"): - self.chat_generator.warm_up() - self._is_warmed_up = True - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize the component to a dictionary. - - :return: Dictionary with serialized data - """ - if self.streaming_callback is not None: - streaming_callback = serialize_callable(self.streaming_callback) - else: - streaming_callback = None - - return default_to_dict( - self, - chat_generator=component_to_dict(obj=self.chat_generator, name="chat_generator"), - tools=serialize_tools_or_toolset(self.tools), - system_prompt=self.system_prompt, - exit_conditions=self.exit_conditions, - # We serialize the original state schema, not the resolved one to reflect the original user input - state_schema=_schema_to_dict(self._state_schema), - max_agent_steps=self.max_agent_steps, - streaming_callback=streaming_callback, - raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure, - tool_invoker_kwargs=self.tool_invoker_kwargs, - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "Agent": - """ - Deserialize the agent from a dictionary. - - :param data: Dictionary to deserialize from - :return: Deserialized agent - """ - init_params = data.get("init_parameters", {}) - - deserialize_chatgenerator_inplace(init_params, key="chat_generator") - - if "state_schema" in init_params: - init_params["state_schema"] = _schema_from_dict(init_params["state_schema"]) - - if init_params.get("streaming_callback") is not None: - init_params["streaming_callback"] = deserialize_callable(init_params["streaming_callback"]) - - deserialize_tools_or_toolset_inplace(init_params, key="tools") - - return default_from_dict(cls, data) - - def _prepare_generator_inputs(self, streaming_callback: Optional[StreamingCallbackT] = None) -> Dict[str, Any]: - """Prepare inputs for the chat generator.""" - generator_inputs: Dict[str, Any] = {"tools": self.tools} - if streaming_callback is not None: - generator_inputs["streaming_callback"] = streaming_callback - return generator_inputs - - def _create_agent_span(self) -> Any: - """Create a span for the agent run.""" - return tracing.tracer.trace( - "haystack.agent.run", - tags={ - "haystack.agent.max_steps": self.max_agent_steps, - "haystack.agent.tools": self.tools, - "haystack.agent.exit_conditions": self.exit_conditions, - "haystack.agent.state_schema": _schema_to_dict(self.state_schema), - }, - ) - - def run( - self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any - ) -> Dict[str, Any]: - """ - Process messages and execute tools until an exit condition is met. - - :param messages: List of Haystack ChatMessage objects to process. - If a list of dictionaries is provided, each dictionary will be converted to a ChatMessage object. - :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. - The same callback can be configured to emit tool results when a tool is called. - :param kwargs: Additional data to pass to the State schema used by the Agent. - The keys must match the schema defined in the Agent's `state_schema`. - :returns: - A dictionary with the following keys: - - "messages": List of all messages exchanged during the agent's run. - - "last_message": The last message exchanged during the agent's run. - - Any additional keys defined in the `state_schema`. - :raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`. - """ - if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"): - raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run()'.") - - if self.system_prompt is not None: - messages = [ChatMessage.from_system(self.system_prompt)] + messages - - if all(m.is_from(ChatRole.SYSTEM) for m in messages): - logger.warning( - "All messages provided to the Agent component are system messages. This is not recommended as the " - "Agent will not perform any actions specific to user input. Consider adding user messages to the input." - ) - - state = State(schema=self.state_schema, data=kwargs) - state.set("messages", messages) - component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) - - streaming_callback = select_streaming_callback( - init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False - ) - generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback) - with self._create_agent_span() as span: - span.set_content_tag( - "haystack.agent.input", - _deepcopy_with_exceptions({"messages": messages, "streaming_callback": streaming_callback, **kwargs}), - ) - counter = 0 - while counter < self.max_agent_steps: - # 1. Call the ChatGenerator - result = Pipeline._run_component( - component_name="chat_generator", - component={"instance": self.chat_generator}, - inputs={"messages": messages, **generator_inputs}, - component_visits=component_visits, - parent_span=span, - ) - llm_messages = result["replies"] - state.set("messages", llm_messages) - - # 2. Check if any of the LLM responses contain a tool call or if the LLM is not using tools - if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None: - counter += 1 - break - - # 3. Call the ToolInvoker - # We only send the messages from the LLM to the tool invoker - tool_invoker_result = Pipeline._run_component( - component_name="tool_invoker", - component={"instance": self._tool_invoker}, - inputs={"messages": llm_messages, "state": state, "streaming_callback": streaming_callback}, - component_visits=component_visits, - parent_span=span, - ) - tool_messages = tool_invoker_result["tool_messages"] - state = tool_invoker_result["state"] - state.set("messages", tool_messages) - - # 4. Check if any LLM message's tool call name matches an exit condition - if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages): - counter += 1 - break - - # 5. Fetch the combined messages and send them back to the LLM - messages = state.get("messages") - counter += 1 - - if counter >= self.max_agent_steps: - logger.warning( - "Agent reached maximum agent steps of {max_agent_steps}, stopping.", - max_agent_steps=self.max_agent_steps, - ) - span.set_content_tag("haystack.agent.output", state.data) - span.set_tag("haystack.agent.steps_taken", counter) - - result = {**state.data} - all_messages = state.get("messages") - if all_messages: - result.update({"last_message": all_messages[-1]}) - return result - - async def run_async( - self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any - ) -> Dict[str, Any]: - """ - Asynchronously process messages and execute tools until the exit condition is met. - - This is the asynchronous version of the `run` method. It follows the same logic but uses - asynchronous operations where possible, such as calling the `run_async` method of the ChatGenerator - if available. - - :param messages: List of chat messages to process - :param streaming_callback: An asynchronous callback that will be invoked when a response - is streamed from the LLM. The same callback can be configured to emit tool results - when a tool is called. - :param kwargs: Additional data to pass to the State schema used by the Agent. - The keys must match the schema defined in the Agent's `state_schema`. - :returns: - A dictionary with the following keys: - - "messages": List of all messages exchanged during the agent's run. - - "last_message": The last message exchanged during the agent's run. - - Any additional keys defined in the `state_schema`. - :raises RuntimeError: If the Agent component wasn't warmed up before calling `run_async()`. - """ - if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"): - raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run_async()'.") - - if self.system_prompt is not None: - messages = [ChatMessage.from_system(self.system_prompt)] + messages - - if all(m.is_from(ChatRole.SYSTEM) for m in messages): - logger.warning( - "All messages provided to the Agent component are system messages. This is not recommended as the " - "Agent will not perform any actions specific to user input. Consider adding user messages to the input." - ) - - state = State(schema=self.state_schema, data=kwargs) - state.set("messages", messages) - component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) - - streaming_callback = select_streaming_callback( - init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True - ) - generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback) - with self._create_agent_span() as span: - span.set_content_tag( - "haystack.agent.input", - _deepcopy_with_exceptions({"messages": messages, "streaming_callback": streaming_callback, **kwargs}), - ) - counter = 0 - while counter < self.max_agent_steps: - # 1. Call the ChatGenerator - result = await AsyncPipeline._run_component_async( - component_name="chat_generator", - component={"instance": self.chat_generator}, - component_inputs={"messages": messages, **generator_inputs}, - component_visits=component_visits, - max_runs_per_component=self.max_agent_steps, - parent_span=span, - ) - llm_messages = result["replies"] - state.set("messages", llm_messages) - - # 2. Check if any of the LLM responses contain a tool call or if the LLM is not using tools - if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None: - counter += 1 - break - - # 3. Call the ToolInvoker - # We only send the messages from the LLM to the tool invoker - # Check if the ToolInvoker supports async execution. Currently, it doesn't. - tool_invoker_result = await AsyncPipeline._run_component_async( - component_name="tool_invoker", - component={"instance": self._tool_invoker}, - component_inputs={ - "messages": llm_messages, - "state": state, - "streaming_callback": streaming_callback, - }, - component_visits=component_visits, - max_runs_per_component=self.max_agent_steps, - parent_span=span, - ) - tool_messages = tool_invoker_result["tool_messages"] - state = tool_invoker_result["state"] - state.set("messages", tool_messages) - - # 4. Check if any LLM message's tool call name matches an exit condition - if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages): - counter += 1 - break - - # 5. Fetch the combined messages and send them back to the LLM - messages = state.get("messages") - counter += 1 - - if counter >= self.max_agent_steps: - logger.warning( - "Agent reached maximum agent steps of {max_agent_steps}, stopping.", - max_agent_steps=self.max_agent_steps, - ) - span.set_content_tag("haystack.agent.output", state.data) - span.set_tag("haystack.agent.steps_taken", counter) - - result = {**state.data} - all_messages = state.get("messages") - if all_messages: - result.update({"last_message": all_messages[-1]}) - return result - - def _check_exit_conditions(self, llm_messages: List[ChatMessage], tool_messages: List[ChatMessage]) -> bool: - """ - Check if any of the LLM messages' tool calls match an exit condition and if there are no errors. - - :param llm_messages: List of messages from the LLM - :param tool_messages: List of messages from the tool invoker - :return: True if an exit condition is met and there are no errors, False otherwise - """ - matched_exit_conditions = set() - has_errors = False - - for msg in llm_messages: - if msg.tool_call and msg.tool_call.tool_name in self.exit_conditions: - matched_exit_conditions.add(msg.tool_call.tool_name) - - # Check if any error is specifically from the tool matching the exit condition - tool_errors = [ - tool_msg.tool_call_result.error - for tool_msg in tool_messages - if tool_msg.tool_call_result is not None - and tool_msg.tool_call_result.origin.tool_name == msg.tool_call.tool_name - ] - if any(tool_errors): - has_errors = True - # No need to check further if we found an error - break - - # Only return True if at least one exit condition was matched AND none had errors - return bool(matched_exit_conditions) and not has_errors diff --git a/test/components/agents/test_state_class_experimental.py b/test/components/agents/test_state_class_experimental.py deleted file mode 100644 index 9b7d3d7728..0000000000 --- a/test/components/agents/test_state_class_experimental.py +++ /dev/null @@ -1,477 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - -import inspect -from dataclasses import dataclass -from typing import Dict, Generic, List, Optional, TypeVar, Union - -import pytest - -from haystack.components.agents.state.state import ( - State, - _is_list_type, - _is_valid_type, - _schema_from_dict, - _schema_to_dict, - _validate_schema, - merge_lists, -) -from haystack.dataclasses import ChatMessage - - -@pytest.fixture -def basic_schema(): - return {"numbers": {"type": list}, "metadata": {"type": dict}, "name": {"type": str}} - - -def numbers_handler(current, new): - if current is None: - return sorted(set(new)) - return sorted(set(current + new)) - - -@pytest.fixture -def complex_schema(): - return {"numbers": {"type": list, "handler": numbers_handler}, "metadata": {"type": dict}, "name": {"type": str}} - - -def test_is_list_type(): - assert _is_list_type(list) is True - assert _is_list_type(List[int]) is True - assert _is_list_type(List[str]) is True - assert _is_list_type(dict) is False - assert _is_list_type(int) is False - assert _is_list_type(Union[List[int], None]) is False - - -class TestMergeLists: - def test_merge_two_lists(self): - current = [1, 2, 3] - new = [4, 5, 6] - result = merge_lists(current, new) - assert result == [1, 2, 3, 4, 5, 6] - # Ensure original lists weren't modified - assert current == [1, 2, 3] - assert new == [4, 5, 6] - - def test_append_to_list(self): - current = [1, 2, 3] - new = 4 - result = merge_lists(current, new) - assert result == [1, 2, 3, 4] - assert current == [1, 2, 3] # Ensure original wasn't modified - - def test_create_new_list(self): - current = 1 - new = 2 - result = merge_lists(current, new) - assert result == [1, 2] - - def test_replace_with_list(self): - current = 1 - new = [2, 3] - result = merge_lists(current, new) - assert result == [1, 2, 3] - - -class TestIsValidType: - def test_builtin_types(self): - assert _is_valid_type(str) is True - assert _is_valid_type(int) is True - assert _is_valid_type(dict) is True - assert _is_valid_type(list) is True - assert _is_valid_type(tuple) is True - assert _is_valid_type(set) is True - assert _is_valid_type(bool) is True - assert _is_valid_type(float) is True - - def test_generic_types(self): - assert _is_valid_type(List[str]) is True - assert _is_valid_type(Dict[str, int]) is True - assert _is_valid_type(List[Dict[str, int]]) is True - assert _is_valid_type(Dict[str, List[int]]) is True - - def test_custom_classes(self): - @dataclass - class CustomClass: - value: int - - T = TypeVar("T") - - class GenericCustomClass(Generic[T]): - pass - - # Test regular and generic custom classes - assert _is_valid_type(CustomClass) is True - assert _is_valid_type(GenericCustomClass) is True - assert _is_valid_type(GenericCustomClass[int]) is True - - # Test generic types with custom classes - assert _is_valid_type(List[CustomClass]) is True - assert _is_valid_type(Dict[str, CustomClass]) is True - assert _is_valid_type(Dict[str, GenericCustomClass[int]]) is True - - def test_invalid_types(self): - # Test regular values - assert _is_valid_type(42) is False - assert _is_valid_type("string") is False - assert _is_valid_type([1, 2, 3]) is False - assert _is_valid_type({"a": 1}) is False - assert _is_valid_type(True) is False - - # Test class instances - @dataclass - class SampleClass: - value: int - - instance = SampleClass(42) - assert _is_valid_type(instance) is False - - # Test callable objects - assert _is_valid_type(len) is False - assert _is_valid_type(lambda x: x) is False - assert _is_valid_type(print) is False - - def test_union_and_optional_types(self): - # Test basic Union types - assert _is_valid_type(Union[str, int]) is True - assert _is_valid_type(Union[str, None]) is True - assert _is_valid_type(Union[List[int], Dict[str, str]]) is True - - # Test Optional types (which are Union[T, None]) - assert _is_valid_type(Optional[str]) is True - assert _is_valid_type(Optional[List[int]]) is True - assert _is_valid_type(Optional[Dict[str, list]]) is True - - # Test that Union itself is not a valid type (only instantiated Unions are) - assert _is_valid_type(Union) is False - - def test_nested_generic_types(self): - assert _is_valid_type(List[List[Dict[str, List[int]]]]) is True - assert _is_valid_type(Dict[str, List[Dict[str, set]]]) is True - assert _is_valid_type(Dict[str, Optional[List[int]]]) is True - assert _is_valid_type(List[Union[str, Dict[str, List[int]]]]) is True - - def test_edge_cases(self): - # Test None and NoneType - assert _is_valid_type(None) is False - assert _is_valid_type(type(None)) is True - - # Test functions and methods - def sample_func(): - pass - - assert _is_valid_type(sample_func) is False - assert _is_valid_type(type(sample_func)) is True - - # Test modules - assert _is_valid_type(inspect) is False - - # Test type itself - assert _is_valid_type(type) is True - - @pytest.mark.parametrize( - "test_input,expected", - [ - (str, True), - (int, True), - (List[int], True), - (Dict[str, int], True), - (Union[str, int], True), - (Optional[str], True), - (42, False), - ("string", False), - ([1, 2, 3], False), - (lambda x: x, False), - ], - ) - def test_parametrized_cases(self, test_input, expected): - assert _is_valid_type(test_input) is expected - - -class TestState: - def test_validate_schema_valid(self, basic_schema): - # Should not raise any exceptions - _validate_schema(basic_schema) - - def test_validate_schema_invalid_type(self): - invalid_schema = {"test": {"type": "not_a_type"}} - with pytest.raises(ValueError, match="must be a Python type"): - _validate_schema(invalid_schema) - - def test_validate_schema_missing_type(self): - invalid_schema = {"test": {"handler": lambda x, y: x + y}} - with pytest.raises(ValueError, match="missing a 'type' entry"): - _validate_schema(invalid_schema) - - def test_validate_schema_invalid_handler(self): - invalid_schema = {"test": {"type": list, "handler": "not_callable"}} - with pytest.raises(ValueError, match="must be callable or None"): - _validate_schema(invalid_schema) - - def test_state_initialization(self, basic_schema): - # Test empty initialization - state = State(basic_schema) - assert state.data == {} - - # Test initialization with data - initial_data = {"numbers": [1, 2, 3], "name": "test"} - state = State(basic_schema, initial_data) - assert state.data["numbers"] == [1, 2, 3] - assert state.data["name"] == "test" - - def test_state_get(self, basic_schema): - state = State(basic_schema, {"name": "test"}) - assert state.get("name") == "test" - assert state.get("non_existent") is None - assert state.get("non_existent", "default") == "default" - - def test_state_set_basic(self, basic_schema): - state = State(basic_schema) - - # Test setting new values - state.set("numbers", [1, 2]) - assert state.get("numbers") == [1, 2] - - # Test updating existing values - state.set("numbers", [3, 4]) - assert state.get("numbers") == [1, 2, 3, 4] - - def test_state_set_with_handler(self, complex_schema): - state = State(complex_schema) - - # Test custom handler for numbers - state.set("numbers", [3, 2, 1]) - assert state.get("numbers") == [1, 2, 3] - - state.set("numbers", [6, 5, 4]) - assert state.get("numbers") == [1, 2, 3, 4, 5, 6] - - def test_state_set_with_handler_override(self, basic_schema): - state = State(basic_schema) - - # Custom handler that concatenates strings - custom_handler = lambda current, new: f"{current}-{new}" if current else new - - state.set("name", "first") - state.set("name", "second", handler_override=custom_handler) - assert state.get("name") == "first-second" - - def test_state_has(self, basic_schema): - state = State(basic_schema, {"name": "test"}) - assert state.has("name") is True - assert state.has("non_existent") is False - - def test_state_empty_schema(self): - state = State({}) - assert state.data == {} - - # Instead of comparing the entire schema directly, check structure separately - assert "messages" in state.schema - assert state.schema["messages"]["type"] == List[ChatMessage] - assert callable(state.schema["messages"]["handler"]) - - with pytest.raises(ValueError, match="Key 'any_key' not found in schema"): - state.set("any_key", "value") - - def test_state_none_values(self, basic_schema): - state = State(basic_schema) - state.set("name", None) - assert state.get("name") is None - state.set("name", "value") - assert state.get("name") == "value" - - def test_state_merge_lists(self, basic_schema): - state = State(basic_schema) - state.set("numbers", "not_a_list") - assert state.get("numbers") == ["not_a_list"] - state.set("numbers", [1, 2]) - assert state.get("numbers") == ["not_a_list", 1, 2] - - def test_state_nested_structures(self): - schema = { - "complex": { - "type": Dict[str, List[int]], - "handler": lambda current, new: { - k: current.get(k, []) + new.get(k, []) for k in set(current.keys()) | set(new.keys()) - } - if current - else new, - } - } - - state = State(schema) - state.set("complex", {"a": [1, 2], "b": [3, 4]}) - state.set("complex", {"b": [5, 6], "c": [7, 8]}) - - expected = {"a": [1, 2], "b": [3, 4, 5, 6], "c": [7, 8]} - assert state.get("complex") == expected - - def test_schema_to_dict(self, basic_schema): - expected_dict = {"numbers": {"type": "list"}, "metadata": {"type": "dict"}, "name": {"type": "str"}} - result = _schema_to_dict(basic_schema) - assert result == expected_dict - - def test_schema_to_dict_with_handlers(self, complex_schema): - expected_dict = { - "numbers": {"type": "list", "handler": "test_state_class.numbers_handler"}, - "metadata": {"type": "dict"}, - "name": {"type": "str"}, - } - result = _schema_to_dict(complex_schema) - assert result == expected_dict - - def test_schema_from_dict(self, basic_schema): - schema_dict = {"numbers": {"type": "list"}, "metadata": {"type": "dict"}, "name": {"type": "str"}} - result = _schema_from_dict(schema_dict) - assert result == basic_schema - - def test_schema_from_dict_with_handlers(self, complex_schema): - schema_dict = { - "numbers": {"type": "list", "handler": "test_state_class.numbers_handler"}, - "metadata": {"type": "dict"}, - "name": {"type": "str"}, - } - result = _schema_from_dict(schema_dict) - assert result == complex_schema - - def test_state_mutability(self): - state = State({"my_list": {"type": list}}, {"my_list": [1, 2]}) - - my_list = state.get("my_list") - my_list.append(3) - - assert state.get("my_list") == [1, 2] - - def test_state_to_dict(self): - # we test dict, a python type and a haystack dataclass - state_schema = { - "numbers": {"type": int}, - "messages": {"type": List[ChatMessage]}, - "dict_of_lists": {"type": dict}, - } - - data = { - "numbers": 1, - "messages": [ChatMessage.from_user(text="Hello, world!")], - "dict_of_lists": {"numbers": [1, 2, 3]}, - } - state = State(state_schema, data) - state_dict = state.to_dict() - assert state_dict["schema"] == { - "numbers": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"}, - "messages": { - "type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", - "handler": "haystack.components.agents.state.state_utils.merge_lists", - }, - "dict_of_lists": {"type": "dict", "handler": "haystack.components.agents.state.state_utils.replace_values"}, - } - assert state_dict["data"] == { - "serialization_schema": { - "type": "object", - "properties": { - "numbers": {"type": "integer"}, - "messages": {"type": "array", "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}}, - "dict_of_lists": { - "type": "object", - "properties": {"numbers": {"type": "array", "items": {"type": "integer"}}}, - }, - }, - }, - "serialized_data": { - "numbers": 1, - "messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}], - "dict_of_lists": {"numbers": [1, 2, 3]}, - }, - } - - def test_state_from_dict(self): - state_dict = { - "schema": { - "numbers": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"}, - "messages": { - "type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", - "handler": "haystack.components.agents.state.state_utils.merge_lists", - }, - "dict_of_lists": { - "type": "dict", - "handler": "haystack.components.agents.state.state_utils.replace_values", - }, - }, - "data": { - "serialization_schema": { - "type": "object", - "properties": { - "numbers": {"type": "integer"}, - "messages": { - "type": "array", - "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}, - }, - "dict_of_lists": { - "type": "object", - "properties": {"numbers": {"type": "array", "items": {"type": "integer"}}}, - }, - }, - }, - "serialized_data": { - "numbers": 1, - "messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}], - "dict_of_lists": {"numbers": [1, 2, 3]}, - }, - }, - } - state = State.from_dict(state_dict) - # Check types are correctly converted - assert state.schema["numbers"]["type"] == int - assert state.schema["dict_of_lists"]["type"] == dict - # Check handlers are functions, not comparing exact functions as they might be different references - assert callable(state.schema["numbers"]["handler"]) - assert callable(state.schema["messages"]["handler"]) - assert callable(state.schema["dict_of_lists"]["handler"]) - # Check data is correct - assert state.data["numbers"] == 1 - assert state.data["messages"] == [ChatMessage.from_user(text="Hello, world!")] - assert state.data["dict_of_lists"] == {"numbers": [1, 2, 3]} - - def test_state_from_dict_legacy(self): - # this is the old format of the state dictionary - # it is kept for backward compatibility - # it will be removed in Haystack 2.16.0 - state_dict = { - "schema": { - "numbers": {"type": "int", "handler": "haystack.components.agents.state.state_utils.replace_values"}, - "messages": { - "type": "typing.List[haystack.dataclasses.chat_message.ChatMessage]", - "handler": "haystack.components.agents.state.state_utils.merge_lists", - }, - "dict_of_lists": { - "type": "dict", - "handler": "haystack.components.agents.state.state_utils.replace_values", - }, - }, - "data": { - "serialization_schema": { - "numbers": {"type": "integer"}, - "messages": {"type": "array", "items": {"type": "haystack.dataclasses.chat_message.ChatMessage"}}, - "dict_of_lists": {"type": "object"}, - }, - "serialized_data": { - "numbers": 1, - "messages": [{"role": "user", "meta": {}, "name": None, "content": [{"text": "Hello, world!"}]}], - "dict_of_lists": {"numbers": [1, 2, 3]}, - }, - }, - } - state = State.from_dict(state_dict) - # Check types are correctly converted - assert state.schema["numbers"]["type"] == int - assert state.schema["dict_of_lists"]["type"] == dict - # Check handlers are functions, not comparing exact functions as they might be different references - assert callable(state.schema["numbers"]["handler"]) - assert callable(state.schema["messages"]["handler"]) - assert callable(state.schema["dict_of_lists"]["handler"]) - # Check data is correct - assert state.data["numbers"] == 1 - assert state.data["messages"] == [ChatMessage.from_user(text="Hello, world!")] - assert state.data["dict_of_lists"] == {"numbers": [1, 2, 3]} From 7c608e0836de48e4402cbfb917fb7ebdf89e80f7 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 15 Jul 2025 10:54:26 +0200 Subject: [PATCH 05/21] fixing circular imports --- haystack/components/tools/tool_invoker.py | 3 +- haystack/core/pipeline/pipeline.py | 218 ++-------------------- 2 files changed, 16 insertions(+), 205 deletions(-) diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 679460496f..af425c50f9 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -9,9 +9,10 @@ from functools import partial from typing import Any, Dict, List, Optional, Set, Union -from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.agents import State +from haystack.core.component.component import component from haystack.core.component.sockets import Sockets +from haystack.core.serialization import default_from_dict, default_to_dict, logging from haystack.dataclasses import ChatMessage, ToolCall from haystack.dataclasses.streaming_chunk import StreamingCallbackT, StreamingChunk, select_streaming_callback from haystack.tools import ( diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index 1b702e4073..0d86f7183f 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -9,8 +9,6 @@ from typing import Any, Dict, Mapping, Optional, Set, Union, cast from haystack import logging, tracing - -# from haystack.components.agents import Agent from haystack.core.component import Component from haystack.core.errors import PipelineInvalidResumeStateError, PipelineRuntimeError from haystack.core.pipeline.base import ( @@ -85,195 +83,6 @@ def _run_component( return cast(Dict[Any, Any], component_output) - # ToDo: delete - def run_old( # noqa: PLR0915, PLR0912 - self, data: Dict[str, Any], include_outputs_from: Optional[Set[str]] = None - ) -> Dict[str, Any]: - """ - Runs the Pipeline with given input data. - - Usage: - ```python - from haystack import Pipeline, Document - from haystack.utils import Secret - from haystack.document_stores.in_memory import InMemoryDocumentStore - from haystack.components.retrievers.in_memory import InMemoryBM25Retriever - from haystack.components.generators import OpenAIGenerator - from haystack.components.builders.answer_builder import AnswerBuilder - from haystack.components.builders.prompt_builder import PromptBuilder - - # Write documents to InMemoryDocumentStore - document_store = InMemoryDocumentStore() - document_store.write_documents([ - Document(content="My name is Jean and I live in Paris."), - Document(content="My name is Mark and I live in Berlin."), - Document(content="My name is Giorgio and I live in Rome.") - ]) - - prompt_template = \"\"\" - Given these documents, answer the question. - Documents: - {% for doc in documents %} - {{ doc.content }} - {% endfor %} - Question: {{question}} - Answer: - \"\"\" - - retriever = InMemoryBM25Retriever(document_store=document_store) - prompt_builder = PromptBuilder(template=prompt_template) - llm = OpenAIGenerator(api_key=Secret.from_token(api_key)) - - rag_pipeline = Pipeline() - rag_pipeline.add_component("retriever", retriever) - rag_pipeline.add_component("prompt_builder", prompt_builder) - rag_pipeline.add_component("llm", llm) - rag_pipeline.connect("retriever", "prompt_builder.documents") - rag_pipeline.connect("prompt_builder", "llm") - - # Ask a question - question = "Who lives in Paris?" - results = rag_pipeline.run( - { - "retriever": {"query": question}, - "prompt_builder": {"question": question}, - } - ) - - print(results["llm"]["replies"]) - # Jean lives in Paris - ``` - - :param data: - A dictionary of inputs for the pipeline's components. Each key is a component name - and its value is a dictionary of that component's input parameters: - ``` - data = { - "comp1": {"input1": 1, "input2": 2}, - } - ``` - For convenience, this format is also supported when input names are unique: - ``` - data = { - "input1": 1, "input2": 2, - } - ``` - :param include_outputs_from: - Set of component names whose individual outputs are to be - included in the pipeline's output. For components that are - invoked multiple times (in a loop), only the last-produced - output is included. - :returns: - A dictionary where each entry corresponds to a component name - and its output. If `include_outputs_from` is `None`, this dictionary - will only contain the outputs of leaf components, i.e., components - without outgoing connections. - - :raises ValueError: - If invalid inputs are provided to the pipeline. - :raises PipelineRuntimeError: - If the Pipeline contains cycles with unsupported connections that would cause - it to get stuck and fail running. - Or if a Component fails or returns output in an unsupported type. - :raises PipelineMaxComponentRuns: - If a Component reaches the maximum number of times it can be run in this Pipeline. - """ - pipeline_running(self) - - # TODO: Remove this warmup once we can check reliably whether a component has been warmed up or not - # As of now it's here to make sure we don't have failing tests that assume warm_up() is called in run() - self.warm_up() - - # normalize `data` - data = self._prepare_component_input_data(data) - - # Raise ValueError if input is malformed in some way - self.validate_input(data) - - if include_outputs_from is None: - include_outputs_from = set() - - # We create a list of components in the pipeline sorted by name, so that the algorithm runs deterministically - # and independent of insertion order into the pipeline. - ordered_component_names = sorted(self.graph.nodes.keys()) - - # We track component visits to decide if a component can run. - component_visits = dict.fromkeys(ordered_component_names, 0) - - # We need to access a component's receivers multiple times during a pipeline run. - # We store them here for easy access. - cached_receivers = {name: self._find_receivers_from(name) for name in ordered_component_names} - - cached_topological_sort = None - - pipeline_outputs: Dict[str, Any] = {} - with tracing.tracer.trace( - "haystack.pipeline.run", - tags={ - "haystack.pipeline.input_data": data, - "haystack.pipeline.output_data": pipeline_outputs, - "haystack.pipeline.metadata": self.metadata, - "haystack.pipeline.max_runs_per_component": self._max_runs_per_component, - }, - ) as span: - inputs = self._convert_to_internal_format(pipeline_inputs=data) - priority_queue = self._fill_queue(ordered_component_names, inputs, component_visits) - - # check if pipeline is blocked before execution - self.validate_pipeline(priority_queue) - - while True: - candidate = self._get_next_runnable_component(priority_queue, component_visits) - if candidate is None: - break - - priority, component_name, component = candidate - if len(priority_queue) > 0 and priority in [ComponentPriority.DEFER, ComponentPriority.DEFER_LAST]: - component_name, topological_sort = self._tiebreak_waiting_components( - component_name=component_name, - priority=priority, - priority_queue=priority_queue, - topological_sort=cached_topological_sort, - ) - cached_topological_sort = topological_sort - component = self._get_component_with_graph_metadata_and_visits( - component_name, component_visits[component_name] - ) - - component_inputs = self._consume_component_inputs( - component_name=component_name, component=component, inputs=inputs - ) - # We need to add missing defaults using default values from input sockets because the run signature - # might not provide these defaults for components with inputs defined dynamically upon component - # initialization - component_inputs = self._add_missing_input_defaults(component_inputs, component["input_sockets"]) - - component_outputs = self._run_component( - component_name=component_name, - component=component, - inputs=component_inputs, - component_visits=component_visits, - parent_span=span, - ) - - # Updates global input state with component outputs and returns outputs that should go to - # pipeline outputs. - - component_pipeline_outputs = self._write_component_outputs( - component_name=component_name, - component_outputs=component_outputs, - inputs=inputs, - receivers=cached_receivers[component_name], - include_outputs_from=include_outputs_from, - ) - - if component_pipeline_outputs: - pipeline_outputs[component_name] = _deepcopy_with_exceptions(component_pipeline_outputs) - if self._is_queue_stale(priority_queue): - priority_queue = self._fill_queue(ordered_component_names, inputs, component_visits) - - return pipeline_outputs - def _handle_resume_state(self, resume_state: Dict[str, Any]) -> tuple[Dict[str, int], Dict[str, Any], bool, list]: """ Handle resume state initialization. @@ -536,19 +345,20 @@ def run( # noqa: PLR0915, PLR0912 agent_breakpoint = False if isinstance(break_point, AgentBreakpoint): - # component_instance = component["instance"] - # if isinstance(component_instance, Agent): - component_inputs = handle_agent_break_point( - break_point, - component_name, - component_inputs, - inputs, - component_visits, - ordered_component_names, - data, - debug_path, - ) - agent_breakpoint = True + component_instance = component["instance"] + # Use type checking by class name to avoid circular import + if component_instance.__class__.__name__ == "Agent": + component_inputs = handle_agent_break_point( + break_point, + component_name, + component_inputs, + inputs, + component_visits, + ordered_component_names, + data, + debug_path, + ) + agent_breakpoint = True if not agent_breakpoint and isinstance(break_point, Breakpoint): breakpoint_triggered = check_regular_break_point(break_point, component_name, component_visits) From 6b2a1f14324828bcc4548b3e421fac5978240626 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 15 Jul 2025 11:37:48 +0200 Subject: [PATCH 06/21] decoupling resume and initial run() for agent --- haystack/components/agents/agent.py | 49 ++++++++++++++--------------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 713ab0ecbc..1d7bc3ea5b 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -423,23 +423,20 @@ def run( # noqa: PLR0915 state.set("messages", messages) else: - # initialize new state if not resuming + if self.system_prompt is not None: + messages = [ChatMessage.from_system(self.system_prompt)] + messages + + if all(m.is_from(ChatRole.SYSTEM) for m in messages): + logger.warning( + "All messages provided to the Agent component are system messages. This is not recommended as the " + "Agent will not perform any actions specific to user input. Consider adding user messages to the " + "input." + ) + state = State(schema=self.state_schema, data=kwargs) state.set("messages", messages) component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) - if self.system_prompt is not None: - messages = [ChatMessage.from_system(self.system_prompt)] + messages - - if all(m.is_from(ChatRole.SYSTEM) for m in messages): - logger.warning( - "All messages provided to the Agent component are system messages. This is not recommended as the " - "Agent will not perform any actions specific to user input. Consider adding user messages to the input." - ) - - # state = State(schema=self.state_schema, data=kwargs) - # state.set("messages", messages) - streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False ) @@ -571,6 +568,7 @@ async def run_async( # noqa: PLR0915 if resume_state: # Extract component visits from pipeline state component_visits = resume_state.get("pipeline_state", {}).get("component_visits", {}) + # Initialize with default values if not present in resume state component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) | component_visits @@ -580,27 +578,26 @@ async def run_async( # noqa: PLR0915 # Extract and deserialize messages from pipeline state raw_messages = resume_state.get("pipeline_state", {}).get("inputs", {}).get("messages", messages) + # Convert raw message dictionaries to ChatMessage objects messages = [ChatMessage.from_dict(msg) if isinstance(msg, dict) else msg for msg in raw_messages] state.set("messages", messages) + else: - # Initialize new state if not resuming + if self.system_prompt is not None: + messages = [ChatMessage.from_system(self.system_prompt)] + messages + + if all(m.is_from(ChatRole.SYSTEM) for m in messages): + logger.warning( + "All messages provided to the Agent component are system messages. This is not recommended as the " + "Agent will not perform any actions specific to user input. Consider adding user messages to the " + "input." + ) + state = State(schema=self.state_schema, data=kwargs) state.set("messages", messages) component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) - if self.system_prompt is not None: - messages = [ChatMessage.from_system(self.system_prompt)] + messages - - if all(m.is_from(ChatRole.SYSTEM) for m in messages): - logger.warning( - "All messages provided to the Agent component are system messages. This is not recommended as the " - "Agent will not perform any actions specific to user input. Consider adding user messages to the input." - ) - - # state = State(schema=self.state_schema, data=kwargs) - # state.set("messages", messages) - streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True ) From 9030636bc40ffef96906e5c769658bf6f782b886 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 15 Jul 2025 11:42:11 +0200 Subject: [PATCH 07/21] adding release notes --- ...eat-pipeline-and-agents-breakpoints-aa819128c8c5f456.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 releasenotes/notes/feat-pipeline-and-agents-breakpoints-aa819128c8c5f456.yaml diff --git a/releasenotes/notes/feat-pipeline-and-agents-breakpoints-aa819128c8c5f456.yaml b/releasenotes/notes/feat-pipeline-and-agents-breakpoints-aa819128c8c5f456.yaml new file mode 100644 index 0000000000..74c10fd100 --- /dev/null +++ b/releasenotes/notes/feat-pipeline-and-agents-breakpoints-aa819128c8c5f456.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + The Pipeline and the Agent now support breakpoints, a feature useful for debugging. A breakpoint is associated with a component and it stops the execution of a Pipeline/Agent generating a JSON file with the execution status, which can be inspected + edited and later used to resume the execution. From 98dee135d4b7dfaf1af427820a6c29565d620a30 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 15 Jul 2025 16:12:34 +0200 Subject: [PATCH 08/21] re-raising BreakPointException from pipeline.run() --- haystack/core/pipeline/pipeline.py | 8 +++- .../test_agent_breakpoints_inside_pipeline.py | 42 ------------------- 2 files changed, 7 insertions(+), 43 deletions(-) diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index 0d86f7183f..30f149eab0 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -10,7 +10,7 @@ from haystack import logging, tracing from haystack.core.component import Component -from haystack.core.errors import PipelineInvalidResumeStateError, PipelineRuntimeError +from haystack.core.errors import BreakpointException, PipelineInvalidResumeStateError, PipelineRuntimeError from haystack.core.pipeline.base import ( _COMPONENT_INPUT, _COMPONENT_OUTPUT, @@ -71,8 +71,14 @@ def _run_component( logger.info("Running component {component_name}", component_name=component_name) try: component_output = instance.run(**inputs) + except BreakpointException as error: + # Re-raise BreakpointException to preserve the original exception context + # This is important when Agent components internally use Pipeline._run_component + # and trigger breakpoints that need to bubble up to the main pipeline + raise error except Exception as error: raise PipelineRuntimeError.from_exception(component_name, instance.__class__, error) from error + component_visits[component_name] += 1 if not isinstance(component_output, Mapping): diff --git a/test/components/agents/test_agent_breakpoints_inside_pipeline.py b/test/components/agents/test_agent_breakpoints_inside_pipeline.py index 6a56e6ec17..d5a019b538 100644 --- a/test/components/agents/test_agent_breakpoints_inside_pipeline.py +++ b/test/components/agents/test_agent_breakpoints_inside_pipeline.py @@ -234,17 +234,6 @@ def test_chat_generator_breakpoint_in_pipeline_agent(): assert e.state is not None assert "messages" in e.state assert e.results is not None - except PipelineRuntimeError as e: - # propagated exception to core Pipeline - assure that the cause is a PipelineBreakpointException - if hasattr(e, "__cause__") and isinstance(e.__cause__, BreakpointException): - original_exception = e.__cause__ - assert original_exception.component == "chat_generator" - assert original_exception.state is not None - assert "messages" in original_exception.state - assert original_exception.results is not None - else: - # re-raise if it's a different PipelineRuntimeError - test failed - raise # verify that debug/state file was created chat_generator_state_files = list(Path(debug_path).glob("database_agent_chat_generator_*.json")) @@ -269,17 +258,6 @@ def test_tool_breakpoint_in_pipeline_agent(): assert e.state is not None assert "messages" in e.state assert e.results is not None - except PipelineRuntimeError as e: - # propagated exception to core Pipeline - assure that the cause is a PipelineBreakpointException - if hasattr(e, "__cause__") and isinstance(e.__cause__, BreakpointException): - original_exception = e.__cause__ - assert original_exception.component == "tool_invoker" - assert original_exception.state is not None - assert "messages" in original_exception.state - assert original_exception.results is not None - else: - # re-raise if it's a different PipelineRuntimeError - test failed - raise # verify that debug/state file was created tool_invoker_state_files = list(Path(debug_path).glob("database_agent_tool_invoker_*.json")) @@ -306,16 +284,6 @@ def test_agent_breakpoint_chat_generator_and_resume_pipeline(): assert "messages" in e.state assert e.results is not None - except PipelineRuntimeError as e: - if hasattr(e, "__cause__") and isinstance(e.__cause__, BreakpointException): - original_exception = e.__cause__ - assert original_exception.component == "chat_generator" - assert original_exception.state is not None - assert "messages" in original_exception.state - assert original_exception.results is not None - else: - raise - # verify that the state file was created chat_generator_state_files = list(Path(debug_path).glob("database_agent_chat_generator_*.json")) assert len(chat_generator_state_files) > 0, f"No chat_generator state files found in {debug_path}" @@ -368,16 +336,6 @@ def test_agent_breakpoint_tool_and_resume_pipeline(): assert "messages" in e.state assert e.results is not None - except PipelineRuntimeError as e: - if hasattr(e, "__cause__") and isinstance(e.__cause__, BreakpointException): - original_exception = e.__cause__ - assert original_exception.component == "tool_invoker" - assert original_exception.state is not None - assert "messages" in original_exception.state - assert original_exception.results is not None - else: - raise - # verify that the state file was created tool_invoker_state_files = list(Path(debug_path).glob("database_agent_tool_invoker_*.json")) assert len(tool_invoker_state_files) > 0, f"No tool_invoker state files found in {debug_path}" From f8bfe4c162b5cc6e78feffcbea06260360e5dd43 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Tue, 15 Jul 2025 16:39:01 +0200 Subject: [PATCH 09/21] fixing imports --- haystack/core/pipeline/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index 955aa02aef..3727d64c0a 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -6,7 +6,7 @@ from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Mapping, Optional, Set, Union, cast +from typing import Any, Dict, Mapping, Optional, Set, Union from haystack import logging, tracing from haystack.core.component import Component From 7551fce8acbee96ba73e49d013ef7a39f3ec54d0 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> Date: Wed, 16 Jul 2025 15:16:47 +0200 Subject: [PATCH 10/21] refactor: Refactor suggestions for Pipeline breakpoints (#9614) * Refactoring * Start adding debug_path into Breakpoint class * Fully move debug_path into Breakpoint dataclass * Simplifications in pipeline run logic * More simplification * lint * More simplification * Updates * Rename resume_state to pipeline_snapshot * PR comments * Missed renaming of state in a few more places --- haystack/components/agents/agent.py | 242 +++++----------- haystack/core/errors.py | 8 +- haystack/core/pipeline/breakpoint.py | 262 +++++++++++++----- haystack/core/pipeline/pipeline.py | 177 ++++++------ haystack/dataclasses/breakpoints.py | 94 +++---- .../test_agent_breakpoints_inside_pipeline.py | 120 ++++---- .../test_agent_breakpoints_isolation_async.py | 74 +++-- .../test_agent_breakpoints_isolation_sync.py | 65 ++--- .../agents/test_agent_breakpoints_utils.py | 12 +- test/conftest.py | 27 +- test/core/pipeline/test_breakpoint.py | 55 ++-- ...test_pipeline_breakpoints_answer_joiner.py | 27 +- ...test_pipeline_breakpoints_branch_joiner.py | 26 +- .../test_pipeline_breakpoints_list_joiner.py | 27 +- .../test_pipeline_breakpoints_loops.py | 16 +- .../test_pipeline_breakpoints_rag_hybrid.py | 37 ++- ...test_pipeline_breakpoints_string_joiner.py | 24 +- 17 files changed, 619 insertions(+), 674 deletions(-) diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 5708b2f82a..7fc18a54c6 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -3,22 +3,19 @@ # SPDX-License-Identifier: Apache-2.0 import inspect -from copy import deepcopy -from pathlib import Path from typing import Any, Dict, List, Optional, Union from haystack import logging, tracing from haystack.components.generators.chat.types import ChatGenerator from haystack.components.tools import ToolInvoker from haystack.core.component.component import component -from haystack.core.errors import BreakpointException from haystack.core.pipeline.async_pipeline import AsyncPipeline -from haystack.core.pipeline.breakpoint import _save_state +from haystack.core.pipeline.breakpoint import _check_chat_generator_breakpoint, _check_tool_invoker_breakpoint from haystack.core.pipeline.pipeline import Pipeline from haystack.core.pipeline.utils import _deepcopy_with_exceptions from haystack.core.serialization import component_to_dict, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, ChatRole -from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint +from haystack.dataclasses.breakpoints import AgentBreakpoint, ToolBreakpoint from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset from haystack.utils.callable_serialization import deserialize_callable, serialize_callable @@ -250,126 +247,13 @@ def _validate_tool_breakpoint_is_valid(self, agent_breakpoint: AgentBreakpoint) if tool_breakpoint.tool_name is not None and tool_breakpoint.tool_name not in available_tool_names: # type: ignore # was checked outside function raise ValueError(f"Tool '{tool_breakpoint.tool_name}' is not available in the agent's tools") # type: ignore # was checked outside function - def _check_chat_generator_breakpoint( # pylint: disable=too-many-positional-arguments - self, - agent_breakpoint: Optional[AgentBreakpoint], - component_visits: Dict[str, int], - messages: List[ChatMessage], - generator_inputs: Dict[str, Any], - debug_path: Optional[Union[str, Path]], - kwargs: Dict[str, Any], - state: State, - ) -> None: - """ - Check for breakpoint before calling the ChatGenerator. - - :param agent_breakpoint: AgentBreakpoint object containing breakpoints - :param component_visits: Dictionary tracking component visit counts - :param messages: Current messages to process - :param generator_inputs: Inputs for the chat generator - :param debug_path: Path for saving debug state - :param kwargs: Additional keyword arguments - :param state: Current agent state - :raises AgentBreakpointException: If a breakpoint is triggered - """ - - if agent_breakpoint and isinstance(agent_breakpoint.break_point, Breakpoint): - break_point = agent_breakpoint.break_point - if component_visits[break_point.component_name] == break_point.visit_count: - state_inputs = deepcopy({"messages": messages, **generator_inputs}) - _save_state( - inputs=state_inputs, - component_name=break_point.component_name, - component_visits=component_visits, # these are the component visits of the agent components - debug_path=debug_path, - original_input_data={"messages": messages, **kwargs}, - ordered_component_names=["chat_generator", "tool_invoker"], - agent_name=self._agent_name, - main_pipeline_state=state.data.get("main_pipeline_state", {}), - ) - msg = ( - f"Breaking at {break_point.component_name} visit count " - f"{component_visits[break_point.component_name]}" - ) - logger.info(msg) - raise BreakpointException( - message=msg, component=break_point.component_name, state=state_inputs, results=state.data - ) - - def _check_tool_invoker_breakpoint( # pylint: disable=too-many-positional-arguments - self, - agent_breakpoint: Optional[AgentBreakpoint], - component_visits: Dict[str, int], - llm_messages: List[ChatMessage], - streaming_callback: Optional[StreamingCallbackT], - debug_path: Optional[Union[str, Path]], - messages: List[ChatMessage], - kwargs: Dict[str, Any], - state: State, - ) -> None: - """ - Check for breakpoint before calling the ToolInvoker. - - :param agent_breakpoint: AgentBreakpoint object containing breakpoints - :param component_visits: Dictionary tracking component visit counts - :param llm_messages: Messages from the LLM - :param state: Current agent state - :param streaming_callback: Streaming callback function - :param debug_path: Path for saving debug state - :param messages: Original messages - :param kwargs: Additional keyword arguments - :raises AgentBreakpointException: If a breakpoint is triggered - """ - - if agent_breakpoint and isinstance(agent_breakpoint.break_point, ToolBreakpoint): - tool_breakpoint = agent_breakpoint.break_point - # Check if the visit count matches - if component_visits[tool_breakpoint.component_name] == tool_breakpoint.visit_count: - # Check if we should break for this specific tool or all tools - should_break = False - if tool_breakpoint.tool_name is None: - # Break for any tool call - should_break = any(msg.tool_call for msg in llm_messages) - else: - # Break only for the specific tool - should_break = any( - msg.tool_call and msg.tool_call.tool_name == tool_breakpoint.tool_name for msg in llm_messages - ) - - if should_break: - state_inputs = deepcopy( - {"messages": llm_messages, "state": state, "streaming_callback": streaming_callback} - ) - _save_state( - inputs=state_inputs, - component_name=tool_breakpoint.component_name, - component_visits=component_visits, - debug_path=debug_path, - original_input_data={"messages": messages, **kwargs}, - ordered_component_names=["chat_generator", "tool_invoker"], - agent_name=self._agent_name, - main_pipeline_state=state.data.get("main_pipeline_state", {}), - ) - msg = ( - f"Breaking at {tool_breakpoint.component_name} visit count " - f"{component_visits[tool_breakpoint.component_name]}" - ) - if tool_breakpoint.tool_name: - msg += f" for tool {tool_breakpoint.tool_name}" - logger.info(msg) - - raise BreakpointException( - message=msg, component=tool_breakpoint.component_name, state=state_inputs, results=state.data - ) - - def run( # noqa: PLR0915 + def run( self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, *, break_point: Optional[AgentBreakpoint] = None, - resume_state: Optional[Dict[str, Any]] = None, - debug_path: Optional[Union[str, Path]] = None, + pipeline_snapshot: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Dict[str, Any]: """ @@ -381,8 +265,7 @@ def run( # noqa: PLR0915 The same callback can be configured to emit tool results when a tool is called. :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint for "tool_invoker". - :param resume_state: A dictionary containing the state of a previously saved agent execution. - :param debug_path: Path to the directory where the agent state should be saved. + :param pipeline_snapshot: A dictionary containing the state of a previously saved agent execution. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :returns: @@ -397,31 +280,25 @@ def run( # noqa: PLR0915 if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"): raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run()'.") - if break_point and resume_state: - msg = ( - "agent_breakpoint and resume_state cannot be provided at the same time. The agent run will be aborted." + if break_point and pipeline_snapshot: + raise ValueError( + "agent_breakpoint and pipeline_snapshot cannot be provided at the same time. " + "The agent run will be aborted." ) - raise ValueError(msg) - - self._agent_name = self.__component_name__ if hasattr(self, "__component_name__") else "isolated_agent" # validate breakpoints if break_point and isinstance(break_point.break_point, ToolBreakpoint): self._validate_tool_breakpoint_is_valid(break_point) - # resume state if provided - if resume_state: - component_visits = resume_state.get("pipeline_state", {}).get("component_visits", {}) - state_data = resume_state.get("pipeline_state", {}).get("inputs", {}).get("state", {}).get("data", {}) - state = State(schema=self.state_schema, data=state_data) + # Handle pipeline snapshot if provided + if pipeline_snapshot: + component_visits = pipeline_snapshot.get("pipeline_state", {}).get("component_visits", {}) + state_data = pipeline_snapshot.get("pipeline_state", {}).get("inputs", {}).get("state", {}).get("data", {}) # deserialize messages from pipeline state - raw_messages = resume_state.get("pipeline_state", {}).get("inputs", {}).get("messages", messages) - + raw_messages = pipeline_snapshot.get("pipeline_state", {}).get("inputs", {}).get("messages", messages) # convert raw message dictionaries to ChatMessage objects and populate the state messages = [ChatMessage.from_dict(msg) if isinstance(msg, dict) else msg for msg in raw_messages] - state.set("messages", messages) - else: if self.system_prompt is not None: messages = [ChatMessage.from_system(self.system_prompt)] + messages @@ -432,10 +309,11 @@ def run( # noqa: PLR0915 "Agent will not perform any actions specific to user input. Consider adding user messages to the " "input." ) - - state = State(schema=self.state_schema, data=kwargs) - state.set("messages", messages) component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) + state_data = kwargs + + state = State(schema=self.state_schema, data=state_data) + state.set("messages", messages) streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False @@ -448,13 +326,15 @@ def run( # noqa: PLR0915 ) counter = 0 - if break_point and self._agent_name is None: - raise ValueError("When using breakpoints, the agent_name must be provided to save the state correctly.") - while counter < self.max_agent_steps: # check for breakpoint before ChatGenerator - self._check_chat_generator_breakpoint( - break_point, component_visits, messages, generator_inputs, debug_path, kwargs, state + _check_chat_generator_breakpoint( + agent_breakpoint=break_point, + component_visits=component_visits, + messages=messages, + generator_inputs=generator_inputs, + kwargs=kwargs, + state=state, ) # 1. Call the ChatGenerator @@ -474,8 +354,14 @@ def run( # noqa: PLR0915 break # check for breakpoint before ToolInvoker - self._check_tool_invoker_breakpoint( - break_point, component_visits, llm_messages, streaming_callback, debug_path, messages, kwargs, state + _check_tool_invoker_breakpoint( + agent_breakpoint=break_point, + component_visits=component_visits, + llm_messages=llm_messages, + streaming_callback=streaming_callback, + messages=messages, + kwargs=kwargs, + state=state, ) # 3. Call the ToolInvoker @@ -520,8 +406,7 @@ async def run_async( # noqa: PLR0915 streaming_callback: Optional[StreamingCallbackT] = None, *, break_point: Optional[AgentBreakpoint] = None, - resume_state: Optional[Dict[str, Any]] = None, - debug_path: Optional[Union[str, Path]] = None, + pipeline_snapshot: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Dict[str, Any]: """ @@ -536,8 +421,7 @@ async def run_async( # noqa: PLR0915 is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint for "tool_invoker". - :param resume_state: A dictionary containing the state of a previously saved agent execution. - :param debug_path: Path to the directory where the agent state should be saved. + :param pipeline_snapshot: A dictionary containing the state of a previously saved agent execution. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :returns: @@ -552,37 +436,26 @@ async def run_async( # noqa: PLR0915 if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"): raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run_async()'.") - if break_point and resume_state: + if break_point and pipeline_snapshot: msg = ( - "agent_breakpoint and resume_state cannot be provided at the same time. The agent run will be aborted." + "agent_breakpoint and pipeline_snapshot cannot be provided at the same time. " + "The agent run will be aborted." ) raise ValueError(msg) - self._agent_name = self.__component_name__ if hasattr(self, "__component_name__") else "isolated_agent" - # validate breakpoints if break_point and isinstance(break_point.break_point, ToolBreakpoint): self._validate_tool_breakpoint_is_valid(break_point) - # Handle resume state if provided - if resume_state: - # Extract component visits from pipeline state - component_visits = resume_state.get("pipeline_state", {}).get("component_visits", {}) - - # Initialize with default values if not present in resume state - component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) | component_visits - - # Extract state data from pipeline state - state_data = resume_state.get("pipeline_state", {}).get("inputs", {}).get("state", {}).get("data", {}) - state = State(schema=self.state_schema, data=state_data) + # Handle pipeline snapshot if provided + if pipeline_snapshot: + component_visits = pipeline_snapshot.get("pipeline_state", {}).get("component_visits", {}) + state_data = pipeline_snapshot.get("pipeline_state", {}).get("inputs", {}).get("state", {}).get("data", {}) # Extract and deserialize messages from pipeline state - raw_messages = resume_state.get("pipeline_state", {}).get("inputs", {}).get("messages", messages) - + raw_messages = pipeline_snapshot.get("pipeline_state", {}).get("inputs", {}).get("messages", messages) # Convert raw message dictionaries to ChatMessage objects messages = [ChatMessage.from_dict(msg) if isinstance(msg, dict) else msg for msg in raw_messages] - state.set("messages", messages) - else: if self.system_prompt is not None: messages = [ChatMessage.from_system(self.system_prompt)] + messages @@ -593,10 +466,11 @@ async def run_async( # noqa: PLR0915 "Agent will not perform any actions specific to user input. Consider adding user messages to the " "input." ) - - state = State(schema=self.state_schema, data=kwargs) - state.set("messages", messages) component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) + state_data = kwargs + + state = State(schema=self.state_schema, data=state_data) + state.set("messages", messages) streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True @@ -609,13 +483,15 @@ async def run_async( # noqa: PLR0915 ) counter = 0 - if break_point and self._agent_name is None: - raise ValueError("When using breakpoints, the agent_name must be provided to save the state correctly.") - while counter < self.max_agent_steps: # Check for breakpoint before ChatGenerator - self._check_chat_generator_breakpoint( - break_point, component_visits, messages, generator_inputs, debug_path, kwargs, state + _check_chat_generator_breakpoint( + agent_breakpoint=break_point, + component_visits=component_visits, + messages=messages, + generator_inputs=generator_inputs, + kwargs=kwargs, + state=state, ) # 1. Call the ChatGenerator @@ -635,8 +511,14 @@ async def run_async( # noqa: PLR0915 break # Check for breakpoint before ToolInvoker - self._check_tool_invoker_breakpoint( - break_point, component_visits, llm_messages, streaming_callback, debug_path, messages, kwargs, state + _check_tool_invoker_breakpoint( + agent_breakpoint=break_point, + component_visits=component_visits, + llm_messages=llm_messages, + streaming_callback=streaming_callback, + messages=messages, + kwargs=kwargs, + state=state, ) # 3. Call the ToolInvoker diff --git a/haystack/core/errors.py b/haystack/core/errors.py index 9de137da1f..a05e1ea705 100644 --- a/haystack/core/errors.py +++ b/haystack/core/errors.py @@ -100,18 +100,18 @@ def __init__( self, message: str, component: Optional[str] = None, - state: Optional[Dict[str, Any]] = None, + pipeline_snapshot: Optional[Dict[str, Any]] = None, results: Optional[Dict[str, Any]] = None, ): super().__init__(message) self.component = component - self.state = state + self.pipeline_snapshot = pipeline_snapshot self.results = results -class PipelineInvalidResumeStateError(Exception): +class PipelineInvalidPipelineSnapshotError(Exception): """ - Exception raised when a pipeline is resumed from an invalid state. + Exception raised when a pipeline is resumed from an invalid snapshot. """ def __init__(self, message: str): diff --git a/haystack/core/pipeline/breakpoint.py b/haystack/core/pipeline/breakpoint.py index bc1a4adf5b..41c03d1551 100644 --- a/haystack/core/pipeline/breakpoint.py +++ b/haystack/core/pipeline/breakpoint.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -# pylint: disable=too-many-return-statements, too-many-positional-arguments +# pylint: disable=too-many-return-statements import json from copy import deepcopy @@ -13,7 +13,9 @@ from networkx import MultiDiGraph from haystack import logging -from haystack.core.errors import BreakpointException, PipelineInvalidResumeStateError +from haystack.components.agents.state import State +from haystack.core.errors import BreakpointException, PipelineInvalidPipelineSnapshotError +from haystack.dataclasses import ChatMessage, StreamingCallbackT from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint from haystack.utils.base_serialization import _serialize_value_with_schema @@ -49,54 +51,56 @@ def _validate_break_point(break_point: Union[Breakpoint, AgentBreakpoint], graph ) -def _validate_components_against_pipeline(resume_state: Dict[str, Any], graph: MultiDiGraph) -> None: +def _validate_components_against_pipeline(pipeline_snapshot: Dict[str, Any], graph: MultiDiGraph) -> None: """ - Validates that the resume_state contains valid configuration for the current pipeline. + Validates that the pipeline_snapshot contains valid configuration for the current pipeline. - Raises a PipelineInvalidResumeStateError if any component in resume_state is not part of the target pipeline. + Raises a PipelineInvalidPipelineSnapshotError if any component in pipeline_snapshot is not part of the + target pipeline. - :param resume_state: The saved state to validate. + :param pipeline_snapshot: The saved state to validate. """ - pipeline_state = resume_state["pipeline_state"] + pipeline_state = pipeline_snapshot["pipeline_state"] valid_components = set(graph.nodes.keys()) # Check if the ordered_component_names are valid components in the pipeline invalid_ordered_components = set(pipeline_state["ordered_component_names"]) - valid_components if invalid_ordered_components: - raise PipelineInvalidResumeStateError( - f"Invalid resume state: components {invalid_ordered_components} in 'ordered_component_names' " + raise PipelineInvalidPipelineSnapshotError( + f"Invalid pipeline snapshot: components {invalid_ordered_components} in 'ordered_component_names' " f"are not part of the current pipeline." ) # Check if the input_data is valid components in the pipeline - serialized_input_data = resume_state["input_data"]["serialized_data"] + serialized_input_data = pipeline_snapshot["input_data"]["serialized_data"] invalid_input_data = set(serialized_input_data.keys()) - valid_components if invalid_input_data: - raise PipelineInvalidResumeStateError( - f"Invalid resume state: components {invalid_input_data} in 'input_data' " + raise PipelineInvalidPipelineSnapshotError( + f"Invalid pipeline snapshot: components {invalid_input_data} in 'input_data' " f"are not part of the current pipeline." ) # Validate 'component_visits' invalid_component_visits = set(pipeline_state["component_visits"].keys()) - valid_components if invalid_component_visits: - raise PipelineInvalidResumeStateError( - f"Invalid resume state: components {invalid_component_visits} in 'component_visits' " + raise PipelineInvalidPipelineSnapshotError( + f"Invalid pipeline snapshot: components {invalid_component_visits} in 'component_visits' " f"are not part of the current pipeline." ) logger.info( - f"Resuming pipeline from component: {resume_state['pipeline_breakpoint']['component']} " - f"(visit {resume_state['pipeline_breakpoint']['visits']})" + f"Resuming pipeline from component: {pipeline_snapshot['pipeline_breakpoint']['component']} " + f"(visit {pipeline_snapshot['pipeline_breakpoint']['visits']})" ) -def _validate_resume_state(resume_state: Dict[str, Any]) -> None: +def _validate_pipeline_snapshot(pipeline_snapshot: Dict[str, Any]) -> None: """ - Validates the loaded pipeline resume_state. + Validates the loaded pipeline snapshot. - Ensures that the resume_state contains required keys: "input_data", "pipeline_breakpoint", and "pipeline_state". + Ensures that the pipeline_snapshot contains required keys: "input_data", "pipeline_breakpoint", + and "pipeline_state". Raises: ValueError: If required keys are missing or the component sets are inconsistent. @@ -104,12 +108,12 @@ def _validate_resume_state(resume_state: Dict[str, Any]) -> None: # top-level state has all required keys required_top_keys = {"input_data", "pipeline_breakpoint", "pipeline_state"} - missing_top = required_top_keys - resume_state.keys() + missing_top = required_top_keys - pipeline_snapshot.keys() if missing_top: - raise ValueError(f"Invalid state file: missing required keys {missing_top}") + raise ValueError(f"Invalid pipeline_snapshot: missing required keys {missing_top}") # pipeline_state has the necessary keys - pipeline_state = resume_state["pipeline_state"] + pipeline_state = pipeline_snapshot["pipeline_state"] required_pipeline_keys = {"inputs", "component_visits", "ordered_component_names"} missing_pipeline = required_pipeline_keys - pipeline_state.keys() @@ -126,23 +130,23 @@ def _validate_resume_state(resume_state: Dict[str, Any]) -> None: f"do not match components in ordered_component_names {components_in_order}" ) - logger.info("Passed resume state validated successfully.") + logger.info("Pipeline snapshot validated successfully.") -def load_state(file_path: Union[str, Path]) -> Dict[str, Any]: +def load_pipeline_snapshot(file_path: Union[str, Path]) -> Dict[str, Any]: """ - Load a saved pipeline state. + Load a saved pipeline snapshot. - :param file_path: Path to the resume_state file. + :param file_path: Path to the pipeline_snapshot file. :returns: - Dict containing the loaded resume_state. + Dict containing the loaded pipeline_snapshot. """ file_path = Path(file_path) try: with open(file_path, "r", encoding="utf-8") as f: - state = json.load(f) + pipeline_snapshot = json.load(f) except FileNotFoundError: raise FileNotFoundError(f"File not found: {file_path}") except json.JSONDecodeError as e: @@ -151,12 +155,12 @@ def load_state(file_path: Union[str, Path]) -> Dict[str, Any]: raise IOError(f"Error reading {file_path}: {str(e)}") try: - _validate_resume_state(resume_state=state) + _validate_pipeline_snapshot(pipeline_snapshot=pipeline_snapshot) except ValueError as e: - raise ValueError(f"Invalid pipeline state from {file_path}: {str(e)}") + raise ValueError(f"Invalid pipeline snapshot from {file_path}: {str(e)}") - logger.info(f"Successfully loaded pipeline state from: {file_path}") - return state + logger.info(f"Successfully loaded the pipeline snapshot from: {file_path}") + return pipeline_snapshot def _process_main_pipeline_state(main_pipeline_state: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: @@ -184,26 +188,19 @@ def _process_main_pipeline_state(main_pipeline_state: Optional[Dict[str, Any]]) } -def _save_state_to_file( - state: Dict[str, Any], - debug_path: Union[str, Path], - dt: datetime, - is_agent: bool, - agent_name: Optional[str], - component_name: str, +def _save_pipeline_snapshot_to_file( + *, pipeline_snapshot: Dict[str, Any], debug_path: Union[str, Path], dt: datetime, component_name: str ) -> None: """ - Save state dictionary to a JSON file. + Save the pipeline snapshot dictionary to a JSON file. - :param state: The state dictionary to save. + :param pipeline_snapshot: The pipeline snapshot to save. :param debug_path: The path where to save the file. :param dt: The datetime object for timestamping. - :param is_agent: Whether this is an agent pipeline. - :param agent_name: Name of the agent (if applicable). :param component_name: Name of the component that triggered the breakpoint. :raises: ValueError: If the debug_path is not a string or a Path object. - Exception: If saving the JSON state fails. + Exception: If saving the JSON snapshot fails. """ debug_path = Path(debug_path) if isinstance(debug_path, str) else debug_path if not isinstance(debug_path, Path): @@ -212,21 +209,23 @@ def _save_state_to_file( debug_path.mkdir(exist_ok=True) # Generate filename - if is_agent: - file_name = f"{agent_name}_{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json" + # We check if the agent_name is provided to differentiate between agent and non-agent breakpoints + if pipeline_snapshot["agent_name"] is not None: + file_name = f"{pipeline_snapshot['agent_name']}_{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json" else: file_name = f"{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json" try: with open(debug_path / file_name, "w") as f_out: - json.dump(state, f_out, indent=2) - logger.info(f"Pipeline state saved at: {file_name}") + json.dump(pipeline_snapshot, f_out, indent=2) + logger.info(f"Pipeline snapshot saved at: {file_name}") except Exception as e: - logger.error(f"Failed to save pipeline state: {str(e)}") + logger.error(f"Failed to save pipeline snapshot: {str(e)}") raise -def _save_state( +def _save_snapshot( + *, inputs: Dict[str, Any], component_name: str, component_visits: Dict[str, int], @@ -237,19 +236,19 @@ def _save_state( main_pipeline_state: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """ - Save the pipeline state to a file. + Save the pipeline snapshot to a file. - :param inputs: The current pipeline state inputs. + :param inputs: The current pipeline snapshot inputs. :param component_name: The name of the component that triggered the breakpoint. :param component_visits: The visit count of the component that triggered the breakpoint. - :param debug_path: The path to save the state to. + :param debug_path: The path to save the snapshot to. :param original_input_data: The original input data. :param ordered_component_names: The ordered component names. :param main_pipeline_state: Dictionary containing main pipeline state with keys: "component_visits", "ordered_component_names", "original_input_data", and "inputs". :returns: - The dictionary containing the state of the pipeline containing the following keys: + The dictionary containing the snapshot of the pipeline containing the following keys: - input_data: The original input data passed to the pipeline. - timestamp: The timestamp of the breakpoint. - pipeline_breakpoint: The component name and visit count that triggered the breakpoint. @@ -267,7 +266,7 @@ def _save_state( transformed_original_input_data = _transform_json_structure(original_input_data) transformed_inputs = _transform_json_structure(inputs) - state = { + pipeline_snapshot = { # related to the main pipeline where the agent running as a breakpoint - only used with AgentBreakpoint "agent_name": agent_name if agent_name else None, "main_pipeline_state": _process_main_pipeline_state(main_pipeline_state) if agent_name else None, @@ -284,12 +283,13 @@ def _save_state( } if not debug_path: - return state + return pipeline_snapshot - is_agent = agent_name is not None - _save_state_to_file(state, debug_path, dt, is_agent, agent_name, component_name) + _save_pipeline_snapshot_to_file( + pipeline_snapshot=pipeline_snapshot, debug_path=debug_path, dt=dt, component_name=component_name + ) - return state + return pipeline_snapshot def _transform_json_structure(data: Union[Dict[str, Any], List[Any], Any]) -> Any: @@ -322,7 +322,8 @@ def _transform_json_structure(data: Union[Dict[str, Any], List[Any], Any]) -> An return data -def handle_agent_break_point( +def _handle_agent_break_point( + *, break_point: AgentBreakpoint, component_name: str, component_inputs: Dict[str, Any], @@ -330,7 +331,6 @@ def handle_agent_break_point( component_visits: Dict[str, int], ordered_component_names: list, data: Dict[str, Any], - debug_path: Optional[Union[str, Path]], ) -> Dict[str, Any]: """ Handle agent-specific breakpoint logic. @@ -342,11 +342,9 @@ def handle_agent_break_point( :param component_visits: Component visit counts :param ordered_component_names: Ordered list of component names :param data: Original pipeline data - :param debug_path: Path for debug files :return: Updated component inputs """ component_inputs["break_point"] = break_point - component_inputs["debug_path"] = debug_path # Store pipeline state for agent resume state_inputs_serialised = deepcopy(inputs) @@ -361,7 +359,7 @@ def handle_agent_break_point( return component_inputs -def check_regular_break_point(break_point: Breakpoint, component_name: str, component_visits: Dict[str, int]) -> bool: +def _check_regular_break_point(break_point: Breakpoint, component_name: str, component_visits: Dict[str, int]) -> bool: """ Check if a regular breakpoint should be triggered. @@ -373,7 +371,8 @@ def check_regular_break_point(break_point: Breakpoint, component_name: str, comp return break_point.component_name == component_name and break_point.visit_count == component_visits[component_name] -def trigger_break_point( +def _trigger_break_point( + *, component_name: str, component_inputs: Dict[str, Any], inputs: Dict[str, Any], @@ -384,7 +383,7 @@ def trigger_break_point( pipeline_outputs: Dict[str, Any], ) -> None: """ - Trigger a breakpoint by saving state and raising exception. + Trigger a breakpoint by saving a snapshot and raising exception. :param component_name: Name of the component where breakpoint is triggered :param component_inputs: Inputs for the current component @@ -396,10 +395,10 @@ def trigger_break_point( :param pipeline_outputs: Current pipeline outputs :raises PipelineBreakpointException: When breakpoint is triggered """ - state_inputs_serialised = deepcopy(inputs) - state_inputs_serialised[component_name] = deepcopy(component_inputs) - _save_state( - inputs=state_inputs_serialised, + pipeline_snapshot_inputs_serialised = deepcopy(inputs) + pipeline_snapshot_inputs_serialised[component_name] = deepcopy(component_inputs) + _save_snapshot( + inputs=pipeline_snapshot_inputs_serialised, component_name=str(component_name), component_visits=component_visits, debug_path=debug_path, @@ -409,5 +408,126 @@ def trigger_break_point( msg = f"Breaking at component {component_name} at visit count {component_visits[component_name]}" raise BreakpointException( - message=msg, component=component_name, state=state_inputs_serialised, results=pipeline_outputs + message=msg, + component=component_name, + pipeline_snapshot=pipeline_snapshot_inputs_serialised, + results=pipeline_outputs, ) + + +def _check_chat_generator_breakpoint( + *, + agent_breakpoint: Optional[AgentBreakpoint], + component_visits: Dict[str, int], + messages: List[ChatMessage], + generator_inputs: Dict[str, Any], + kwargs: Dict[str, Any], + state: State, +) -> None: + """ + Check for breakpoint before calling the ChatGenerator. + + :param agent_breakpoint: AgentBreakpoint object containing breakpoints + :param component_visits: Dictionary tracking component visit counts + :param messages: Current messages to process + :param generator_inputs: Inputs for the chat generator + :param kwargs: Additional keyword arguments + :param state: The current State of the agent + :raises AgentBreakpointException: If a breakpoint is triggered + """ + + # We also check component_name since ToolBreakpoint is a subclass of Breakpoint + if ( + agent_breakpoint + and isinstance(agent_breakpoint.break_point, Breakpoint) + and agent_breakpoint.break_point.component_name == "chat_generator" + ): + break_point = agent_breakpoint.break_point + if component_visits[break_point.component_name] == break_point.visit_count: + chat_generator_inputs = deepcopy({"messages": messages, **generator_inputs}) + _save_snapshot( + inputs=chat_generator_inputs, + component_name=break_point.component_name, + component_visits=component_visits, # these are the component visits of the agent components + debug_path=break_point.debug_path, + original_input_data={"messages": messages, **kwargs}, + ordered_component_names=["chat_generator", "tool_invoker"], + agent_name=agent_breakpoint.agent_name or "isolated_agent", + main_pipeline_state=state.data.get("main_pipeline_state", {}), + ) + msg = f"Breaking at {break_point.component_name} visit count {component_visits[break_point.component_name]}" + logger.info(msg) + raise BreakpointException( + message=msg, + component=break_point.component_name, + pipeline_snapshot=chat_generator_inputs, + results=state.data, + ) + + +def _check_tool_invoker_breakpoint( + *, + agent_breakpoint: Optional[AgentBreakpoint], + component_visits: Dict[str, int], + llm_messages: List[ChatMessage], + streaming_callback: Optional[StreamingCallbackT], + messages: List[ChatMessage], + kwargs: Dict[str, Any], + state: State, +) -> None: + """ + Check for breakpoint before calling the ToolInvoker. + + :param agent_breakpoint: AgentBreakpoint object containing breakpoints + :param component_visits: Dictionary tracking component visit counts + :param llm_messages: Messages from the LLM + :param state: Current agent state + :param streaming_callback: Streaming callback function + :param messages: Original messages + :param kwargs: Additional keyword arguments + :raises AgentBreakpointException: If a breakpoint is triggered + """ + + if agent_breakpoint and isinstance(agent_breakpoint.break_point, ToolBreakpoint): + tool_breakpoint = agent_breakpoint.break_point + # Check if the visit count matches + if component_visits[tool_breakpoint.component_name] == tool_breakpoint.visit_count: + # Check if we should break for this specific tool or all tools + should_break = False + if tool_breakpoint.tool_name is None: + # Break for any tool call + should_break = any(msg.tool_call for msg in llm_messages) + else: + # Break only for the specific tool + should_break = any( + msg.tool_call and msg.tool_call.tool_name == tool_breakpoint.tool_name for msg in llm_messages + ) + + if should_break: + tool_invoker_inputs = deepcopy( + {"messages": llm_messages, "state": state, "streaming_callback": streaming_callback} + ) + _save_snapshot( + inputs=tool_invoker_inputs, + component_name=tool_breakpoint.component_name, + component_visits=component_visits, + debug_path=tool_breakpoint.debug_path, + original_input_data={"messages": messages, **kwargs}, + ordered_component_names=["chat_generator", "tool_invoker"], + agent_name=agent_breakpoint.agent_name or "isolated_agent", + main_pipeline_state=state.data.get("main_pipeline_state", {}), + ) + msg = ( + f"Breaking at {tool_breakpoint.component_name} visit count " + f"{component_visits[tool_breakpoint.component_name]}" + ) + if tool_breakpoint.tool_name: + msg += f" for tool {tool_breakpoint.tool_name}" + logger.info(msg) + + raise BreakpointException( + message=msg, + component=tool_breakpoint.component_name, + pipeline_snapshot=tool_invoker_inputs, + results=state.data, + ) diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index 3727d64c0a..7fb92fe35c 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -5,12 +5,11 @@ # pylint: disable=too-many-positional-arguments from copy import deepcopy -from pathlib import Path from typing import Any, Dict, Mapping, Optional, Set, Union from haystack import logging, tracing from haystack.core.component import Component -from haystack.core.errors import BreakpointException, PipelineInvalidResumeStateError, PipelineRuntimeError +from haystack.core.errors import BreakpointException, PipelineInvalidPipelineSnapshotError, PipelineRuntimeError from haystack.core.pipeline.base import ( _COMPONENT_INPUT, _COMPONENT_OUTPUT, @@ -19,11 +18,11 @@ PipelineBase, ) from haystack.core.pipeline.breakpoint import ( + _check_regular_break_point, + _handle_agent_break_point, + _trigger_break_point, _validate_break_point, _validate_components_against_pipeline, - check_regular_break_point, - handle_agent_break_point, - trigger_break_point, ) from haystack.core.pipeline.utils import _deepcopy_with_exceptions from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint @@ -90,61 +89,63 @@ def _run_component( return component_output - def _handle_resume_state(self, resume_state: Dict[str, Any]) -> tuple[Dict[str, int], Dict[str, Any], bool, list]: + def _handle_resume_pipeline( + self, pipeline_snapshot: Dict[str, Any] + ) -> tuple[Dict[str, int], Dict[str, Any], bool, list]: """ - Handle resume state initialization. + Handle resuming the pipeline from a pipeline snapshot. - :param resume_state: The resume state to handle + :param pipeline_snapshot: The snapshot of the pipeline to resume from. :return: Tuple of (component_visits, data, resume_agent_in_pipeline, ordered_component_names) """ - if resume_state.get("agent_name"): - return self._handle_agent_resume_state(resume_state) + if pipeline_snapshot.get("agent_name"): + return self._handle_resume_from_agent(pipeline_snapshot) else: - return self._handle_regular_resume_state(resume_state) + return self._handle_resume_from_pipeline_snapshot(pipeline_snapshot) - def _handle_agent_resume_state( - self, resume_state: Dict[str, Any] + def _handle_resume_from_agent( + self, pipeline_snapshot: Dict[str, Any] ) -> tuple[Dict[str, int], Dict[str, Any], bool, list]: """ - Handle agent-specific resume state. + Handle resuming the pipeline at a specific Agent component. - :param resume_state: The resume state to handle + :param pipeline_snapshot: The snapshot of the pipeline to resume from. :return: Tuple of (component_visits, data, resume_agent_in_pipeline, ordered_component_names) """ - agent_name = resume_state["agent_name"] + agent_name = pipeline_snapshot["agent_name"] for name, component in self.graph.nodes.items(): if component["instance"].__class__.__name__ == "Agent" and name == agent_name: - main_pipeline_state = resume_state.get("main_pipeline_state", {}) + main_pipeline_state = pipeline_snapshot.get("main_pipeline_state", {}) component_visits = main_pipeline_state.get("component_visits", {}) ordered_component_names = main_pipeline_state.get("ordered_component_names", []) data = _deserialize_value_with_schema(main_pipeline_state.get("inputs", {})) return component_visits, data, True, ordered_component_names # Fallback to regular resume if agent not found - return self._handle_regular_resume_state(resume_state) + return self._handle_resume_from_pipeline_snapshot(pipeline_snapshot) - def _handle_regular_resume_state( - self, resume_state: Dict[str, Any] + def _handle_resume_from_pipeline_snapshot( + self, pipeline_snapshot: Dict[str, Any] ) -> tuple[Dict[str, int], Dict[str, Any], bool, list]: """ - Handle regular component resume state. + Handle resuming the pipeline from a regular pipeline snapshot. - :param resume_state: The resume state to handle + :param pipeline_snapshot: The snapshot of the pipeline to resume from. :return: Tuple of (component_visits, data, resume_agent_in_pipeline, ordered_component_names) """ - component_visits, data, resume_state, ordered_component_names = self.inject_resume_state_into_graph( - resume_state=resume_state + component_visits, data, pipeline_snapshot, ordered_component_names = self._inject_pipeline_snapshot_into_graph( + pipeline_snapshot=pipeline_snapshot ) - data = _deserialize_value_with_schema(resume_state["pipeline_state"]["inputs"]) + data = _deserialize_value_with_schema(pipeline_snapshot["pipeline_state"]["inputs"]) return component_visits, data, False, ordered_component_names def run( # noqa: PLR0915, PLR0912, C901 self, data: Dict[str, Any], include_outputs_from: Optional[Set[str]] = None, + *, break_point: Optional[Union[Breakpoint, AgentBreakpoint]] = None, - resume_state: Optional[Dict[str, Any]] = None, - debug_path: Optional[Union[str, Path]] = None, + pipeline_snapshot: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """ Runs the Pipeline with given input data. @@ -224,12 +225,9 @@ def run( # noqa: PLR0915, PLR0912, C901 :param break_point: A set of breakpoints that can be used to debug the pipeline execution. - :param resume_state: + :param pipeline_snapshot: A dictionary containing the state of a previously saved pipeline execution. - :param debug_path: - Path to the directory where the pipeline state should be saved. - :returns: A dictionary where each entry corresponds to a component name and its output. If `include_outputs_from` is `None`, this dictionary @@ -249,12 +247,12 @@ def run( # noqa: PLR0915, PLR0912, C901 """ pipeline_running(self) - if break_point and resume_state: + if break_point and pipeline_snapshot: msg = ( - "pipeline_breakpoint and resume_state cannot be provided at the same time. " + "pipeline_breakpoint and pipeline_snapshot cannot be provided at the same time. " "The pipeline run will be aborted." ) - raise PipelineInvalidResumeStateError(message=msg) + raise PipelineInvalidPipelineSnapshotError(message=msg) # make sure all breakpoints are valid, i.e. reference components in the pipeline if break_point: @@ -267,7 +265,7 @@ def run( # noqa: PLR0915, PLR0912, C901 if include_outputs_from is None: include_outputs_from = set() - if not resume_state: + if not pipeline_snapshot: # normalize `data` data = self._prepare_component_input_data(data) @@ -283,9 +281,9 @@ def run( # noqa: PLR0915, PLR0912, C901 resume_agent_in_pipeline = False else: - # Handle resume state - component_visits, data, resume_agent_in_pipeline, ordered_component_names = self._handle_resume_state( - resume_state + # Handle resuming the pipeline from a snapshot + component_visits, data, resume_agent_in_pipeline, ordered_component_names = self._handle_resume_pipeline( + pipeline_snapshot ) cached_topological_sort = None @@ -349,7 +347,9 @@ def run( # noqa: PLR0915, PLR0912, C901 component_name, component_visits[component_name] ) - is_resume = bool(resume_state and resume_state["pipeline_breakpoint"]["component"] == component_name) + is_resume = bool( + pipeline_snapshot and pipeline_snapshot["pipeline_breakpoint"]["component"] == component_name + ) component_inputs = self._consume_component_inputs( component_name=component_name, component=component, inputs=inputs, is_resume=is_resume ) @@ -359,52 +359,46 @@ def run( # noqa: PLR0915, PLR0912, C901 # initialization component_inputs = self._add_missing_input_defaults(component_inputs, component["input_sockets"]) - # Scenario 1: Resume state is provided to resume the pipeline at a specific component - # Deserialize the component_inputs if they are passed in resume state + # Scenario 1: Pipeline snapshot is provided to resume the pipeline at a specific component + # Deserialize the component_inputs if they are passed in the pipeline_snapshot. # this check will prevent other component_inputs generated at runtime from being deserialized - if resume_state and component_name in resume_state["pipeline_state"]["inputs"].keys(): + if pipeline_snapshot and component_name in pipeline_snapshot["pipeline_state"]["inputs"].keys(): for key, value in component_inputs.items(): component_inputs[key] = _deserialize_value_with_schema(value) - # Scenario 2: a breakpoint is provided to stop the pipeline at a specific component and visit count - breakpoint_triggered = False - if break_point is not None: - agent_breakpoint = False - - if isinstance(break_point, AgentBreakpoint): - component_instance = component["instance"] - # Use type checking by class name to avoid circular import - if component_instance.__class__.__name__ == "Agent": - component_inputs = handle_agent_break_point( - break_point, - component_name, - component_inputs, - inputs, - component_visits, - ordered_component_names, - data, - debug_path, - ) - agent_breakpoint = True - - if not agent_breakpoint and isinstance(break_point, Breakpoint): - breakpoint_triggered = check_regular_break_point(break_point, component_name, component_visits) + # Scenario 2: an AgentBreakpoint is provided to stop the pipeline at a specific component + if isinstance(break_point, AgentBreakpoint) and component_name == break_point.agent_name: + component_inputs = _handle_agent_break_point( + break_point=break_point, + component_name=component_name, + component_inputs=component_inputs, + inputs=inputs, + component_visits=component_visits, + ordered_component_names=ordered_component_names, + data=data, + ) + # Scenario 3: a regular breakpoint is provided to stop the pipeline at a specific component and + # visit count + if isinstance(break_point, Breakpoint): + breakpoint_triggered = _check_regular_break_point( + break_point=break_point, component_name=component_name, component_visits=component_visits + ) if breakpoint_triggered: - trigger_break_point( - component_name, - component_inputs, - inputs, - component_visits, - debug_path, - data, - ordered_component_names, - pipeline_outputs, + _trigger_break_point( + component_name=component_name, + component_inputs=component_inputs, + inputs=inputs, + component_visits=component_visits, + debug_path=break_point.debug_path, + data=data, + ordered_component_names=ordered_component_names, + pipeline_outputs=pipeline_outputs, ) if resume_agent_in_pipeline: - # inject the resume_state into the component (the Agent) inputs - component_inputs["resume_state"] = resume_state + # inject the pipeline_snapshot into the component (the Agent) inputs + component_inputs["pipeline_snapshot"] = pipeline_snapshot component_inputs["break_point"] = None component_outputs = self._run_component( @@ -430,7 +424,7 @@ def run( # noqa: PLR0915, PLR0912, C901 if self._is_queue_stale(priority_queue): priority_queue = self._fill_queue(ordered_component_names, inputs, component_visits) - if break_point and not agent_breakpoint: + if isinstance(break_point, Breakpoint): logger.warning( "The given breakpoint {break_point} was never triggered. This is because:\n" "1. The provided component is not a part of the pipeline execution path.\n" @@ -440,25 +434,26 @@ def run( # noqa: PLR0915, PLR0912, C901 return pipeline_outputs - def inject_resume_state_into_graph(self, resume_state): + def _inject_pipeline_snapshot_into_graph( + self, pipeline_snapshot: Dict[str, Any] + ) -> tuple[Dict[str, int], Dict[str, Any], Dict[str, Any], list]: """ - Loads the resume state from a file and injects it into the pipeline graph. - + Injects the pipeline snapshot into the current pipeline graph. """ - # We previously check if the resume_state is None but this is needed to prevent a typing error - if not resume_state: - raise PipelineInvalidResumeStateError("Cannot inject resume state: resume_state is None") + # We previously check if the pipeline_snapshot is None but this is needed to prevent a typing error + if not pipeline_snapshot: + raise PipelineInvalidPipelineSnapshotError("Cannot inject pipeline_snapshot: pipeline_snapshot is None") - # check if the resume_state is valid for the current pipeline - _validate_components_against_pipeline(resume_state, self.graph) + # check if the pipeline_snapshot is valid for the current pipeline + _validate_components_against_pipeline(pipeline_snapshot, self.graph) - data = self._prepare_component_input_data(resume_state["pipeline_state"]["inputs"]) - component_visits = resume_state["pipeline_state"]["component_visits"] - ordered_component_names = resume_state["pipeline_state"]["ordered_component_names"] + data = self._prepare_component_input_data(pipeline_snapshot["pipeline_state"]["inputs"]) + component_visits = pipeline_snapshot["pipeline_state"]["component_visits"] + ordered_component_names = pipeline_snapshot["pipeline_state"]["ordered_component_names"] logger.info( "Resuming pipeline from {component} with visit count {visits}", - component=resume_state["pipeline_breakpoint"]["component"], - visits=resume_state["pipeline_breakpoint"]["visits"], + component=pipeline_snapshot["pipeline_breakpoint"]["component"], + visits=pipeline_snapshot["pipeline_breakpoint"]["visits"], ) - return component_visits, data, resume_state, ordered_component_names + return component_visits, data, pipeline_snapshot, ordered_component_names diff --git a/haystack/dataclasses/breakpoints.py b/haystack/dataclasses/breakpoints.py index cf4443c29b..263e5e1986 100644 --- a/haystack/dataclasses/breakpoints.py +++ b/haystack/dataclasses/breakpoints.py @@ -6,91 +6,65 @@ from typing import Optional, Union -@dataclass +@dataclass(frozen=True) class Breakpoint: """ A dataclass to hold a breakpoint for a component. + + :param component_name: The name of the component where the breakpoint is set. + :param visit_count: The number of times the component must be visited before the breakpoint is triggered. + :param debug_path: Optional path to store a snapshot of the pipeline when the breakpoint is hit. + This is useful for debugging purposes, allowing you to inspect the state of the pipeline at the time of the + breakpoint and to resume execution from that point. """ component_name: str visit_count: int = 0 + debug_path: Optional[str] = None - def __hash__(self): - return hash((self.component_name, self.visit_count)) - - def __eq__(self, other): - if not isinstance(other, Breakpoint): - return False - return self.component_name == other.component_name and self.visit_count == other.visit_count - - def __str__(self): - return f"Breakpoint(component_name={self.component_name}, visit_count={self.visit_count})" - def __repr__(self): - return self.__str__() - - -@dataclass +@dataclass(frozen=True) class ToolBreakpoint(Breakpoint): """ - A dataclass to hold a breakpoint that can be used to debug a Tool. + A dataclass representing a breakpoint specific to tools used within an Agent component. + + Inherits from Breakpoint and adds the ability to target individual tools. If `tool_name` is None, + the breakpoint applies to all tools within the Agent component. - If tool_name is None, it means that the breakpoint is for every tool in the component. - Otherwise, it means that the breakpoint is for the tool with the given name. + :param tool_name: The name of the tool to target within the Agent component. If None, applies to all tools. """ tool_name: Optional[str] = None - def __hash__(self): - return hash((self.component_name, self.visit_count, self.tool_name)) - - def __eq__(self, other): - if not isinstance(other, ToolBreakpoint): - return False - return super().__eq__(other) and self.tool_name == other.tool_name - def __str__(self): - if self.tool_name: - return ( - f"ToolBreakpoint(component_name={self.component_name}, visit_count={self.visit_count}, " - f"tool_name={self.tool_name})" - ) - else: - return ( - f"ToolBreakpoint(component_name={self.component_name}, visit_count={self.visit_count}, " - f"tool_name=ALL_TOOLS)" - ) - - def __repr__(self): - return self.__str__() + tool_str = f", tool_name={self.tool_name}" if self.tool_name else ", tool_name=ALL_TOOLS" + return f"ToolBreakpoint(component_name={self.component_name}, visit_count={self.visit_count}{tool_str})" @dataclass class AgentBreakpoint: """ - A dataclass to hold a breakpoint that can be used to debug an Agent. - """ + A dataclass representing a breakpoint tied to an Agent’s execution. - break_point: Union[Breakpoint, ToolBreakpoint] - agent_name: str = "" + This allows for debugging either a specific component (e.g., the chat generator) or a tool used by the agent. + It enforces constraints on which component names are valid for each breakpoint type. - def __init__(self, agent_name: str, break_point: Union[Breakpoint, ToolBreakpoint]): - if not isinstance(break_point, ToolBreakpoint) and break_point.component_name != "chat_generator": - raise ValueError( - "The break_point must be a Breakpoint that has the component_name " - "'chat_generator' or be a ToolBreakpoint." - ) + :param agent_name: The name of the agent component in a pipeline where the breakpoint is set. + :param break_point: An instance of Breakpoint or ToolBreakpoint indicating where to break execution. - if not break_point: - raise ValueError("A Breakpoint must be provided.") + :raises ValueError: If the component_name is invalid for the given breakpoint type: + - Breakpoint must have component_name='chat_generator'. + - ToolBreakpoint must have component_name='tool_invoker'. + """ - self.agent_name = agent_name + agent_name: str + break_point: Union[Breakpoint, ToolBreakpoint] + def __post_init__(self): if ( - isinstance(break_point, ToolBreakpoint) - or isinstance(break_point, Breakpoint) - and not isinstance(break_point, ToolBreakpoint) - ): - self.break_point = break_point - else: - raise ValueError("The breakpoint must be either Breakpoint or ToolBreakpoint.") + isinstance(self.break_point, Breakpoint) and not isinstance(self.break_point, ToolBreakpoint) + ) and self.break_point.component_name != "chat_generator": + raise ValueError("If the break_point is a Breakpoint, it must have the component_name 'chat_generator'.") + + if isinstance(self.break_point, ToolBreakpoint) and self.break_point.component_name != "tool_invoker": + raise ValueError("If the break_point is a ToolBreakpoint, it must have the component_name 'tool_invoker'.") diff --git a/test/components/agents/test_agent_breakpoints_inside_pipeline.py b/test/components/agents/test_agent_breakpoints_inside_pipeline.py index d5a019b538..817e1fdece 100644 --- a/test/components/agents/test_agent_breakpoints_inside_pipeline.py +++ b/test/components/agents/test_agent_breakpoints_inside_pipeline.py @@ -3,17 +3,20 @@ # SPDX-License-Identifier: Apache-2.0 import os +import re import tempfile from pathlib import Path from typing import Dict, List, Optional +import pytest + from haystack import component from haystack.components.agents import Agent from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder from haystack.components.generators.chat import OpenAIChatGenerator -from haystack.core.errors import BreakpointException, PipelineRuntimeError +from haystack.core.errors import BreakpointException from haystack.core.pipeline import Pipeline -from haystack.core.pipeline.breakpoint import load_state +from haystack.core.pipeline.breakpoint import load_pipeline_snapshot from haystack.dataclasses import ByteStream, ChatMessage, Document, ToolCall from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint from haystack.document_stores.in_memory import InMemoryDocumentStore @@ -77,8 +80,6 @@ def run(self, sources: List[ByteStream]) -> Dict[str, List[Document]]: # Simple text extraction - remove HTML tags and extract meaningful content # This is a simplified version that extracts the main content - import re - # Remove HTML tags text_content = re.sub(r"<[^>]+>", " ", html_content) # Remove extra whitespace @@ -101,7 +102,9 @@ def add_database_tool(name: str, surname: str, job_title: Optional[str], other: ) -def create_pipeline(): +@pytest.fixture +def pipeline_with_agent(monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test_key") generator = OpenAIChatGenerator() call_count = 0 @@ -198,8 +201,7 @@ def mock_run(messages, tools=None, **kwargs): return extraction_agent -def run_pipeline_without_any_breakpoints(): - pipeline_with_agent = create_pipeline() +def run_pipeline_without_any_breakpoints(pipeline_with_agent): agent_output = pipeline_with_agent.run(data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}) # pipeline completed @@ -215,84 +217,70 @@ def run_pipeline_without_any_breakpoints(): assert "Chief Technology Officer" in final_message -def test_chat_generator_breakpoint_in_pipeline_agent(): - pipeline_with_agent = create_pipeline() - agent_generator_breakpoint = Breakpoint("chat_generator", 0) - agent_breakpoint = AgentBreakpoint(break_point=agent_generator_breakpoint, agent_name="database_agent") - +def test_chat_generator_breakpoint_in_pipeline_agent(pipeline_with_agent): with tempfile.TemporaryDirectory() as debug_path: + agent_generator_breakpoint = Breakpoint("chat_generator", 0, debug_path=debug_path) + agent_breakpoint = AgentBreakpoint(break_point=agent_generator_breakpoint, agent_name="database_agent") try: pipeline_with_agent.run( - data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, - break_point=agent_breakpoint, - debug_path=debug_path, + data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint ) assert False, "Expected exception was not raised" except BreakpointException as e: # this is the exception from the Agent assert e.component == "chat_generator" - assert e.state is not None - assert "messages" in e.state + assert e.pipeline_snapshot is not None + assert "messages" in e.pipeline_snapshot assert e.results is not None - # verify that debug/state file was created - chat_generator_state_files = list(Path(debug_path).glob("database_agent_chat_generator_*.json")) - assert len(chat_generator_state_files) > 0, f"No chat_generator state files found in {debug_path}" - + # verify that snapshot file was created + chat_generator_snapshot_files = list(Path(debug_path).glob("database_agent_chat_generator_*.json")) + assert len(chat_generator_snapshot_files) > 0, f"No chat_generator snapshot file found in {debug_path}" -def test_tool_breakpoint_in_pipeline_agent(): - pipeline_with_agent = create_pipeline() - agent_tool_breakpoint = ToolBreakpoint("tool_invoker", 0, "add_database_tool") - agent_breakpoints = AgentBreakpoint(break_point=agent_tool_breakpoint, agent_name="database_agent") +def test_tool_breakpoint_in_pipeline_agent(pipeline_with_agent): with tempfile.TemporaryDirectory() as debug_path: + agent_tool_breakpoint = ToolBreakpoint("tool_invoker", 0, tool_name="add_database_tool", debug_path=debug_path) + agent_breakpoint = AgentBreakpoint(break_point=agent_tool_breakpoint, agent_name="database_agent") try: pipeline_with_agent.run( - data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, - break_point=agent_breakpoints, - debug_path=debug_path, + data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint ) assert False, "Expected exception was not raised" except BreakpointException as e: # this is the exception from the Agent assert e.component == "tool_invoker" - assert e.state is not None - assert "messages" in e.state + assert e.pipeline_snapshot is not None + assert "messages" in e.pipeline_snapshot assert e.results is not None - # verify that debug/state file was created - tool_invoker_state_files = list(Path(debug_path).glob("database_agent_tool_invoker_*.json")) - assert len(tool_invoker_state_files) > 0, f"No tool_invoker state files found in {debug_path}" + # verify that snapshot file was created + tool_invoker_snapshot_files = list(Path(debug_path).glob("database_agent_tool_invoker_*.json")) + assert len(tool_invoker_snapshot_files) > 0, f"No tool_invoker snapshot file found in {debug_path}" -def test_agent_breakpoint_chat_generator_and_resume_pipeline(): - pipeline_with_agent = create_pipeline() - agent_generator_breakpoint = Breakpoint("chat_generator", 0) - agent_breakpoints = AgentBreakpoint(break_point=agent_generator_breakpoint, agent_name="database_agent") - +def test_agent_breakpoint_chat_generator_and_resume_pipeline(pipeline_with_agent): with tempfile.TemporaryDirectory() as debug_path: + agent_generator_breakpoint = Breakpoint("chat_generator", 0, debug_path=debug_path) + agent_breakpoint = AgentBreakpoint(break_point=agent_generator_breakpoint, agent_name="database_agent") try: pipeline_with_agent.run( - data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, - break_point=agent_breakpoints, - debug_path=debug_path, + data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint ) assert False, "Expected PipelineBreakpointException was not raised" except BreakpointException as e: assert e.component == "chat_generator" - assert e.state is not None - assert "messages" in e.state + assert e.pipeline_snapshot is not None + assert "messages" in e.pipeline_snapshot assert e.results is not None - # verify that the state file was created - chat_generator_state_files = list(Path(debug_path).glob("database_agent_chat_generator_*.json")) - assert len(chat_generator_state_files) > 0, f"No chat_generator state files found in {debug_path}" - - # resume the pipeline from the saved state - latest_state_file = max(chat_generator_state_files, key=os.path.getctime) - resume_state = load_state(latest_state_file) + # verify that the snapshot file was created + chat_generator_snapshot_files = list(Path(debug_path).glob("database_agent_chat_generator_*.json")) + assert len(chat_generator_snapshot_files) > 0, f"No chat_generator snapshot file found in {debug_path}" - result = pipeline_with_agent.run(data={}, resume_state=resume_state) + # resume the pipeline from the saved snapshot + latest_snapshot_file = max(chat_generator_snapshot_files, key=os.path.getctime) + result = pipeline_with_agent.run(data={}, pipeline_snapshot=load_pipeline_snapshot(latest_snapshot_file)) # pipeline completed successfully after resuming assert "database_agent" in result @@ -316,35 +304,29 @@ def test_agent_breakpoint_chat_generator_and_resume_pipeline(): assert any("Milos Rusic" in name for name in person_names) -def test_agent_breakpoint_tool_and_resume_pipeline(): - pipeline_with_agent = create_pipeline() - agent_tool_breakpoint = ToolBreakpoint("tool_invoker", 0, "add_database_tool") - agent_breakpoints = AgentBreakpoint(break_point=agent_tool_breakpoint, agent_name="database_agent") - +def test_agent_breakpoint_tool_and_resume_pipeline(pipeline_with_agent): with tempfile.TemporaryDirectory() as debug_path: + agent_tool_breakpoint = ToolBreakpoint("tool_invoker", 0, tool_name="add_database_tool", debug_path=debug_path) + agent_breakpoint = AgentBreakpoint(break_point=agent_tool_breakpoint, agent_name="database_agent") try: pipeline_with_agent.run( - data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, - break_point=agent_breakpoints, - debug_path=debug_path, + data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint ) assert False, "Expected PipelineBreakpointException was not raised" except BreakpointException as e: assert e.component == "tool_invoker" - assert e.state is not None - assert "messages" in e.state + assert e.pipeline_snapshot is not None + assert "messages" in e.pipeline_snapshot assert e.results is not None - # verify that the state file was created - tool_invoker_state_files = list(Path(debug_path).glob("database_agent_tool_invoker_*.json")) - assert len(tool_invoker_state_files) > 0, f"No tool_invoker state files found in {debug_path}" - - # resume the pipeline from the saved state - latest_state_file = max(tool_invoker_state_files, key=os.path.getctime) - resume_state = load_state(latest_state_file) + # verify that the snapshot file was created + tool_invoker_snapshot_files = list(Path(debug_path).glob("database_agent_tool_invoker_*.json")) + assert len(tool_invoker_snapshot_files) > 0, f"No tool_invoker snapshot file found in {debug_path}" - result = pipeline_with_agent.run(data={}, resume_state=resume_state) + # resume the pipeline from the saved snapshot + latest_snapshot_file = max(tool_invoker_snapshot_files, key=os.path.getctime) + result = pipeline_with_agent.run(data={}, pipeline_snapshot=load_pipeline_snapshot(latest_snapshot_file)) # pipeline completed successfully after resuming assert "database_agent" in result diff --git a/test/components/agents/test_agent_breakpoints_isolation_async.py b/test/components/agents/test_agent_breakpoints_isolation_async.py index 4f86875605..297b7df711 100644 --- a/test/components/agents/test_agent_breakpoints_isolation_async.py +++ b/test/components/agents/test_agent_breakpoints_isolation_async.py @@ -4,28 +4,19 @@ import os from pathlib import Path -from typing import Optional from unittest.mock import AsyncMock import pytest from haystack.components.agents import Agent from haystack.core.errors import BreakpointException -from haystack.core.pipeline.breakpoint import load_state +from haystack.core.pipeline.breakpoint import load_pipeline_snapshot from haystack.dataclasses import ChatMessage, ToolCall from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint from haystack.tools import Tool from test.components.agents.test_agent import MockChatGeneratorWithRunAsync, weather_function -agent_name = "isolated_agent" - - -def create_chat_generator_breakpoint(visit_count: int = 0) -> Breakpoint: - return Breakpoint(component_name="chat_generator", visit_count=visit_count) - - -def create_tool_breakpoint(tool_name: Optional[str] = None, visit_count: int = 0) -> ToolBreakpoint: - return ToolBreakpoint(component_name="tool_invoker", visit_count=visit_count, tool_name=tool_name) +AGENT_NAME = "isolated_agent" @pytest.fixture @@ -70,7 +61,7 @@ def agent(mock_chat_generator, weather_tool): @pytest.fixture def debug_path(tmp_path): - return str(tmp_path / "debug_states") + return str(tmp_path / "debug_snapshots") @pytest.fixture @@ -90,49 +81,47 @@ def mock_agent_with_tool_calls(monkeypatch, weather_tool): @pytest.mark.asyncio async def test_run_async_with_chat_generator_breakpoint(agent): messages = [ChatMessage.from_user("What's the weather in Berlin?")] - chat_generator_bp = create_chat_generator_breakpoint(visit_count=0) + chat_generator_bp = Breakpoint(component_name="chat_generator", visit_count=0) agent_breakpoint = AgentBreakpoint(break_point=chat_generator_bp, agent_name="test") with pytest.raises(BreakpointException) as exc_info: - await agent.run_async(messages=messages, break_point=agent_breakpoint, agent_name=agent_name) + await agent.run_async(messages=messages, break_point=agent_breakpoint, agent_name=AGENT_NAME) assert exc_info.value.component == "chat_generator" - assert "messages" in exc_info.value.state + assert "messages" in exc_info.value.pipeline_snapshot @pytest.mark.asyncio async def test_run_async_with_tool_invoker_breakpoint(mock_agent_with_tool_calls): messages = [ChatMessage.from_user("What's the weather in Berlin?")] - tool_bp = create_tool_breakpoint(tool_name="weather_tool", visit_count=0) + tool_bp = ToolBreakpoint(component_name="tool_invoker", visit_count=0, tool_name="weather_tool") agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test") with pytest.raises(BreakpointException) as exc_info: await mock_agent_with_tool_calls.run_async( - messages=messages, break_point=agent_breakpoint, agent_name=agent_name + messages=messages, break_point=agent_breakpoint, agent_name=AGENT_NAME ) assert exc_info.value.component == "tool_invoker" - assert "messages" in exc_info.value.state + assert "messages" in exc_info.value.pipeline_snapshot @pytest.mark.asyncio async def test_resume_from_chat_generator_async(agent, debug_path): messages = [ChatMessage.from_user("What's the weather in Berlin?")] - chat_generator_bp = create_chat_generator_breakpoint(visit_count=0) - agent_breakpoint = AgentBreakpoint(break_point=chat_generator_bp, agent_name=agent_name) + chat_generator_bp = Breakpoint(component_name="chat_generator", visit_count=0, debug_path=debug_path) + agent_breakpoint = AgentBreakpoint(break_point=chat_generator_bp, agent_name=AGENT_NAME) try: - await agent.run_async( - messages=messages, break_point=agent_breakpoint, debug_path=debug_path, agent_name=agent_name - ) + await agent.run_async(messages=messages, break_point=agent_breakpoint, agent_name=AGENT_NAME) except BreakpointException: pass - state_files = list(Path(debug_path).glob(agent_name + "_chat_generator_*.json")) + snapshot_files = list(Path(debug_path).glob(AGENT_NAME + "_chat_generator_*.json")) - assert len(state_files) > 0 - latest_state_file = str(max(state_files, key=os.path.getctime)) + assert len(snapshot_files) > 0 + latest_snapshot_file = str(max(snapshot_files, key=os.path.getctime)) - resume_state = load_state(latest_state_file) result = await agent.run_async( - messages=[ChatMessage.from_user("Continue from where we left off.")], resume_state=resume_state + messages=[ChatMessage.from_user("Continue from where we left off.")], + pipeline_snapshot=load_pipeline_snapshot(latest_snapshot_file), ) assert "messages" in result @@ -143,25 +132,26 @@ async def test_resume_from_chat_generator_async(agent, debug_path): @pytest.mark.asyncio async def test_resume_from_tool_invoker_async(mock_agent_with_tool_calls, debug_path): messages = [ChatMessage.from_user("What's the weather in Berlin?")] - tool_bp = create_tool_breakpoint(tool_name="weather_tool", visit_count=0) - agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name=agent_name) + tool_bp = ToolBreakpoint( + component_name="tool_invoker", visit_count=0, tool_name="weather_tool", debug_path=debug_path + ) + agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name=AGENT_NAME) try: await mock_agent_with_tool_calls.run_async( - messages=messages, break_point=agent_breakpoint, debug_path=debug_path, agent_name=agent_name + messages=messages, break_point=agent_breakpoint, agent_name=AGENT_NAME ) except BreakpointException: pass - state_files = list(Path(debug_path).glob(agent_name + "_tool_invoker_*.json")) - - assert len(state_files) > 0 - latest_state_file = str(max(state_files, key=os.path.getctime)) + snapshot_files = list(Path(debug_path).glob(AGENT_NAME + "_tool_invoker_*.json")) - resume_state = load_state(latest_state_file) + assert len(snapshot_files) > 0 + latest_snapshot_file = str(max(snapshot_files, key=os.path.getctime)) result = await mock_agent_with_tool_calls.run_async( - messages=[ChatMessage.from_user("Continue from where we left off.")], resume_state=resume_state + messages=[ChatMessage.from_user("Continue from where we left off.")], + pipeline_snapshot=load_pipeline_snapshot(latest_snapshot_file), ) assert "messages" in result @@ -170,13 +160,13 @@ async def test_resume_from_tool_invoker_async(mock_agent_with_tool_calls, debug_ @pytest.mark.asyncio -async def test_invalid_combination_breakpoint_and_resume_state_async(mock_agent_with_tool_calls): +async def test_invalid_combination_breakpoint_and_pipeline_snapshot_async(mock_agent_with_tool_calls): messages = [ChatMessage.from_user("What's the weather in Berlin?")] - tool_bp = create_tool_breakpoint(tool_name="weather_tool", visit_count=0) + tool_bp = ToolBreakpoint(component_name="tool_invoker", visit_count=0, tool_name="weather_tool") agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test") - with pytest.raises(ValueError, match="agent_breakpoint and resume_state cannot be provided at the same time"): + with pytest.raises(ValueError, match="agent_breakpoint and pipeline_snapshot cannot be provided at the same time"): await mock_agent_with_tool_calls.run_async( - messages=messages, break_point=agent_breakpoint, resume_state={"some": "state"} + messages=messages, break_point=agent_breakpoint, pipeline_snapshot={"some": "snapshot"} ) @@ -189,7 +179,7 @@ async def test_breakpoint_with_invalid_component_async(mock_agent_with_tool_call @pytest.mark.asyncio async def test_breakpoint_with_invalid_tool_name_async(mock_agent_with_tool_calls): - tool_breakpoint = create_tool_breakpoint(tool_name="invalid_tool", visit_count=0) + tool_breakpoint = ToolBreakpoint(component_name="tool_invoker", visit_count=0, tool_name="invalid_tool") with pytest.raises(ValueError, match="Tool 'invalid_tool' is not available in the agent's tools"): agent_breakpoint = AgentBreakpoint(break_point=tool_breakpoint, agent_name="test") await mock_agent_with_tool_calls.run_async( diff --git a/test/components/agents/test_agent_breakpoints_isolation_sync.py b/test/components/agents/test_agent_breakpoints_isolation_sync.py index 23c3b5b277..dd649479ea 100644 --- a/test/components/agents/test_agent_breakpoints_isolation_sync.py +++ b/test/components/agents/test_agent_breakpoints_isolation_sync.py @@ -8,59 +8,57 @@ import pytest from haystack.core.errors import BreakpointException -from haystack.core.pipeline.breakpoint import load_state +from haystack.core.pipeline.breakpoint import load_pipeline_snapshot from haystack.dataclasses import ChatMessage -from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint +from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint from test.components.agents.test_agent_breakpoints_utils import ( agent_sync, - create_chat_generator_breakpoint, - create_tool_breakpoint, mock_agent_with_tool_calls_sync, weather_tool, ) -agent_name = "isolated_agent" +AGENT_NAME = "isolated_agent" def test_run_with_chat_generator_breakpoint(agent_sync): # noqa: F811 messages = [ChatMessage.from_user("What's the weather in Berlin?")] - chat_generator_bp = create_chat_generator_breakpoint(visit_count=0) + chat_generator_bp = Breakpoint(component_name="chat_generator", visit_count=0) agent_breakpoint = AgentBreakpoint(break_point=chat_generator_bp, agent_name="test_agent") with pytest.raises(BreakpointException) as exc_info: agent_sync.run(messages=messages, break_point=agent_breakpoint, agent_name="test") assert exc_info.value.component == "chat_generator" - assert "messages" in exc_info.value.state + assert "messages" in exc_info.value.pipeline_snapshot def test_run_with_tool_invoker_breakpoint(mock_agent_with_tool_calls_sync): # noqa: F811 messages = [ChatMessage.from_user("What's the weather in Berlin?")] - tool_bp = create_tool_breakpoint(tool_name="weather_tool", visit_count=0) + tool_bp = ToolBreakpoint(component_name="tool_invoker", visit_count=0, tool_name="weather_tool") agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test_agent") with pytest.raises(BreakpointException) as exc_info: mock_agent_with_tool_calls_sync.run(messages=messages, break_point=agent_breakpoint, agent_name="test") assert exc_info.value.component == "tool_invoker" - assert "messages" in exc_info.value.state + assert "messages" in exc_info.value.pipeline_snapshot def test_resume_from_chat_generator(agent_sync, tmp_path): # noqa: F811 messages = [ChatMessage.from_user("What's the weather in Berlin?")] - chat_generator_bp = create_chat_generator_breakpoint(visit_count=0) - agent_breakpoint = AgentBreakpoint(break_point=chat_generator_bp, agent_name=agent_name) - debug_path = str(tmp_path / "debug_states") + debug_path = str(tmp_path / "debug_snapshots") + chat_generator_bp = Breakpoint(component_name="chat_generator", visit_count=0, debug_path=debug_path) + agent_breakpoint = AgentBreakpoint(break_point=chat_generator_bp, agent_name=AGENT_NAME) try: - agent_sync.run(messages=messages, break_point=agent_breakpoint, debug_path=debug_path, agent_name=agent_name) + agent_sync.run(messages=messages, break_point=agent_breakpoint, agent_name=AGENT_NAME) except BreakpointException: pass - state_files = list(Path(debug_path).glob(agent_name + "_chat_generator_*.json")) - assert len(state_files) > 0 - latest_state_file = str(max(state_files, key=os.path.getctime)) + snapshot_files = list(Path(debug_path).glob(AGENT_NAME + "_chat_generator_*.json")) + assert len(snapshot_files) > 0 + latest_snapshot_file = str(max(snapshot_files, key=os.path.getctime)) - resume_state = load_state(latest_state_file) result = agent_sync.run( - messages=[ChatMessage.from_user("Continue from where we left off.")], resume_state=resume_state + messages=[ChatMessage.from_user("Continue from where we left off.")], + pipeline_snapshot=load_pipeline_snapshot(latest_snapshot_file), ) assert "messages" in result @@ -70,25 +68,22 @@ def test_resume_from_chat_generator(agent_sync, tmp_path): # noqa: F811 def test_resume_from_tool_invoker(mock_agent_with_tool_calls_sync, tmp_path): # noqa: F811 messages = [ChatMessage.from_user("What's the weather in Berlin?")] - tool_bp = create_tool_breakpoint(tool_name="weather_tool", visit_count=0) - agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name=agent_name) - debug_path = str(tmp_path / "debug_states") + debug_path = str(tmp_path / "debug_snapshots") + tool_bp = ToolBreakpoint(component_name="tool_invoker", visit_count=0, tool_name=None, debug_path=debug_path) + agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name=AGENT_NAME) try: - mock_agent_with_tool_calls_sync.run( - messages=messages, break_point=agent_breakpoint, debug_path=debug_path, agent_name=agent_name - ) + mock_agent_with_tool_calls_sync.run(messages=messages, break_point=agent_breakpoint, agent_name=AGENT_NAME) except BreakpointException: pass - state_files = list(Path(debug_path).glob(agent_name + "_tool_invoker_*.json")) - assert len(state_files) > 0 - latest_state_file = str(max(state_files, key=os.path.getctime)) - - resume_state = load_state(latest_state_file) + snapshot_files = list(Path(debug_path).glob(AGENT_NAME + "_tool_invoker_*.json")) + assert len(snapshot_files) > 0 + latest_snapshot_file = str(max(snapshot_files, key=os.path.getctime)) result = mock_agent_with_tool_calls_sync.run( - messages=[ChatMessage.from_user("Continue from where we left off.")], resume_state=resume_state + messages=[ChatMessage.from_user("Continue from where we left off.")], + pipeline_snapshot=load_pipeline_snapshot(latest_snapshot_file), ) assert "messages" in result @@ -96,13 +91,13 @@ def test_resume_from_tool_invoker(mock_agent_with_tool_calls_sync, tmp_path): # assert len(result["messages"]) > 0 -def test_invalid_combination_breakpoint_and_resume_state(mock_agent_with_tool_calls_sync): # noqa: F811 +def test_invalid_combination_breakpoint_and_pipeline_snapshot(mock_agent_with_tool_calls_sync): # noqa: F811 messages = [ChatMessage.from_user("What's the weather in Berlin?")] - tool_bp = create_tool_breakpoint(tool_name="weather_tool", visit_count=0) + tool_bp = ToolBreakpoint(component_name="tool_invoker", visit_count=0, tool_name="weather_tool") agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test_agent") - with pytest.raises(ValueError, match="agent_breakpoint and resume_state cannot be provided at the same time"): + with pytest.raises(ValueError, match="agent_breakpoint and pipeline_snapshot cannot be provided at the same time"): mock_agent_with_tool_calls_sync.run( - messages=messages, break_point=agent_breakpoint, resume_state={"some": "state"} + messages=messages, break_point=agent_breakpoint, pipeline_snapshot={"some": "snapshot"} ) @@ -113,7 +108,7 @@ def test_breakpoint_with_invalid_component(mock_agent_with_tool_calls_sync): # def test_breakpoint_with_invalid_tool_name(mock_agent_with_tool_calls_sync): # noqa: F811 - tool_breakpoint = create_tool_breakpoint(tool_name="invalid_tool", visit_count=0) + tool_breakpoint = ToolBreakpoint(component_name="tool_invoker", visit_count=0, tool_name="invalid_tool") with pytest.raises(ValueError, match="Tool 'invalid_tool' is not available in the agent's tools"): agent_breakpoints = AgentBreakpoint(break_point=tool_breakpoint, agent_name="test_agent") mock_agent_with_tool_calls_sync.run( diff --git a/test/components/agents/test_agent_breakpoints_utils.py b/test/components/agents/test_agent_breakpoints_utils.py index a7beeaecf8..3edd23100f 100644 --- a/test/components/agents/test_agent_breakpoints_utils.py +++ b/test/components/agents/test_agent_breakpoints_utils.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Optional from unittest.mock import AsyncMock, MagicMock import pytest @@ -10,7 +9,6 @@ from haystack.components.agents import Agent from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage, ToolCall -from haystack.dataclasses.breakpoints import Breakpoint, ToolBreakpoint from haystack.tools import Tool from test.components.agents.test_agent import ( MockChatGeneratorWithoutRunAsync, @@ -19,14 +17,6 @@ ) -def create_chat_generator_breakpoint(visit_count: int = 0) -> Breakpoint: - return Breakpoint(component_name="chat_generator", visit_count=visit_count) - - -def create_tool_breakpoint(tool_name: Optional[str] = None, visit_count: int = 0) -> ToolBreakpoint: - return ToolBreakpoint(component_name="tool_invoker", visit_count=visit_count, tool_name=tool_name) - - # Common fixtures @pytest.fixture def weather_tool(): @@ -40,7 +30,7 @@ def weather_tool(): @pytest.fixture def debug_path(tmp_path): - return str(tmp_path / "debug_states") + return str(tmp_path / "debug_snapshots") @pytest.fixture diff --git a/test/conftest.py b/test/conftest.py index 98477b957e..002c026d12 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -11,7 +11,7 @@ import pytest from haystack import component, tracing -from haystack.core.pipeline.breakpoint import load_state +from haystack.core.pipeline.breakpoint import load_pipeline_snapshot from haystack.testing.test_utils import set_all_seeds from test.tracing.utils import SpyingTracer @@ -83,20 +83,19 @@ def spying_tracer() -> Generator[SpyingTracer, None, None]: tracing.disable_tracing() -def load_and_resume_pipeline_state(pipeline, output_directory: Path, component: str, data: Dict = None) -> Dict: +def load_and_resume_pipeline_snapshot(pipeline, output_directory: Path, component_name: str, data: Dict = None) -> Dict: """ - Utility function to load and resume pipeline state from a breakpoint file. + Utility function to load and resume pipeline snapshot from a breakpoint file. - Args: - pipeline: The pipeline instance to resume - output_directory: Directory containing the breakpoint files - component: Component name to look for in breakpoint files - data: Data to pass to the pipeline run (defaults to empty dict) + :param pipeline: The pipeline instance to resume + :param output_directory: Directory containing the breakpoint files + :param component_name: Component name to look for in breakpoint files + :param data: Data to pass to the pipeline run (defaults to empty dict) - Returns: + :returns: Dict containing the pipeline run results - Raises: + :raises: ValueError: If no breakpoint file is found for the given component """ data = data or {} @@ -105,10 +104,10 @@ def load_and_resume_pipeline_state(pipeline, output_directory: Path, component: for full_path in all_files: f_name = Path(full_path).name - if str(f_name).startswith(component): - resume_state = load_state(full_path) - return pipeline.run(data=data, resume_state=resume_state) + if str(f_name).startswith(component_name): + resume_state = load_pipeline_snapshot(full_path) + return pipeline.run(data=data, pipeline_snapshot=resume_state) if not file_found: - msg = f"No files found for {component} in {output_directory}." + msg = f"No files found for {component_name} in {output_directory}." raise ValueError(msg) diff --git a/test/core/pipeline/test_breakpoint.py b/test/core/pipeline/test_breakpoint.py index 8f6c671171..505292668c 100644 --- a/test/core/pipeline/test_breakpoint.py +++ b/test/core/pipeline/test_breakpoint.py @@ -8,9 +8,8 @@ from haystack.core.pipeline.breakpoint import ( _transform_json_structure, - _validate_break_point, - _validate_resume_state, - load_state, + _validate_pipeline_snapshot, + load_pipeline_snapshot, ) @@ -38,17 +37,17 @@ def test_transform_json_structure_handles_nested_structures(): assert result == {"key1": "value1", "key2": {"nested": "value2", "direct": "value3"}, "key3": ["value4", "value5"]} -def test_validate_resume_state_validates_required_keys(): - state = { +def test_validate_pipeline_snapshot_validates_required_keys(): + pipeline_snapshot = { "input_data": {}, "pipeline_breakpoint": {"component": "comp1", "visits": 0}, # Missing pipeline_state } - with pytest.raises(ValueError, match="Invalid state file: missing required keys"): - _validate_resume_state(state) + with pytest.raises(ValueError, match="Invalid pipeline_snapshot: missing required keys"): + _validate_pipeline_snapshot(pipeline_snapshot) - state = { + pipeline_snapshot = { "input_data": {}, "pipeline_breakpoint": {"component": "comp1", "visits": 0}, "pipeline_state": { @@ -59,11 +58,11 @@ def test_validate_resume_state_validates_required_keys(): } with pytest.raises(ValueError, match="Invalid pipeline_state: missing required keys"): - _validate_resume_state(state) + _validate_pipeline_snapshot(pipeline_snapshot) -def test_validate_resume_state_validates_component_consistency(): - state = { +def test_validate_pipeline_snapshot_validates_component_consistency(): + pipeline_snapshot = { "input_data": {}, "pipeline_breakpoint": {"component": "comp1", "visits": 0}, "pipeline_state": { @@ -74,11 +73,11 @@ def test_validate_resume_state_validates_component_consistency(): } with pytest.raises(ValueError, match="Inconsistent state: components in pipeline_state"): - _validate_resume_state(state) + _validate_pipeline_snapshot(pipeline_snapshot) -def test_validate_resume_state_validates_valid_state(): - state = { +def test_validate_pipeline_snapshot_validates_valid_snapshot(): + pipeline_snapshot = { "input_data": {}, "pipeline_breakpoint": {"component": "comp1", "visits": 0}, "pipeline_state": { @@ -88,11 +87,11 @@ def test_validate_resume_state_validates_valid_state(): }, } - _validate_resume_state(state) # should not raise any exception + _validate_pipeline_snapshot(pipeline_snapshot) # should not raise any exception -def test_load_state_loads_valid_state(tmp_path): - state = { +def test_load_pipeline_snapshot_loads_valid_snapshot(tmp_path): + pipeline_snapshot = { "input_data": {}, "pipeline_breakpoint": {"component": "comp1", "visits": 0}, "pipeline_state": { @@ -101,16 +100,16 @@ def test_load_state_loads_valid_state(tmp_path): "ordered_component_names": ["comp1", "comp2"], }, } - state_file = tmp_path / "state.json" - with open(state_file, "w") as f: - json.dump(state, f) + pipeline_snapshot_file = tmp_path / "state.json" + with open(pipeline_snapshot_file, "w") as f: + json.dump(pipeline_snapshot, f) - loaded_state = load_state(state_file) - assert loaded_state == state + loaded_snapshot = load_pipeline_snapshot(pipeline_snapshot_file) + assert loaded_snapshot == pipeline_snapshot def test_load_state_handles_invalid_state(tmp_path): - state = { + pipeline_snapshot = { "input_data": {}, "pipeline_breakpoint": {"component": "comp1", "visits": 0}, "pipeline_state": { @@ -120,9 +119,9 @@ def test_load_state_handles_invalid_state(tmp_path): }, } - state_file = tmp_path / "invalid_state.json" - with open(state_file, "w") as f: - json.dump(state, f) + pipeline_snapshot_file = tmp_path / "invalid_pipeline_snapshot.json" + with open(pipeline_snapshot_file, "w") as f: + json.dump(pipeline_snapshot, f) - with pytest.raises(ValueError, match="Invalid pipeline state from"): - load_state(state_file) + with pytest.raises(ValueError, match="Invalid pipeline snapshot from"): + load_pipeline_snapshot(pipeline_snapshot_file) diff --git a/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py index bec638846c..9752ac9c01 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from pathlib import Path from unittest.mock import MagicMock, patch import pytest @@ -14,7 +15,7 @@ from haystack.dataclasses import ChatMessage from haystack.dataclasses.breakpoints import Breakpoint from haystack.utils.auth import Secret -from test.conftest import load_and_resume_pipeline_state +from test.conftest import load_and_resume_pipeline_snapshot class TestPipelineBreakpoints: @@ -86,18 +87,12 @@ def answer_join_pipeline(self, mock_openai_chat_generator): return pipeline @pytest.fixture(scope="session") - def output_directory(self, tmp_path_factory): + def output_directory(self, tmp_path_factory) -> Path: return tmp_path_factory.mktemp("output_files") - components = [ - Breakpoint("gpt-4o", 0), - Breakpoint("gpt-3", 0), - Breakpoint("answer_builder_a", 0), - Breakpoint("answer_builder_b", 0), - Breakpoint("answer_joiner", 0), - ] + BREAKPOINT_COMPONENTS = ["gpt-4o", "gpt-3", "answer_builder_a", "answer_builder_b", "answer_joiner"] - @pytest.mark.parametrize("component", components) + @pytest.mark.parametrize("component", BREAKPOINT_COMPONENTS, ids=BREAKPOINT_COMPONENTS) @pytest.mark.integration def test_pipeline_breakpoints_answer_joiner(self, answer_join_pipeline, output_directory, component): """ @@ -115,10 +110,18 @@ def test_pipeline_breakpoints_answer_joiner(self, answer_join_pipeline, output_d "answer_builder_b": {"query": query}, } + # Create a Breakpoint on-the-fly using the shared output directory + break_point = Breakpoint(component_name=component, visit_count=0, debug_path=str(output_directory)) + try: - _ = answer_join_pipeline.run(data, break_point=component, debug_path=str(output_directory)) + _ = answer_join_pipeline.run(data, break_point=break_point) except BreakpointException: pass - result = load_and_resume_pipeline_state(answer_join_pipeline, output_directory, component.component_name, data) + result = load_and_resume_pipeline_snapshot( + pipeline=answer_join_pipeline, + output_directory=output_directory, + component_name=break_point.component_name, + data=data, + ) assert result["answer_joiner"] diff --git a/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py index 4fb41dbc87..75e526e57b 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from pathlib import Path from typing import List from unittest.mock import MagicMock, patch @@ -16,7 +17,7 @@ from haystack.dataclasses import ChatMessage from haystack.dataclasses.breakpoints import Breakpoint from haystack.utils.auth import Secret -from test.conftest import load_and_resume_pipeline_state +from test.conftest import load_and_resume_pipeline_snapshot class TestPipelineBreakpoints: @@ -91,17 +92,12 @@ def branch_joiner_pipeline(self, mock_openai_chat_generator): return pipe @pytest.fixture(scope="session") - def output_directory(self, tmp_path_factory): + def output_directory(self, tmp_path_factory) -> Path: return tmp_path_factory.mktemp("output_files") - components = [ - Breakpoint("joiner", 0), - Breakpoint("fc_llm", 0), - Breakpoint("validator", 0), - Breakpoint("adapter", 0), - ] + BREAKPOINT_COMPONENTS = ["joiner", "fc_llm", "validator", "adapter"] - @pytest.mark.parametrize("component", components) + @pytest.mark.parametrize("component", BREAKPOINT_COMPONENTS, ids=BREAKPOINT_COMPONENTS) @pytest.mark.integration def test_pipeline_breakpoints_branch_joiner(self, branch_joiner_pipeline, output_directory, component): data = { @@ -109,12 +105,18 @@ def test_pipeline_breakpoints_branch_joiner(self, branch_joiner_pipeline, output "adapter": {"chat_message": [ChatMessage.from_user("Create JSON from Peter Parker")]}, } + # Create a Breakpoint on-the-fly using the shared output directory + break_point = Breakpoint(component_name=component, visit_count=0, debug_path=str(output_directory)) + try: - _ = branch_joiner_pipeline.run(data, break_point=component, debug_path=str(output_directory)) + _ = branch_joiner_pipeline.run(data, break_point=break_point) except BreakpointException: pass - result = load_and_resume_pipeline_state( - branch_joiner_pipeline, output_directory, component.component_name, data + result = load_and_resume_pipeline_snapshot( + pipeline=branch_joiner_pipeline, + output_directory=output_directory, + component_name=break_point.component_name, + data=data, ) assert result["validator"], "The result should be valid according to the schema." diff --git a/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py index 5c12c40d64..7281585fdb 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from pathlib import Path from typing import List from unittest.mock import MagicMock, patch @@ -15,7 +16,7 @@ from haystack.dataclasses import ChatMessage from haystack.dataclasses.breakpoints import Breakpoint from haystack.utils.auth import Secret -from test.conftest import load_and_resume_pipeline_state +from test.conftest import load_and_resume_pipeline_snapshot class TestPipelineBreakpoints: @@ -109,18 +110,12 @@ def list_joiner_pipeline(self, mock_openai_chat_generator): return pipe @pytest.fixture(scope="session") - def output_directory(self, tmp_path_factory): + def output_directory(self, tmp_path_factory) -> Path: return tmp_path_factory.mktemp("output_files") - components = [ - Breakpoint("prompt_builder", 0), - Breakpoint("llm", 0), - Breakpoint("feedback_prompt_builder", 0), - Breakpoint("feedback_llm", 0), - Breakpoint("list_joiner", 0), - ] + BREAKPOINT_COMPONENTS = ["prompt_builder", "llm", "feedback_prompt_builder", "feedback_llm", "list_joiner"] - @pytest.mark.parametrize("component", components) + @pytest.mark.parametrize("component", BREAKPOINT_COMPONENTS, ids=BREAKPOINT_COMPONENTS) @pytest.mark.integration def test_list_joiner_pipeline(self, list_joiner_pipeline, output_directory, component): query = "What is nuclear physics?" @@ -129,10 +124,18 @@ def test_list_joiner_pipeline(self, list_joiner_pipeline, output_directory, comp "feedback_prompt_builder": {"template_variables": {"query": query}}, } + # Create a Breakpoint on-the-fly using the shared output directory + break_point = Breakpoint(component_name=component, visit_count=0, debug_path=str(output_directory)) + try: - _ = list_joiner_pipeline.run(data, break_point=component, debug_path=str(output_directory)) + _ = list_joiner_pipeline.run(data, break_point=break_point) except BreakpointException: pass - result = load_and_resume_pipeline_state(list_joiner_pipeline, output_directory, component.component_name, data) + result = load_and_resume_pipeline_snapshot( + pipeline=list_joiner_pipeline, + output_directory=output_directory, + component_name=break_point.component_name, + data=data, + ) assert result["list_joiner"] diff --git a/test/core/pipeline/test_pipeline_breakpoints_loops.py b/test/core/pipeline/test_pipeline_breakpoints_loops.py index 2b63f4da9b..fd4365d8b8 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_loops.py +++ b/test/core/pipeline/test_pipeline_breakpoints_loops.py @@ -16,7 +16,7 @@ from haystack.components.builders import ChatPromptBuilder from haystack.components.generators.chat import OpenAIChatGenerator from haystack.core.errors import BreakpointException -from haystack.core.pipeline.breakpoint import load_state +from haystack.core.pipeline.breakpoint import load_pipeline_snapshot from haystack.core.pipeline.pipeline import Pipeline from haystack.dataclasses import ChatMessage from haystack.dataclasses.breakpoints import Breakpoint @@ -198,9 +198,9 @@ def test_data(self): return {"schema": json_schema, "passage": passage} - components = [Breakpoint("prompt_builder", 0), Breakpoint("llm", 0), Breakpoint("output_validator", 0)] + BREAKPOINT_COMPONENTS = ["prompt_builder", "llm", "output_validator"] - @pytest.mark.parametrize("component", components) + @pytest.mark.parametrize("component", BREAKPOINT_COMPONENTS, ids=BREAKPOINT_COMPONENTS) @pytest.mark.integration def test_pipeline_breakpoints_validation_loop( self, validation_loop_pipeline, output_directory, test_data, component @@ -210,8 +210,11 @@ def test_pipeline_breakpoints_validation_loop( """ data = {"prompt_builder": {"passage": test_data["passage"], "schema": test_data["schema"]}} + # Create a Breakpoint on-the-fly using the shared output directory + break_point = Breakpoint(component_name=component, visit_count=0, debug_path=str(output_directory)) + try: - _ = validation_loop_pipeline.run(data, break_point=component, debug_path=str(output_directory)) + _ = validation_loop_pipeline.run(data, break_point=break_point) except BreakpointException: pass @@ -219,10 +222,9 @@ def test_pipeline_breakpoints_validation_loop( file_found = False for full_path in all_files: f_name = Path(full_path).name - if str(f_name).startswith(component.component_name): + if str(f_name).startswith(break_point.component_name): file_found = True - resume_state = load_state(full_path) - result = validation_loop_pipeline.run(data={}, resume_state=resume_state) + result = validation_loop_pipeline.run(data={}, pipeline_snapshot=load_pipeline_snapshot(full_path)) # Verify the result contains valid output if "output_validator" in result and "valid_replies" in result["output_validator"]: valid_reply = result["output_validator"]["valid_replies"][0].text diff --git a/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py b/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py index affb836dcc..6c9f8a95fe 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py +++ b/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from pathlib import Path from unittest.mock import MagicMock, patch import pytest @@ -21,7 +22,7 @@ from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack.document_stores.types import DuplicatePolicy from haystack.utils.auth import Secret -from test.conftest import load_and_resume_pipeline_state +from test.conftest import load_and_resume_pipeline_snapshot class TestPipelineBreakpoints: @@ -251,21 +252,21 @@ def hybrid_rag_pipeline( return pipeline @pytest.fixture(scope="session") - def output_directory(self, tmp_path_factory): + def output_directory(self, tmp_path_factory) -> Path: return tmp_path_factory.mktemp("output_files") - components = [ - Breakpoint("bm25_retriever", 0), - Breakpoint("query_embedder", 0), - Breakpoint("embedding_retriever", 0), - Breakpoint("doc_joiner", 0), - Breakpoint("ranker", 0), - Breakpoint("prompt_builder", 0), - Breakpoint("llm", 0), - Breakpoint("answer_builder", 0), + BREAKPOINT_COMPONENTS = [ + "bm25_retriever", + "query_embedder", + "embedding_retriever", + "doc_joiner", + "ranker", + "prompt_builder", + "llm", + "answer_builder", ] - @pytest.mark.parametrize("component", components) + @pytest.mark.parametrize("component", BREAKPOINT_COMPONENTS, ids=BREAKPOINT_COMPONENTS) @pytest.mark.integration def test_pipeline_breakpoints_hybrid_rag( self, hybrid_rag_pipeline, document_store, output_directory, component, mock_openai_completion @@ -283,10 +284,18 @@ def test_pipeline_breakpoints_hybrid_rag( "answer_builder": {"query": question}, } + # Create a Breakpoint on-the-fly using the shared output directory + break_point = Breakpoint(component_name=component, visit_count=0, debug_path=str(output_directory)) + try: - _ = hybrid_rag_pipeline.run(data, break_point=component, debug_path=str(output_directory)) + _ = hybrid_rag_pipeline.run(data, break_point=break_point) except BreakpointException: pass - result = load_and_resume_pipeline_state(hybrid_rag_pipeline, output_directory, component.component_name, data) + result = load_and_resume_pipeline_snapshot( + pipeline=hybrid_rag_pipeline, + output_directory=output_directory, + component_name=break_point.component_name, + data=data, + ) assert result["answer_builder"] diff --git a/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py index 8c01b5973b..4b791535c6 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py @@ -11,7 +11,7 @@ from haystack.core.pipeline.pipeline import Pipeline from haystack.dataclasses import ChatMessage from haystack.dataclasses.breakpoints import Breakpoint -from test.conftest import load_and_resume_pipeline_state +from test.conftest import load_and_resume_pipeline_snapshot class TestPipelineBreakpoints: @@ -39,27 +39,27 @@ def string_joiner_pipeline(self): def output_directory(self, tmp_path_factory): return tmp_path_factory.mktemp("output_files") - components = [ - Breakpoint("prompt_builder_1", 0), - Breakpoint("prompt_builder_2", 0), - Breakpoint("adapter_1", 0), - Breakpoint("adapter_2", 0), - Breakpoint("string_joiner", 0), - ] + BREAKPOINT_COMPONENTS = ["prompt_builder_1", "prompt_builder_2", "adapter_1", "adapter_2", "string_joiner"] - @pytest.mark.parametrize("component", components) + @pytest.mark.parametrize("component", BREAKPOINT_COMPONENTS, ids=BREAKPOINT_COMPONENTS) @pytest.mark.integration def test_string_joiner_pipeline(self, string_joiner_pipeline, output_directory, component): string_1 = "What's Natural Language Processing?" string_2 = "What is life?" data = {"prompt_builder_1": {"query": string_1}, "prompt_builder_2": {"query": string_2}} + # Create a Breakpoint on-the-fly using the shared output directory + break_point = Breakpoint(component_name=component, visit_count=0, debug_path=str(output_directory)) + try: - _ = string_joiner_pipeline.run(data, break_point=component, debug_path=str(output_directory)) + _ = string_joiner_pipeline.run(data, break_point=break_point) except BreakpointException: pass - result = load_and_resume_pipeline_state( - string_joiner_pipeline, output_directory, component.component_name, data + result = load_and_resume_pipeline_snapshot( + pipeline=string_joiner_pipeline, + output_directory=output_directory, + component_name=break_point.component_name, + data=data, ) assert result["string_joiner"] From bbaf3099f193432587b22cb85d3215ec3bb5e053 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> Date: Fri, 18 Jul 2025 14:33:14 +0200 Subject: [PATCH 11/21] feat: Add dataclasses to represent a `PipelineSnapshot` and refactored to use it (#9619) * Refactor to use dataclasses for PipelineSnapshot and AgentSnapshot * Fix integration tests * Mypy * Fix mypy * Fix lint * Refactor AgentSnapshot to only contain needed info * Fix mypy * More refactoring * removing unused import --------- Co-authored-by: David S. Batista --- haystack/components/agents/agent.py | 264 +++++---- haystack/core/errors.py | 4 +- haystack/core/pipeline/breakpoint.py | 526 +++++++----------- haystack/core/pipeline/pipeline.py | 177 ++---- haystack/dataclasses/breakpoints.py | 184 +++++- .../test_agent_breakpoints_inside_pipeline.py | 42 +- .../test_agent_breakpoints_isolation_async.py | 16 +- .../test_agent_breakpoints_isolation_sync.py | 21 +- test/conftest.py | 4 +- test/core/pipeline/test_breakpoint.py | 70 +-- ...test_pipeline_breakpoints_answer_joiner.py | 2 +- ...test_pipeline_breakpoints_branch_joiner.py | 2 +- .../test_pipeline_breakpoints_list_joiner.py | 2 +- .../test_pipeline_breakpoints_loops.py | 2 +- .../test_pipeline_breakpoints_rag_hybrid.py | 2 +- ...test_pipeline_breakpoints_string_joiner.py | 2 +- 16 files changed, 669 insertions(+), 651 deletions(-) diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 7fc18a54c6..d8c4fd9f1e 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -10,14 +10,20 @@ from haystack.components.tools import ToolInvoker from haystack.core.component.component import component from haystack.core.pipeline.async_pipeline import AsyncPipeline -from haystack.core.pipeline.breakpoint import _check_chat_generator_breakpoint, _check_tool_invoker_breakpoint +from haystack.core.pipeline.breakpoint import ( + _check_chat_generator_breakpoint, + _check_tool_invoker_breakpoint, + _create_agent_snapshot, + _validate_tool_breakpoint_is_valid, +) from haystack.core.pipeline.pipeline import Pipeline from haystack.core.pipeline.utils import _deepcopy_with_exceptions from haystack.core.serialization import component_to_dict, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, ChatRole -from haystack.dataclasses.breakpoints import AgentBreakpoint, ToolBreakpoint +from haystack.dataclasses.breakpoints import AgentBreakpoint, AgentSnapshot, ToolBreakpoint from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset +from haystack.utils import _deserialize_value_with_schema from haystack.utils.callable_serialization import deserialize_callable, serialize_callable from haystack.utils.deserialization import deserialize_chatgenerator_inplace @@ -232,28 +238,13 @@ def _create_agent_span(self) -> Any: }, ) - def _validate_tool_breakpoint_is_valid(self, agent_breakpoint: AgentBreakpoint) -> None: - """ - Validates the AgentBreakpoint passed to the agent. - - Validates that all tool names in ToolBreakpoints correspond to tools available in the agent. - - :param agent_breakpoint: AgentBreakpoint object containing breakpoints for the agent components. - :raises ValueError: If any tool name in ToolBreakpoints is not available in the agent's tools. - """ - - available_tool_names = {tool.name for tool in self.tools} - tool_breakpoint = agent_breakpoint.break_point - if tool_breakpoint.tool_name is not None and tool_breakpoint.tool_name not in available_tool_names: # type: ignore # was checked outside function - raise ValueError(f"Tool '{tool_breakpoint.tool_name}' is not available in the agent's tools") # type: ignore # was checked outside function - - def run( + def run( # noqa: PLR0915 self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, *, break_point: Optional[AgentBreakpoint] = None, - pipeline_snapshot: Optional[Dict[str, Any]] = None, + snapshot: Optional[AgentSnapshot] = None, **kwargs: Any, ) -> Dict[str, Any]: """ @@ -265,7 +256,7 @@ def run( The same callback can be configured to emit tool results when a tool is called. :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint for "tool_invoker". - :param pipeline_snapshot: A dictionary containing the state of a previously saved agent execution. + :param snapshot: A dictionary containing the state of a previously saved agent execution. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :returns: @@ -274,32 +265,46 @@ def run( - "last_message": The last message exchanged during the agent's run. - Any additional keys defined in the `state_schema`. :raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`. - :raises AgentBreakpointException: If an agent breakpoint is triggered. + :raises BreakpointException: If an agent breakpoint is triggered. """ + # kwargs can contain the key parent_snapshot. + # We pop it here to avoid passing it into State. We explicitly handle it pass it on if a break point is + # triggered. + parent_snapshot = kwargs.pop("parent_snapshot", None) + if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"): raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run()'.") - if break_point and pipeline_snapshot: + if break_point and snapshot: raise ValueError( - "agent_breakpoint and pipeline_snapshot cannot be provided at the same time. " - "The agent run will be aborted." + "break_point and snapshot cannot be provided at the same time. The agent run will be aborted." ) # validate breakpoints if break_point and isinstance(break_point.break_point, ToolBreakpoint): - self._validate_tool_breakpoint_is_valid(break_point) - - # Handle pipeline snapshot if provided - if pipeline_snapshot: - component_visits = pipeline_snapshot.get("pipeline_state", {}).get("component_visits", {}) - state_data = pipeline_snapshot.get("pipeline_state", {}).get("inputs", {}).get("state", {}).get("data", {}) - - # deserialize messages from pipeline state - raw_messages = pipeline_snapshot.get("pipeline_state", {}).get("inputs", {}).get("messages", messages) - # convert raw message dictionaries to ChatMessage objects and populate the state - messages = [ChatMessage.from_dict(msg) if isinstance(msg, dict) else msg for msg in raw_messages] + _validate_tool_breakpoint_is_valid(agent_breakpoint=break_point, tools=self.tools) + + # Handle agent snapshot if provided + if snapshot: + component_visits = snapshot.component_visits + current_inputs = { + "chat_generator": _deserialize_value_with_schema(snapshot.component_inputs["chat_generator"]), + "tool_invoker": _deserialize_value_with_schema(snapshot.component_inputs["tool_invoker"]), + } + state_data = current_inputs["tool_invoker"]["state"].data + if isinstance(snapshot.break_point.break_point, ToolBreakpoint): + # If the break point is a ToolBreakpoint, we need to get the messages from the tool invoker inputs + messages = current_inputs["tool_invoker"]["messages"] + # Needed to know if we should start with the ToolInvoker or ChatGenerator + skip_chat_generator = True + else: + messages = current_inputs["chat_generator"]["messages"] + skip_chat_generator = False + # We also load the streaming_callback from the snapshot if it exists + streaming_callback = current_inputs["chat_generator"].get("streaming_callback", None) else: + skip_chat_generator = False if self.system_prompt is not None: messages = [ChatMessage.from_system(self.system_prompt)] + messages @@ -328,25 +333,33 @@ def run( while counter < self.max_agent_steps: # check for breakpoint before ChatGenerator - _check_chat_generator_breakpoint( - agent_breakpoint=break_point, - component_visits=component_visits, - messages=messages, - generator_inputs=generator_inputs, - kwargs=kwargs, - state=state, - ) + if break_point and break_point.break_point.component_name == "chat_generator": + agent_snapshot = _create_agent_snapshot( + component_visits=component_visits, + agent_breakpoint=break_point, + component_inputs={ + "chat_generator": {"messages": messages, **generator_inputs}, + "tool_invoker": {"messages": [], "state": state, "streaming_callback": streaming_callback}, + }, + ) + _check_chat_generator_breakpoint(agent_snapshot=agent_snapshot, parent_snapshot=parent_snapshot) # 1. Call the ChatGenerator - result = Pipeline._run_component( - component_name="chat_generator", - component={"instance": self.chat_generator}, - inputs={"messages": messages, **generator_inputs}, - component_visits=component_visits, - parent_span=span, - ) - llm_messages = result["replies"] - state.set("messages", llm_messages) + # We skip the chat generator when restarting from a snapshot where we restart at the ToolInvoker. + if skip_chat_generator: + llm_messages = state.get("messages", [])[-1:] + # We set it to False to ensure that the next iteration will call the chat generator again + skip_chat_generator = False + else: + result = Pipeline._run_component( + component_name="chat_generator", + component={"instance": self.chat_generator}, + inputs={"messages": messages, **generator_inputs}, + component_visits=component_visits, + parent_span=span, + ) + llm_messages = result["replies"] + state.set("messages", llm_messages) # 2. Check if any of the LLM responses contain a tool call or if the LLM is not using tools if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None: @@ -354,15 +367,22 @@ def run( break # check for breakpoint before ToolInvoker - _check_tool_invoker_breakpoint( - agent_breakpoint=break_point, - component_visits=component_visits, - llm_messages=llm_messages, - streaming_callback=streaming_callback, - messages=messages, - kwargs=kwargs, - state=state, - ) + if break_point and break_point.break_point.component_name == "tool_invoker": + agent_snapshot = _create_agent_snapshot( + component_visits=component_visits, + agent_breakpoint=break_point, + component_inputs={ + "chat_generator": {"messages": messages[:-1], **generator_inputs}, + "tool_invoker": { + "messages": llm_messages, + "state": state, + "streaming_callback": streaming_callback, + }, + }, + ) + _check_tool_invoker_breakpoint( + llm_messages=llm_messages, agent_snapshot=agent_snapshot, parent_snapshot=parent_snapshot + ) # 3. Call the ToolInvoker # We only send the messages from the LLM to the tool invoker @@ -406,7 +426,7 @@ async def run_async( # noqa: PLR0915 streaming_callback: Optional[StreamingCallbackT] = None, *, break_point: Optional[AgentBreakpoint] = None, - pipeline_snapshot: Optional[Dict[str, Any]] = None, + snapshot: Optional[AgentSnapshot] = None, **kwargs: Any, ) -> Dict[str, Any]: """ @@ -421,7 +441,7 @@ async def run_async( # noqa: PLR0915 is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint for "tool_invoker". - :param pipeline_snapshot: A dictionary containing the state of a previously saved agent execution. + :param snapshot: A dictionary containing the state of a previously saved agent execution. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :returns: @@ -430,33 +450,45 @@ async def run_async( # noqa: PLR0915 - "last_message": The last message exchanged during the agent's run. - Any additional keys defined in the `state_schema`. :raises RuntimeError: If the Agent component wasn't warmed up before calling `run_async()`. - :raises AgentBreakpointException: If an agent breakpoint is triggered. - + :raises BreakpointException: If an agent breakpoint is triggered. """ + # kwargs can contain the key parent_snapshot. + # We pop it here to avoid passing it into State. We explicitly handle it pass it on if a break point is + # triggered. + parent_snapshot = kwargs.pop("parent_snapshot", None) + if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"): raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run_async()'.") - if break_point and pipeline_snapshot: - msg = ( - "agent_breakpoint and pipeline_snapshot cannot be provided at the same time. " - "The agent run will be aborted." + if break_point and snapshot: + raise ValueError( + "break_point and snapshot cannot be provided at the same time. The agent run will be aborted." ) - raise ValueError(msg) # validate breakpoints if break_point and isinstance(break_point.break_point, ToolBreakpoint): - self._validate_tool_breakpoint_is_valid(break_point) - - # Handle pipeline snapshot if provided - if pipeline_snapshot: - component_visits = pipeline_snapshot.get("pipeline_state", {}).get("component_visits", {}) - state_data = pipeline_snapshot.get("pipeline_state", {}).get("inputs", {}).get("state", {}).get("data", {}) - - # Extract and deserialize messages from pipeline state - raw_messages = pipeline_snapshot.get("pipeline_state", {}).get("inputs", {}).get("messages", messages) - # Convert raw message dictionaries to ChatMessage objects - messages = [ChatMessage.from_dict(msg) if isinstance(msg, dict) else msg for msg in raw_messages] + _validate_tool_breakpoint_is_valid(agent_breakpoint=break_point, tools=self.tools) + + # Handle agent snapshot if provided + if snapshot: + component_visits = snapshot.component_visits + current_inputs = { + "chat_generator": _deserialize_value_with_schema(snapshot.component_inputs["chat_generator"]), + "tool_invoker": _deserialize_value_with_schema(snapshot.component_inputs["tool_invoker"]), + } + state_data = current_inputs["tool_invoker"]["state"].data + if isinstance(snapshot.break_point.break_point, ToolBreakpoint): + # If the break point is a ToolBreakpoint, we need to get the messages from the tool invoker inputs + messages = current_inputs["tool_invoker"]["messages"] + # Needed to know if we should start with the ToolInvoker or ChatGenerator + skip_chat_generator = True + else: + messages = current_inputs["chat_generator"]["messages"] + skip_chat_generator = False + # We also load the streaming_callback from the snapshot if it exists + streaming_callback = current_inputs["chat_generator"].get("streaming_callback", None) else: + skip_chat_generator = False if self.system_prompt is not None: messages = [ChatMessage.from_system(self.system_prompt)] + messages @@ -484,26 +516,35 @@ async def run_async( # noqa: PLR0915 counter = 0 while counter < self.max_agent_steps: - # Check for breakpoint before ChatGenerator - _check_chat_generator_breakpoint( - agent_breakpoint=break_point, - component_visits=component_visits, - messages=messages, - generator_inputs=generator_inputs, - kwargs=kwargs, - state=state, - ) + # check for breakpoint before ChatGenerator + if break_point and break_point.break_point.component_name == "chat_generator": + agent_snapshot = _create_agent_snapshot( + component_visits=component_visits, + agent_breakpoint=break_point, + component_inputs={ + "chat_generator": {"messages": messages, **generator_inputs}, + "tool_invoker": {"messages": [], "state": state, "streaming_callback": streaming_callback}, + }, + ) + _check_chat_generator_breakpoint(agent_snapshot=agent_snapshot, parent_snapshot=parent_snapshot) # 1. Call the ChatGenerator - result = await AsyncPipeline._run_component_async( - component_name="chat_generator", - component={"instance": self.chat_generator}, - component_inputs={"messages": messages, **generator_inputs}, - component_visits=component_visits, - parent_span=span, - ) - llm_messages = result["replies"] - state.set("messages", llm_messages) + # If skip_chat_generator is True, we skip the chat generator and use the messages from the state + # This is useful when the agent is resumed from a snapshot where the chat generator already ran. + if skip_chat_generator: + llm_messages = state.get("messages", [])[-1:] # Get the last message from the state + # We set it to False to ensure that the next iteration will call the chat generator again + skip_chat_generator = False + else: + result = await AsyncPipeline._run_component_async( + component_name="chat_generator", + component={"instance": self.chat_generator}, + component_inputs={"messages": messages, **generator_inputs}, + component_visits=component_visits, + parent_span=span, + ) + llm_messages = result["replies"] + state.set("messages", llm_messages) # 2. Check if any of the LLM responses contain a tool call or if the LLM is not using tools if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None: @@ -511,15 +552,22 @@ async def run_async( # noqa: PLR0915 break # Check for breakpoint before ToolInvoker - _check_tool_invoker_breakpoint( - agent_breakpoint=break_point, - component_visits=component_visits, - llm_messages=llm_messages, - streaming_callback=streaming_callback, - messages=messages, - kwargs=kwargs, - state=state, - ) + if break_point and break_point.break_point.component_name == "tool_invoker": + agent_snapshot = _create_agent_snapshot( + component_visits=component_visits, + agent_breakpoint=break_point, + component_inputs={ + "chat_generator": {"messages": messages[:-1], **generator_inputs}, + "tool_invoker": { + "messages": llm_messages, + "state": state, + "streaming_callback": streaming_callback, + }, + }, + ) + _check_tool_invoker_breakpoint( + llm_messages=llm_messages, agent_snapshot=agent_snapshot, parent_snapshot=parent_snapshot + ) # 3. Call the ToolInvoker # We only send the messages from the LLM to the tool invoker diff --git a/haystack/core/errors.py b/haystack/core/errors.py index a05e1ea705..d3ddbc46c6 100644 --- a/haystack/core/errors.py +++ b/haystack/core/errors.py @@ -100,12 +100,12 @@ def __init__( self, message: str, component: Optional[str] = None, - pipeline_snapshot: Optional[Dict[str, Any]] = None, + inputs: Optional[Dict[str, Any]] = None, results: Optional[Dict[str, Any]] = None, ): super().__init__(message) self.component = component - self.pipeline_snapshot = pipeline_snapshot + self.inputs = inputs self.results = results diff --git a/haystack/core/pipeline/breakpoint.py b/haystack/core/pipeline/breakpoint.py index 41c03d1551..168fe2628e 100644 --- a/haystack/core/pipeline/breakpoint.py +++ b/haystack/core/pipeline/breakpoint.py @@ -2,10 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -# pylint: disable=too-many-return-statements - import json from copy import deepcopy +from dataclasses import replace from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -13,16 +12,25 @@ from networkx import MultiDiGraph from haystack import logging -from haystack.components.agents.state import State from haystack.core.errors import BreakpointException, PipelineInvalidPipelineSnapshotError -from haystack.dataclasses import ChatMessage, StreamingCallbackT -from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import ( + AgentBreakpoint, + AgentSnapshot, + Breakpoint, + PipelineSnapshot, + PipelineState, + ToolBreakpoint, +) +from haystack.tools import Tool, Toolset from haystack.utils.base_serialization import _serialize_value_with_schema logger = logging.getLogger(__name__) -def _validate_break_point(break_point: Union[Breakpoint, AgentBreakpoint], graph: MultiDiGraph) -> None: +def _validate_break_point_against_pipeline( + break_point: Union[Breakpoint, AgentBreakpoint], graph: MultiDiGraph +) -> None: """ Validates the breakpoints passed to the pipeline. @@ -33,12 +41,12 @@ def _validate_break_point(break_point: Union[Breakpoint, AgentBreakpoint], graph # all Breakpoints must refer to a valid component in the pipeline if isinstance(break_point, Breakpoint) and break_point.component_name not in graph.nodes: - raise ValueError(f"pipeline_breakpoint {break_point} is not a registered component in the pipeline") + raise ValueError(f"break_point {break_point} is not a registered component in the pipeline") if isinstance(break_point, AgentBreakpoint): breakpoint_agent_component = graph.nodes.get(break_point.agent_name) if not breakpoint_agent_component: - raise ValueError(f"pipeline_breakpoint {break_point} is not a registered Agent component in the pipeline") + raise ValueError(f"break_point {break_point} is not a registered Agent component in the pipeline") if isinstance(break_point.break_point, ToolBreakpoint): instance = breakpoint_agent_component["instance"] @@ -47,11 +55,11 @@ def _validate_break_point(break_point: Union[Breakpoint, AgentBreakpoint], graph break else: raise ValueError( - f"pipeline_breakpoint {break_point.break_point} is not a registered tool in the Agent component" + f"break_point {break_point.break_point} is not a registered tool in the Agent component" ) -def _validate_components_against_pipeline(pipeline_snapshot: Dict[str, Any], graph: MultiDiGraph) -> None: +def _validate_pipeline_snapshot_against_pipeline(pipeline_snapshot: PipelineSnapshot, graph: MultiDiGraph) -> None: """ Validates that the pipeline_snapshot contains valid configuration for the current pipeline. @@ -61,19 +69,19 @@ def _validate_components_against_pipeline(pipeline_snapshot: Dict[str, Any], gra :param pipeline_snapshot: The saved state to validate. """ - pipeline_state = pipeline_snapshot["pipeline_state"] + pipeline_state = pipeline_snapshot.pipeline_state valid_components = set(graph.nodes.keys()) # Check if the ordered_component_names are valid components in the pipeline - invalid_ordered_components = set(pipeline_state["ordered_component_names"]) - valid_components + invalid_ordered_components = set(pipeline_state.ordered_component_names) - valid_components if invalid_ordered_components: raise PipelineInvalidPipelineSnapshotError( f"Invalid pipeline snapshot: components {invalid_ordered_components} in 'ordered_component_names' " f"are not part of the current pipeline." ) - # Check if the input_data is valid components in the pipeline - serialized_input_data = pipeline_snapshot["input_data"]["serialized_data"] + # Check if the original_input_data is valid components in the pipeline + serialized_input_data = pipeline_snapshot.pipeline_state.original_input_data["serialized_data"] invalid_input_data = set(serialized_input_data.keys()) - valid_components if invalid_input_data: raise PipelineInvalidPipelineSnapshotError( @@ -82,58 +90,26 @@ def _validate_components_against_pipeline(pipeline_snapshot: Dict[str, Any], gra ) # Validate 'component_visits' - invalid_component_visits = set(pipeline_state["component_visits"].keys()) - valid_components + invalid_component_visits = set(pipeline_state.component_visits.keys()) - valid_components if invalid_component_visits: raise PipelineInvalidPipelineSnapshotError( f"Invalid pipeline snapshot: components {invalid_component_visits} in 'component_visits' " f"are not part of the current pipeline." ) - logger.info( - f"Resuming pipeline from component: {pipeline_snapshot['pipeline_breakpoint']['component']} " - f"(visit {pipeline_snapshot['pipeline_breakpoint']['visits']})" - ) - - -def _validate_pipeline_snapshot(pipeline_snapshot: Dict[str, Any]) -> None: - """ - Validates the loaded pipeline snapshot. - - Ensures that the pipeline_snapshot contains required keys: "input_data", "pipeline_breakpoint", - and "pipeline_state". - - Raises: - ValueError: If required keys are missing or the component sets are inconsistent. - """ - - # top-level state has all required keys - required_top_keys = {"input_data", "pipeline_breakpoint", "pipeline_state"} - missing_top = required_top_keys - pipeline_snapshot.keys() - if missing_top: - raise ValueError(f"Invalid pipeline_snapshot: missing required keys {missing_top}") - - # pipeline_state has the necessary keys - pipeline_state = pipeline_snapshot["pipeline_state"] - - required_pipeline_keys = {"inputs", "component_visits", "ordered_component_names"} - missing_pipeline = required_pipeline_keys - pipeline_state.keys() - if missing_pipeline: - raise ValueError(f"Invalid pipeline_state: missing required keys {missing_pipeline}") - - # component_visits and ordered_component_names must be consistent - components_in_state = set(pipeline_state["component_visits"].keys()) - components_in_order = set(pipeline_state["ordered_component_names"]) + if isinstance(pipeline_snapshot.break_point, AgentBreakpoint): + component_name = pipeline_snapshot.break_point.agent_name + else: + component_name = pipeline_snapshot.break_point.component_name - if components_in_state != components_in_order: - raise ValueError( - f"Inconsistent state: components in pipeline_state['component_visits'] {components_in_state} " - f"do not match components in ordered_component_names {components_in_order}" - ) + visit_count = pipeline_snapshot.pipeline_state.component_visits[component_name] - logger.info("Pipeline snapshot validated successfully.") + logger.info( + "Resuming pipeline from {component} with visit count {visits}", component=component_name, visits=visit_count + ) -def load_pipeline_snapshot(file_path: Union[str, Path]) -> Dict[str, Any]: +def load_pipeline_snapshot(file_path: Union[str, Path]) -> PipelineSnapshot: """ Load a saved pipeline snapshot. @@ -146,7 +122,7 @@ def load_pipeline_snapshot(file_path: Union[str, Path]) -> Dict[str, Any]: try: with open(file_path, "r", encoding="utf-8") as f: - pipeline_snapshot = json.load(f) + pipeline_snapshot_dict = json.load(f) except FileNotFoundError: raise FileNotFoundError(f"File not found: {file_path}") except json.JSONDecodeError as e: @@ -155,7 +131,7 @@ def load_pipeline_snapshot(file_path: Union[str, Path]) -> Dict[str, Any]: raise IOError(f"Error reading {file_path}: {str(e)}") try: - _validate_pipeline_snapshot(pipeline_snapshot=pipeline_snapshot) + pipeline_snapshot = PipelineSnapshot.from_dict(pipeline_snapshot_dict) except ValueError as e: raise ValueError(f"Invalid pipeline snapshot from {file_path}: {str(e)}") @@ -163,89 +139,84 @@ def load_pipeline_snapshot(file_path: Union[str, Path]) -> Dict[str, Any]: return pipeline_snapshot -def _process_main_pipeline_state(main_pipeline_state: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: - """ - Process and serialize main pipeline state for agent breakpoints. - - :param main_pipeline_state: Dictionary containing main pipeline state with keys: "component_visits", - "ordered_component_names", "original_input_data", and "inputs". - :returns: Processed main pipeline state or None if not available or invalid. - """ - if not main_pipeline_state: - return None - - original_input_data = main_pipeline_state.get("original_input_data") - inputs = main_pipeline_state.get("inputs") - - if not (original_input_data and inputs): - return None - - return { - "component_visits": main_pipeline_state.get("component_visits"), - "ordered_component_names": main_pipeline_state.get("ordered_component_names"), - "original_input_data": _serialize_value_with_schema(_transform_json_structure(original_input_data)), - "inputs": _serialize_value_with_schema(_transform_json_structure(inputs)), - } - - def _save_pipeline_snapshot_to_file( - *, pipeline_snapshot: Dict[str, Any], debug_path: Union[str, Path], dt: datetime, component_name: str + *, pipeline_snapshot: PipelineSnapshot, snapshot_file_path: Union[str, Path], dt: datetime ) -> None: """ Save the pipeline snapshot dictionary to a JSON file. :param pipeline_snapshot: The pipeline snapshot to save. - :param debug_path: The path where to save the file. + :param snapshot_file_path: The path where to save the file. :param dt: The datetime object for timestamping. - :param component_name: Name of the component that triggered the breakpoint. :raises: - ValueError: If the debug_path is not a string or a Path object. + ValueError: If the snapshot_file_path is not a string or a Path object. Exception: If saving the JSON snapshot fails. """ - debug_path = Path(debug_path) if isinstance(debug_path, str) else debug_path - if not isinstance(debug_path, Path): + snapshot_file_path = Path(snapshot_file_path) if isinstance(snapshot_file_path, str) else snapshot_file_path + if not isinstance(snapshot_file_path, Path): raise ValueError("Debug path must be a string or a Path object.") - debug_path.mkdir(exist_ok=True) + snapshot_file_path.mkdir(exist_ok=True) # Generate filename # We check if the agent_name is provided to differentiate between agent and non-agent breakpoints - if pipeline_snapshot["agent_name"] is not None: - file_name = f"{pipeline_snapshot['agent_name']}_{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json" + if isinstance(pipeline_snapshot.break_point, AgentBreakpoint): + agent_name = pipeline_snapshot.break_point.agent_name + component_name = pipeline_snapshot.break_point.break_point.component_name + file_name = f"{agent_name}_{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json" else: + component_name = pipeline_snapshot.break_point.component_name file_name = f"{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json" try: - with open(debug_path / file_name, "w") as f_out: - json.dump(pipeline_snapshot, f_out, indent=2) + with open(snapshot_file_path / file_name, "w") as f_out: + json.dump(pipeline_snapshot.to_dict(), f_out, indent=2) logger.info(f"Pipeline snapshot saved at: {file_name}") except Exception as e: logger.error(f"Failed to save pipeline snapshot: {str(e)}") raise -def _save_snapshot( +def _create_pipeline_snapshot( *, inputs: Dict[str, Any], - component_name: str, + break_point: Union[AgentBreakpoint, Breakpoint], component_visits: Dict[str, int], - debug_path: Optional[Union[str, Path]] = None, original_input_data: Optional[Dict[str, Any]] = None, ordered_component_names: Optional[List[str]] = None, - agent_name: Optional[str] = None, - main_pipeline_state: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: +) -> PipelineSnapshot: """ - Save the pipeline snapshot to a file. + Create a snapshot of the pipeline at the point where the breakpoint was triggered. :param inputs: The current pipeline snapshot inputs. - :param component_name: The name of the component that triggered the breakpoint. + :param break_point: The breakpoint that triggered the snapshot, can be AgentBreakpoint or Breakpoint. :param component_visits: The visit count of the component that triggered the breakpoint. - :param debug_path: The path to save the snapshot to. :param original_input_data: The original input data. :param ordered_component_names: The ordered component names. - :param main_pipeline_state: Dictionary containing main pipeline state with keys: "component_visits", - "ordered_component_names", "original_input_data", and "inputs". + """ + dt = datetime.now() + + transformed_original_input_data = _transform_json_structure(original_input_data) + transformed_inputs = _transform_json_structure(inputs) + + pipeline_snapshot = PipelineSnapshot( + pipeline_state=PipelineState( + original_input_data=_serialize_value_with_schema(transformed_original_input_data), + inputs=_serialize_value_with_schema(transformed_inputs), # current pipeline inputs + component_visits=component_visits, + ordered_component_names=ordered_component_names or [], + ), + timestamp=dt, + break_point=break_point, + ) + return pipeline_snapshot + + +def _save_pipeline_snapshot(pipeline_snapshot: PipelineSnapshot) -> PipelineSnapshot: + """ + Save the pipeline snapshot to a file. + + :param pipeline_snapshot: The pipeline snapshot to save. :returns: The dictionary containing the snapshot of the pipeline containing the following keys: @@ -257,37 +228,17 @@ def _save_snapshot( - component_visits: The visit count of the components when the breakpoint was triggered. - ordered_component_names: The order of components in the pipeline. """ - dt = datetime.now() - - # remove duplicated information - if original_input_data: - original_input_data.pop("main_pipeline_state", None) - - transformed_original_input_data = _transform_json_structure(original_input_data) - transformed_inputs = _transform_json_structure(inputs) - - pipeline_snapshot = { - # related to the main pipeline where the agent running as a breakpoint - only used with AgentBreakpoint - "agent_name": agent_name if agent_name else None, - "main_pipeline_state": _process_main_pipeline_state(main_pipeline_state) if agent_name else None, - # breakpoint - information for the component that triggered the breakpoint, can also be an Agent - "component_name": component_name, - "input_data": _serialize_value_with_schema(transformed_original_input_data), # original input data - "timestamp": dt.isoformat(), - "pipeline_breakpoint": {"component": component_name, "visits": component_visits[component_name]}, - "pipeline_state": { - "inputs": _serialize_value_with_schema(transformed_inputs), # current pipeline state inputs - "component_visits": component_visits, - "ordered_component_names": ordered_component_names, - }, - } - - if not debug_path: - return pipeline_snapshot + break_point = pipeline_snapshot.break_point + if isinstance(break_point, AgentBreakpoint): + snapshot_file_path = break_point.break_point.snapshot_file_path + else: + snapshot_file_path = break_point.snapshot_file_path - _save_pipeline_snapshot_to_file( - pipeline_snapshot=pipeline_snapshot, debug_path=debug_path, dt=dt, component_name=component_name - ) + if snapshot_file_path is not None: + dt = pipeline_snapshot.timestamp or datetime.now() + _save_pipeline_snapshot_to_file( + pipeline_snapshot=pipeline_snapshot, snapshot_file_path=snapshot_file_path, dt=dt + ) return pipeline_snapshot @@ -322,212 +273,167 @@ def _transform_json_structure(data: Union[Dict[str, Any], List[Any], Any]) -> An return data -def _handle_agent_break_point( - *, - break_point: AgentBreakpoint, - component_name: str, - component_inputs: Dict[str, Any], - inputs: Dict[str, Any], - component_visits: Dict[str, int], - ordered_component_names: list, - data: Dict[str, Any], -) -> Dict[str, Any]: +def _trigger_break_point(*, pipeline_snapshot: PipelineSnapshot, pipeline_outputs: Dict[str, Any]) -> None: """ - Handle agent-specific breakpoint logic. - - :param break_point: The agent breakpoint to handle - :param component_name: Name of the current component - :param component_inputs: Inputs for the current component - :param inputs: Global pipeline inputs - :param component_visits: Component visit counts - :param ordered_component_names: Ordered list of component names - :param data: Original pipeline data - :return: Updated component inputs + Trigger a breakpoint by saving a snapshot and raising exception. + + :param pipeline_snapshot: The current pipeline snapshot containing the state and break point + :param pipeline_outputs: Current pipeline outputs + :raises PipelineBreakpointException: When breakpoint is triggered """ - component_inputs["break_point"] = break_point + _save_pipeline_snapshot(pipeline_snapshot=pipeline_snapshot) - # Store pipeline state for agent resume - state_inputs_serialised = deepcopy(inputs) - state_inputs_serialised[component_name] = deepcopy(component_inputs) - component_inputs["main_pipeline_state"] = { - "inputs": state_inputs_serialised, - "component_visits": component_visits, - "ordered_component_names": ordered_component_names, - "original_input_data": data, - } + if isinstance(pipeline_snapshot.break_point, Breakpoint): + component_name = pipeline_snapshot.break_point.component_name + else: + component_name = pipeline_snapshot.break_point.agent_name - return component_inputs + component_visits = pipeline_snapshot.pipeline_state.component_visits + msg = f"Breaking at component {component_name} at visit count {component_visits[component_name]}" + raise BreakpointException( + message=msg, component=component_name, inputs=pipeline_snapshot.pipeline_state.inputs, results=pipeline_outputs + ) -def _check_regular_break_point(break_point: Breakpoint, component_name: str, component_visits: Dict[str, int]) -> bool: +def _create_agent_snapshot( + *, component_visits: Dict[str, int], agent_breakpoint: AgentBreakpoint, component_inputs: Dict[str, Any] +) -> AgentSnapshot: """ - Check if a regular breakpoint should be triggered. + Create a snapshot of the agent's state. - :param break_point: The breakpoint to check - :param component_name: Name of the current component - :param component_visits: Component visit counts - :return: True if breakpoint should be triggered + :param component_visits: The visit counts for the agent's components. + :param agent_breakpoint: AgentBreakpoint object containing breakpoints + :return: An AgentSnapshot containing the agent's state and component visits. """ - return break_point.component_name == component_name and break_point.visit_count == component_visits[component_name] + return AgentSnapshot( + component_inputs={ + "chat_generator": _serialize_value_with_schema(deepcopy(component_inputs["chat_generator"])), + "tool_invoker": _serialize_value_with_schema(deepcopy(component_inputs["tool_invoker"])), + }, + component_visits=component_visits, + break_point=agent_breakpoint, + timestamp=datetime.now(), + ) -def _trigger_break_point( - *, - component_name: str, - component_inputs: Dict[str, Any], - inputs: Dict[str, Any], - component_visits: Dict[str, int], - debug_path: Optional[Union[str, Path]], - data: Dict[str, Any], - ordered_component_names: list, - pipeline_outputs: Dict[str, Any], -) -> None: +def _validate_tool_breakpoint_is_valid(agent_breakpoint: AgentBreakpoint, tools: Union[List[Tool], Toolset]) -> None: """ - Trigger a breakpoint by saving a snapshot and raising exception. + Validates the AgentBreakpoint passed to the agent. - :param component_name: Name of the component where breakpoint is triggered - :param component_inputs: Inputs for the current component - :param inputs: Global pipeline inputs - :param component_visits: Component visit counts - :param debug_path: Path for debug files - :param data: Original pipeline data - :param ordered_component_names: Ordered list of component names - :param pipeline_outputs: Current pipeline outputs - :raises PipelineBreakpointException: When breakpoint is triggered + Validates that the tool name in ToolBreakpoints correspond to a tool available in the agent. + + :param agent_breakpoint: AgentBreakpoint object containing breakpoints for the agent components. + :param tools: List of Tool objects or a Toolset that the agent can use. + :raises ValueError: If any tool name in ToolBreakpoints is not available in the agent's tools. """ - pipeline_snapshot_inputs_serialised = deepcopy(inputs) - pipeline_snapshot_inputs_serialised[component_name] = deepcopy(component_inputs) - _save_snapshot( - inputs=pipeline_snapshot_inputs_serialised, - component_name=str(component_name), - component_visits=component_visits, - debug_path=debug_path, - original_input_data=data, - ordered_component_names=ordered_component_names, - ) - msg = f"Breaking at component {component_name} at visit count {component_visits[component_name]}" - raise BreakpointException( - message=msg, - component=component_name, - pipeline_snapshot=pipeline_snapshot_inputs_serialised, - results=pipeline_outputs, - ) + available_tool_names = {tool.name for tool in tools} + tool_breakpoint = agent_breakpoint.break_point + # Assert added for mypy to pass, but this is already checked before this function is called + assert isinstance(tool_breakpoint, ToolBreakpoint) + if tool_breakpoint.tool_name and tool_breakpoint.tool_name not in available_tool_names: + raise ValueError(f"Tool '{tool_breakpoint.tool_name}' is not available in the agent's tools") def _check_chat_generator_breakpoint( - *, - agent_breakpoint: Optional[AgentBreakpoint], - component_visits: Dict[str, int], - messages: List[ChatMessage], - generator_inputs: Dict[str, Any], - kwargs: Dict[str, Any], - state: State, + *, agent_snapshot: AgentSnapshot, parent_snapshot: Optional[PipelineSnapshot] ) -> None: """ Check for breakpoint before calling the ChatGenerator. - :param agent_breakpoint: AgentBreakpoint object containing breakpoints - :param component_visits: Dictionary tracking component visit counts - :param messages: Current messages to process - :param generator_inputs: Inputs for the chat generator - :param kwargs: Additional keyword arguments - :param state: The current State of the agent - :raises AgentBreakpointException: If a breakpoint is triggered + :param agent_snapshot: AgentSnapshot object containing the agent's state and breakpoints + :param parent_snapshot: Optional parent snapshot containing the state of the pipeline that houses the agent. + :raises BreakpointException: If a breakpoint is triggered """ - # We also check component_name since ToolBreakpoint is a subclass of Breakpoint - if ( - agent_breakpoint - and isinstance(agent_breakpoint.break_point, Breakpoint) - and agent_breakpoint.break_point.component_name == "chat_generator" - ): - break_point = agent_breakpoint.break_point - if component_visits[break_point.component_name] == break_point.visit_count: - chat_generator_inputs = deepcopy({"messages": messages, **generator_inputs}) - _save_snapshot( - inputs=chat_generator_inputs, - component_name=break_point.component_name, - component_visits=component_visits, # these are the component visits of the agent components - debug_path=break_point.debug_path, - original_input_data={"messages": messages, **kwargs}, - ordered_component_names=["chat_generator", "tool_invoker"], - agent_name=agent_breakpoint.agent_name or "isolated_agent", - main_pipeline_state=state.data.get("main_pipeline_state", {}), - ) - msg = f"Breaking at {break_point.component_name} visit count {component_visits[break_point.component_name]}" - logger.info(msg) - raise BreakpointException( - message=msg, - component=break_point.component_name, - pipeline_snapshot=chat_generator_inputs, - results=state.data, - ) + break_point = agent_snapshot.break_point.break_point + if agent_snapshot.component_visits[break_point.component_name] != break_point.visit_count: + return + + if parent_snapshot is None: + # Create an empty pipeline snapshot if no parent snapshot is provided + final_snapshot = PipelineSnapshot( + pipeline_state=PipelineState( + original_input_data={}, inputs={}, component_visits={}, ordered_component_names=[] + ), + timestamp=agent_snapshot.timestamp, + break_point=agent_snapshot.break_point, + agent_snapshot=agent_snapshot, + ) + else: + final_snapshot = replace(parent_snapshot, agent_snapshot=agent_snapshot) + _save_pipeline_snapshot(pipeline_snapshot=final_snapshot) + + msg = ( + f"Breaking at {break_point.component_name} visit count " + "{agent_snapshot.component_visits[break_point.component_name]}" + ) + logger.info(msg) + raise BreakpointException( + message=msg, + component=break_point.component_name, + inputs=agent_snapshot.component_inputs, + results=agent_snapshot.component_inputs["tool_invoker"]["serialized_data"]["state"], + ) def _check_tool_invoker_breakpoint( - *, - agent_breakpoint: Optional[AgentBreakpoint], - component_visits: Dict[str, int], - llm_messages: List[ChatMessage], - streaming_callback: Optional[StreamingCallbackT], - messages: List[ChatMessage], - kwargs: Dict[str, Any], - state: State, + *, llm_messages: List[ChatMessage], agent_snapshot: AgentSnapshot, parent_snapshot: Optional[PipelineSnapshot] ) -> None: """ Check for breakpoint before calling the ToolInvoker. - :param agent_breakpoint: AgentBreakpoint object containing breakpoints - :param component_visits: Dictionary tracking component visit counts - :param llm_messages: Messages from the LLM - :param state: Current agent state - :param streaming_callback: Streaming callback function - :param messages: Original messages - :param kwargs: Additional keyword arguments - :raises AgentBreakpointException: If a breakpoint is triggered + :param llm_messages: List of ChatMessage objects containing potential tool calls. + :param agent_snapshot: AgentSnapshot object containing the agent's state and breakpoints. + :param parent_snapshot: Optional parent snapshot containing the state of the pipeline that houses the agent. + :raises BreakpointException: If a breakpoint is triggered """ + if not isinstance(agent_snapshot.break_point.break_point, ToolBreakpoint): + return + + tool_breakpoint = agent_snapshot.break_point.break_point + # Check if the visit count matches + if agent_snapshot.component_visits[tool_breakpoint.component_name] != tool_breakpoint.visit_count: + return + + # Check if we should break for this specific tool or all tools + if tool_breakpoint.tool_name is None: + # Break for any tool call + should_break = any(msg.tool_call for msg in llm_messages) + else: + # Break only for the specific tool + should_break = any( + msg.tool_call and msg.tool_call.tool_name == tool_breakpoint.tool_name for msg in llm_messages + ) - if agent_breakpoint and isinstance(agent_breakpoint.break_point, ToolBreakpoint): - tool_breakpoint = agent_breakpoint.break_point - # Check if the visit count matches - if component_visits[tool_breakpoint.component_name] == tool_breakpoint.visit_count: - # Check if we should break for this specific tool or all tools - should_break = False - if tool_breakpoint.tool_name is None: - # Break for any tool call - should_break = any(msg.tool_call for msg in llm_messages) - else: - # Break only for the specific tool - should_break = any( - msg.tool_call and msg.tool_call.tool_name == tool_breakpoint.tool_name for msg in llm_messages - ) + if not should_break: + return # No breakpoint triggered + + if parent_snapshot is None: + # Create an empty pipeline snapshot if no parent snapshot is provided + final_snapshot = PipelineSnapshot( + pipeline_state=PipelineState( + original_input_data={}, inputs={}, component_visits={}, ordered_component_names=[] + ), + timestamp=agent_snapshot.timestamp, + break_point=agent_snapshot.break_point, + agent_snapshot=agent_snapshot, + ) + else: + final_snapshot = replace(parent_snapshot, agent_snapshot=agent_snapshot) + _save_pipeline_snapshot(pipeline_snapshot=final_snapshot) - if should_break: - tool_invoker_inputs = deepcopy( - {"messages": llm_messages, "state": state, "streaming_callback": streaming_callback} - ) - _save_snapshot( - inputs=tool_invoker_inputs, - component_name=tool_breakpoint.component_name, - component_visits=component_visits, - debug_path=tool_breakpoint.debug_path, - original_input_data={"messages": messages, **kwargs}, - ordered_component_names=["chat_generator", "tool_invoker"], - agent_name=agent_breakpoint.agent_name or "isolated_agent", - main_pipeline_state=state.data.get("main_pipeline_state", {}), - ) - msg = ( - f"Breaking at {tool_breakpoint.component_name} visit count " - f"{component_visits[tool_breakpoint.component_name]}" - ) - if tool_breakpoint.tool_name: - msg += f" for tool {tool_breakpoint.tool_name}" - logger.info(msg) - - raise BreakpointException( - message=msg, - component=tool_breakpoint.component_name, - pipeline_snapshot=tool_invoker_inputs, - results=state.data, - ) + msg = ( + f"Breaking at {tool_breakpoint.component_name} visit count " + f"{agent_snapshot.component_visits[tool_breakpoint.component_name]}" + ) + if tool_breakpoint.tool_name: + msg += f" for tool {tool_breakpoint.tool_name}" + logger.info(msg) + + raise BreakpointException( + message=msg, + component=tool_breakpoint.component_name, + inputs=agent_snapshot.component_inputs, + results=agent_snapshot.component_inputs["tool_invoker"]["serialized_data"]["state"], + ) diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index 7fb92fe35c..5cd2b1ad5f 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -2,8 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -# pylint: disable=too-many-positional-arguments - from copy import deepcopy from typing import Any, Dict, Mapping, Optional, Set, Union @@ -18,14 +16,13 @@ PipelineBase, ) from haystack.core.pipeline.breakpoint import ( - _check_regular_break_point, - _handle_agent_break_point, + _create_pipeline_snapshot, _trigger_break_point, - _validate_break_point, - _validate_components_against_pipeline, + _validate_break_point_against_pipeline, + _validate_pipeline_snapshot_against_pipeline, ) from haystack.core.pipeline.utils import _deepcopy_with_exceptions -from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint +from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, PipelineSnapshot from haystack.telemetry import pipeline_running from haystack.utils import _deserialize_value_with_schema @@ -89,63 +86,13 @@ def _run_component( return component_output - def _handle_resume_pipeline( - self, pipeline_snapshot: Dict[str, Any] - ) -> tuple[Dict[str, int], Dict[str, Any], bool, list]: - """ - Handle resuming the pipeline from a pipeline snapshot. - - :param pipeline_snapshot: The snapshot of the pipeline to resume from. - :return: Tuple of (component_visits, data, resume_agent_in_pipeline, ordered_component_names) - """ - if pipeline_snapshot.get("agent_name"): - return self._handle_resume_from_agent(pipeline_snapshot) - else: - return self._handle_resume_from_pipeline_snapshot(pipeline_snapshot) - - def _handle_resume_from_agent( - self, pipeline_snapshot: Dict[str, Any] - ) -> tuple[Dict[str, int], Dict[str, Any], bool, list]: - """ - Handle resuming the pipeline at a specific Agent component. - - :param pipeline_snapshot: The snapshot of the pipeline to resume from. - :return: Tuple of (component_visits, data, resume_agent_in_pipeline, ordered_component_names) - """ - agent_name = pipeline_snapshot["agent_name"] - for name, component in self.graph.nodes.items(): - if component["instance"].__class__.__name__ == "Agent" and name == agent_name: - main_pipeline_state = pipeline_snapshot.get("main_pipeline_state", {}) - component_visits = main_pipeline_state.get("component_visits", {}) - ordered_component_names = main_pipeline_state.get("ordered_component_names", []) - data = _deserialize_value_with_schema(main_pipeline_state.get("inputs", {})) - return component_visits, data, True, ordered_component_names - - # Fallback to regular resume if agent not found - return self._handle_resume_from_pipeline_snapshot(pipeline_snapshot) - - def _handle_resume_from_pipeline_snapshot( - self, pipeline_snapshot: Dict[str, Any] - ) -> tuple[Dict[str, int], Dict[str, Any], bool, list]: - """ - Handle resuming the pipeline from a regular pipeline snapshot. - - :param pipeline_snapshot: The snapshot of the pipeline to resume from. - :return: Tuple of (component_visits, data, resume_agent_in_pipeline, ordered_component_names) - """ - component_visits, data, pipeline_snapshot, ordered_component_names = self._inject_pipeline_snapshot_into_graph( - pipeline_snapshot=pipeline_snapshot - ) - data = _deserialize_value_with_schema(pipeline_snapshot["pipeline_state"]["inputs"]) - return component_visits, data, False, ordered_component_names - - def run( # noqa: PLR0915, PLR0912, C901 + def run( # noqa: PLR0915, PLR0912, C901, pylint: disable=too-many-branches self, data: Dict[str, Any], include_outputs_from: Optional[Set[str]] = None, *, break_point: Optional[Union[Breakpoint, AgentBreakpoint]] = None, - pipeline_snapshot: Optional[Dict[str, Any]] = None, + pipeline_snapshot: Optional[PipelineSnapshot] = None, ) -> Dict[str, Any]: """ Runs the Pipeline with given input data. @@ -226,7 +173,7 @@ def run( # noqa: PLR0915, PLR0912, C901 A set of breakpoints that can be used to debug the pipeline execution. :param pipeline_snapshot: - A dictionary containing the state of a previously saved pipeline execution. + A dictionary containing a snapshot of a previously saved pipeline execution. :returns: A dictionary where each entry corresponds to a component name @@ -256,7 +203,7 @@ def run( # noqa: PLR0915, PLR0912, C901 # make sure all breakpoints are valid, i.e. reference components in the pipeline if break_point: - _validate_break_point(break_point, self.graph) + _validate_break_point_against_pipeline(break_point, self.graph) # TODO: Remove this warmup once we can check reliably whether a component has been warmed up or not # As of now it's here to make sure we don't have failing tests that assume warm_up() is called in run() @@ -278,13 +225,16 @@ def run( # noqa: PLR0915, PLR0912, C901 # We track component visits to decide if a component can run. component_visits = dict.fromkeys(ordered_component_names, 0) - resume_agent_in_pipeline = False else: + # Validate the pipeline snapshot against the current pipeline graph + _validate_pipeline_snapshot_against_pipeline(pipeline_snapshot, self.graph) + # Handle resuming the pipeline from a snapshot - component_visits, data, resume_agent_in_pipeline, ordered_component_names = self._handle_resume_pipeline( - pipeline_snapshot - ) + component_visits = pipeline_snapshot.pipeline_state.component_visits + ordered_component_names = pipeline_snapshot.pipeline_state.ordered_component_names + data = self._prepare_component_input_data(pipeline_snapshot.pipeline_state.inputs) + data = _deserialize_value_with_schema(pipeline_snapshot.pipeline_state.inputs) cached_topological_sort = None # We need to access a component's receivers multiple times during a pipeline run. @@ -347,9 +297,14 @@ def run( # noqa: PLR0915, PLR0912, C901 component_name, component_visits[component_name] ) - is_resume = bool( - pipeline_snapshot and pipeline_snapshot["pipeline_breakpoint"]["component"] == component_name - ) + if pipeline_snapshot: + if isinstance(pipeline_snapshot.break_point, AgentBreakpoint): + name_to_check = pipeline_snapshot.break_point.agent_name + else: + name_to_check = pipeline_snapshot.break_point.component_name + is_resume = name_to_check == component_name + else: + is_resume = False component_inputs = self._consume_component_inputs( component_name=component_name, component=component, inputs=inputs, is_resume=is_resume ) @@ -362,44 +317,46 @@ def run( # noqa: PLR0915, PLR0912, C901 # Scenario 1: Pipeline snapshot is provided to resume the pipeline at a specific component # Deserialize the component_inputs if they are passed in the pipeline_snapshot. # this check will prevent other component_inputs generated at runtime from being deserialized - if pipeline_snapshot and component_name in pipeline_snapshot["pipeline_state"]["inputs"].keys(): + if pipeline_snapshot and component_name in pipeline_snapshot.pipeline_state.inputs.keys(): for key, value in component_inputs.items(): component_inputs[key] = _deserialize_value_with_schema(value) - # Scenario 2: an AgentBreakpoint is provided to stop the pipeline at a specific component - if isinstance(break_point, AgentBreakpoint) and component_name == break_point.agent_name: - component_inputs = _handle_agent_break_point( + # If we are resuming from an AgentBreakpoint, we inject the agent_snapshot into the Agents inputs + if ( + pipeline_snapshot + and isinstance(pipeline_snapshot.break_point, AgentBreakpoint) + and component_name == pipeline_snapshot.break_point.agent_name + ): + component_inputs["snapshot"] = pipeline_snapshot.agent_snapshot + component_inputs["break_point"] = None + + # Scenario 2: A breakpoint is provided to stop the pipeline at a specific component + if break_point: + # Create a PipelineSnapshot to capture the current state of the pipeline + pipeline_snapshot_inputs_serialised = deepcopy(inputs) + pipeline_snapshot_inputs_serialised[component_name] = deepcopy(component_inputs) + new_pipeline_snapshot = _create_pipeline_snapshot( + inputs=pipeline_snapshot_inputs_serialised, break_point=break_point, - component_name=component_name, - component_inputs=component_inputs, - inputs=inputs, component_visits=component_visits, + original_input_data=data, ordered_component_names=ordered_component_names, - data=data, ) - # Scenario 3: a regular breakpoint is provided to stop the pipeline at a specific component and - # visit count - if isinstance(break_point, Breakpoint): - breakpoint_triggered = _check_regular_break_point( - break_point=break_point, component_name=component_name, component_visits=component_visits - ) - if breakpoint_triggered: - _trigger_break_point( - component_name=component_name, - component_inputs=component_inputs, - inputs=inputs, - component_visits=component_visits, - debug_path=break_point.debug_path, - data=data, - ordered_component_names=ordered_component_names, - pipeline_outputs=pipeline_outputs, - ) - - if resume_agent_in_pipeline: - # inject the pipeline_snapshot into the component (the Agent) inputs - component_inputs["pipeline_snapshot"] = pipeline_snapshot - component_inputs["break_point"] = None + # Scenario 2.1: an AgentBreakpoint is provided to stop the pipeline at a specific component + if isinstance(break_point, AgentBreakpoint) and component_name == break_point.agent_name: + # Add the break_point and pipeline_snapshot to the agent's component inputs + component_inputs["break_point"] = break_point + component_inputs["parent_snapshot"] = new_pipeline_snapshot + + # Scenario 2.2: a regular breakpoint is provided to stop the pipeline at a specific component and + # visit count + if ( + isinstance(break_point, Breakpoint) + and break_point.component_name == component_name + and break_point.visit_count == component_visits[component_name] + ): + _trigger_break_point(pipeline_snapshot=new_pipeline_snapshot, pipeline_outputs=pipeline_outputs) component_outputs = self._run_component( component_name=component_name, @@ -433,27 +390,3 @@ def run( # noqa: PLR0915, PLR0912, C901 ) return pipeline_outputs - - def _inject_pipeline_snapshot_into_graph( - self, pipeline_snapshot: Dict[str, Any] - ) -> tuple[Dict[str, int], Dict[str, Any], Dict[str, Any], list]: - """ - Injects the pipeline snapshot into the current pipeline graph. - """ - # We previously check if the pipeline_snapshot is None but this is needed to prevent a typing error - if not pipeline_snapshot: - raise PipelineInvalidPipelineSnapshotError("Cannot inject pipeline_snapshot: pipeline_snapshot is None") - - # check if the pipeline_snapshot is valid for the current pipeline - _validate_components_against_pipeline(pipeline_snapshot, self.graph) - - data = self._prepare_component_input_data(pipeline_snapshot["pipeline_state"]["inputs"]) - component_visits = pipeline_snapshot["pipeline_state"]["component_visits"] - ordered_component_names = pipeline_snapshot["pipeline_state"]["ordered_component_names"] - logger.info( - "Resuming pipeline from {component} with visit count {visits}", - component=pipeline_snapshot["pipeline_breakpoint"]["component"], - visits=pipeline_snapshot["pipeline_breakpoint"]["visits"], - ) - - return component_visits, data, pipeline_snapshot, ordered_component_names diff --git a/haystack/dataclasses/breakpoints.py b/haystack/dataclasses/breakpoints.py index 263e5e1986..d67e02e644 100644 --- a/haystack/dataclasses/breakpoints.py +++ b/haystack/dataclasses/breakpoints.py @@ -2,8 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass -from typing import Optional, Union +from dataclasses import asdict, dataclass +from datetime import datetime +from typing import Any, Dict, List, Optional, Union @dataclass(frozen=True) @@ -13,14 +14,32 @@ class Breakpoint: :param component_name: The name of the component where the breakpoint is set. :param visit_count: The number of times the component must be visited before the breakpoint is triggered. - :param debug_path: Optional path to store a snapshot of the pipeline when the breakpoint is hit. + :param snapshot_file_path: Optional path to store a snapshot of the pipeline when the breakpoint is hit. This is useful for debugging purposes, allowing you to inspect the state of the pipeline at the time of the breakpoint and to resume execution from that point. """ component_name: str visit_count: int = 0 - debug_path: Optional[str] = None + snapshot_file_path: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the Breakpoint to a dictionary representation. + + :return: A dictionary containing the component name, visit count, and debug path. + """ + return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "Breakpoint": + """ + Populate the Breakpoint from a dictionary representation. + + :param data: A dictionary containing the component name, visit count, and debug path. + :return: An instance of Breakpoint. + """ + return cls(**data) @dataclass(frozen=True) @@ -36,12 +55,12 @@ class ToolBreakpoint(Breakpoint): tool_name: Optional[str] = None - def __str__(self): + def __str__(self) -> str: tool_str = f", tool_name={self.tool_name}" if self.tool_name else ", tool_name=ALL_TOOLS" return f"ToolBreakpoint(component_name={self.component_name}, visit_count={self.visit_count}{tool_str})" -@dataclass +@dataclass(frozen=True) class AgentBreakpoint: """ A dataclass representing a breakpoint tied to an Agent’s execution. @@ -68,3 +87,156 @@ def __post_init__(self): if isinstance(self.break_point, ToolBreakpoint) and self.break_point.component_name != "tool_invoker": raise ValueError("If the break_point is a ToolBreakpoint, it must have the component_name 'tool_invoker'.") + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the AgentBreakpoint to a dictionary representation. + + :return: A dictionary containing the agent name and the breakpoint details. + """ + return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "AgentBreakpoint": + """ + Populate the AgentBreakpoint from a dictionary representation. + + :param data: A dictionary containing the agent name and the breakpoint details. + :return: An instance of AgentBreakpoint. + """ + break_point_data = data["break_point"] + break_point: Union[Breakpoint, ToolBreakpoint] + if "tool_name" in break_point_data: + break_point = ToolBreakpoint(**break_point_data) + else: + break_point = Breakpoint(**break_point_data) + return cls(agent_name=data["agent_name"], break_point=break_point) + + +@dataclass +class AgentSnapshot: + component_inputs: Dict[str, Any] + component_visits: Dict[str, int] + break_point: AgentBreakpoint + timestamp: Optional[datetime] = None + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the AgentSnapshot to a dictionary representation. + + :return: A dictionary containing the agent state, timestamp, and breakpoint. + """ + return { + "component_inputs": self.component_inputs, + "component_visits": self.component_visits, + "break_point": self.break_point.to_dict(), + "timestamp": self.timestamp.isoformat() if self.timestamp else None, + } + + @classmethod + def from_dict(cls, data: dict) -> "AgentSnapshot": + """ + Populate the AgentSnapshot from a dictionary representation. + + :param data: A dictionary containing the agent state, timestamp, and breakpoint. + :return: An instance of AgentSnapshot. + """ + return cls( + component_inputs=data["component_inputs"], + component_visits=data["component_visits"], + break_point=AgentBreakpoint.from_dict(data["break_point"]), + timestamp=datetime.fromisoformat(data["timestamp"]) if data.get("timestamp") else None, + ) + + +@dataclass +class PipelineState: + """ + A dataclass to hold the state of the pipeline at a specific point in time. + + :param component_visits: A dictionary mapping component names to their visit counts. + :param ordered_component_names: A list of component names in the order they were visited. + :param original_input_data: The original input data provided to the pipeline. + :param inputs: The inputs processed by the pipeline at the time of the snapshot. + """ + + original_input_data: Dict[str, Any] + inputs: Dict[str, Any] + component_visits: Dict[str, int] + ordered_component_names: List[str] + + def __post_init__(self): + components_in_state = set(self.component_visits.keys()) + components_in_order = set(self.ordered_component_names) + + if components_in_state != components_in_order: + raise ValueError( + f"Inconsistent state: components in PipelineState.component_visits {components_in_state} " + f"do not match components in PipelineState.ordered_component_names {components_in_order}" + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the PipelineState to a dictionary representation. + + :return: A dictionary containing the original input data, inputs, component visits, and ordered component names. + """ + return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "PipelineState": + """ + Populate the PipelineState from a dictionary representation. + + :param data: A dictionary containing the original input data, inputs, component visits, and ordered component + names. + :return: An instance of PipelineState. + """ + return cls(**data) + + +@dataclass +class PipelineSnapshot: + """ + A dataclass to hold a snapshot of the pipeline at a specific point in time. + + :param pipeline_state: The state of the pipeline at the time of the snapshot. + :param break_point: The breakpoint that triggered the snapshot. + :param timestamp: A timestamp indicating when the snapshot was taken. + """ + + pipeline_state: PipelineState + break_point: Union[AgentBreakpoint, Breakpoint] + agent_snapshot: Optional[AgentSnapshot] = None + timestamp: Optional[datetime] = None + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the PipelineSnapshot to a dictionary representation. + + :return: A dictionary containing the pipeline state, timestamp, and breakpoint. + """ + return { + "pipeline_state": self.pipeline_state.to_dict(), + "break_point": self.break_point.to_dict(), + "agent_snapshot": self.agent_snapshot.to_dict() if self.agent_snapshot else None, + "timestamp": self.timestamp.isoformat() if self.timestamp else None, + } + + @classmethod + def from_dict(cls, data: dict) -> "PipelineSnapshot": + """ + Populate the PipelineSnapshot from a dictionary representation. + + :param data: A dictionary containing the pipeline state, timestamp, and breakpoint. + """ + return cls( + pipeline_state=PipelineState.from_dict(data=data["pipeline_state"]), + break_point=( + AgentBreakpoint.from_dict(data=data["break_point"]) + if "agent_name" in data["break_point"] + else Breakpoint.from_dict(data=data["break_point"]) + ), + agent_snapshot=AgentSnapshot.from_dict(data["agent_snapshot"]) if data.get("agent_snapshot") else None, + timestamp=datetime.fromisoformat(data["timestamp"]) if data.get("timestamp") else None, + ) diff --git a/test/components/agents/test_agent_breakpoints_inside_pipeline.py b/test/components/agents/test_agent_breakpoints_inside_pipeline.py index 817e1fdece..91a4d6ebac 100644 --- a/test/components/agents/test_agent_breakpoints_inside_pipeline.py +++ b/test/components/agents/test_agent_breakpoints_inside_pipeline.py @@ -20,7 +20,7 @@ from haystack.dataclasses import ByteStream, ChatMessage, Document, ToolCall from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint from haystack.document_stores.in_memory import InMemoryDocumentStore -from haystack.tools import tool +from haystack.tools import create_tool_from_function document_store = InMemoryDocumentStore() @@ -95,13 +95,16 @@ def run(self, sources: List[ByteStream]) -> Dict[str, List[Document]]: return {"documents": documents} -@tool -def add_database_tool(name: str, surname: str, job_title: Optional[str], other: Optional[str]): +def add_database_tool_function(name: str, surname: str, job_title: Optional[str], other: Optional[str]): document_store.write_documents( [Document(content=name + " " + surname + " " + (job_title or ""), meta={"other": other})] ) +# We use this since the @tool decorator has issues with deserialization +add_database_tool = create_tool_from_function(add_database_tool_function, name="add_database_tool") + + @pytest.fixture def pipeline_with_agent(monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test_key") @@ -219,7 +222,7 @@ def run_pipeline_without_any_breakpoints(pipeline_with_agent): def test_chat_generator_breakpoint_in_pipeline_agent(pipeline_with_agent): with tempfile.TemporaryDirectory() as debug_path: - agent_generator_breakpoint = Breakpoint("chat_generator", 0, debug_path=debug_path) + agent_generator_breakpoint = Breakpoint("chat_generator", 0, snapshot_file_path=debug_path) agent_breakpoint = AgentBreakpoint(break_point=agent_generator_breakpoint, agent_name="database_agent") try: pipeline_with_agent.run( @@ -229,8 +232,8 @@ def test_chat_generator_breakpoint_in_pipeline_agent(pipeline_with_agent): except BreakpointException as e: # this is the exception from the Agent assert e.component == "chat_generator" - assert e.pipeline_snapshot is not None - assert "messages" in e.pipeline_snapshot + assert e.inputs is not None + assert "messages" in e.inputs["chat_generator"]["serialized_data"] assert e.results is not None # verify that snapshot file was created @@ -240,7 +243,9 @@ def test_chat_generator_breakpoint_in_pipeline_agent(pipeline_with_agent): def test_tool_breakpoint_in_pipeline_agent(pipeline_with_agent): with tempfile.TemporaryDirectory() as debug_path: - agent_tool_breakpoint = ToolBreakpoint("tool_invoker", 0, tool_name="add_database_tool", debug_path=debug_path) + agent_tool_breakpoint = ToolBreakpoint( + "tool_invoker", 0, tool_name="add_database_tool", snapshot_file_path=debug_path + ) agent_breakpoint = AgentBreakpoint(break_point=agent_tool_breakpoint, agent_name="database_agent") try: pipeline_with_agent.run( @@ -249,8 +254,8 @@ def test_tool_breakpoint_in_pipeline_agent(pipeline_with_agent): assert False, "Expected exception was not raised" except BreakpointException as e: # this is the exception from the Agent assert e.component == "tool_invoker" - assert e.pipeline_snapshot is not None - assert "messages" in e.pipeline_snapshot + assert e.inputs is not None + assert "messages" in e.inputs["tool_invoker"]["serialized_data"] assert e.results is not None # verify that snapshot file was created @@ -260,7 +265,7 @@ def test_tool_breakpoint_in_pipeline_agent(pipeline_with_agent): def test_agent_breakpoint_chat_generator_and_resume_pipeline(pipeline_with_agent): with tempfile.TemporaryDirectory() as debug_path: - agent_generator_breakpoint = Breakpoint("chat_generator", 0, debug_path=debug_path) + agent_generator_breakpoint = Breakpoint("chat_generator", 0, snapshot_file_path=debug_path) agent_breakpoint = AgentBreakpoint(break_point=agent_generator_breakpoint, agent_name="database_agent") try: pipeline_with_agent.run( @@ -270,8 +275,8 @@ def test_agent_breakpoint_chat_generator_and_resume_pipeline(pipeline_with_agent except BreakpointException as e: assert e.component == "chat_generator" - assert e.pipeline_snapshot is not None - assert "messages" in e.pipeline_snapshot + assert e.inputs is not None + assert "messages" in e.inputs["chat_generator"]["serialized_data"] assert e.results is not None # verify that the snapshot file was created @@ -306,7 +311,9 @@ def test_agent_breakpoint_chat_generator_and_resume_pipeline(pipeline_with_agent def test_agent_breakpoint_tool_and_resume_pipeline(pipeline_with_agent): with tempfile.TemporaryDirectory() as debug_path: - agent_tool_breakpoint = ToolBreakpoint("tool_invoker", 0, tool_name="add_database_tool", debug_path=debug_path) + agent_tool_breakpoint = ToolBreakpoint( + "tool_invoker", 0, tool_name="add_database_tool", snapshot_file_path=debug_path + ) agent_breakpoint = AgentBreakpoint(break_point=agent_tool_breakpoint, agent_name="database_agent") try: pipeline_with_agent.run( @@ -316,8 +323,10 @@ def test_agent_breakpoint_tool_and_resume_pipeline(pipeline_with_agent): except BreakpointException as e: assert e.component == "tool_invoker" - assert e.pipeline_snapshot is not None - assert "messages" in e.pipeline_snapshot + assert e.inputs is not None + assert "serialization_schema" in e.inputs["tool_invoker"] + assert "serialized_data" in e.inputs["tool_invoker"] + assert "messages" in e.inputs["tool_invoker"]["serialized_data"] assert e.results is not None # verify that the snapshot file was created @@ -326,7 +335,8 @@ def test_agent_breakpoint_tool_and_resume_pipeline(pipeline_with_agent): # resume the pipeline from the saved snapshot latest_snapshot_file = max(tool_invoker_snapshot_files, key=os.path.getctime) - result = pipeline_with_agent.run(data={}, pipeline_snapshot=load_pipeline_snapshot(latest_snapshot_file)) + pipeline_snapshot = load_pipeline_snapshot(latest_snapshot_file) + result = pipeline_with_agent.run(data={}, pipeline_snapshot=pipeline_snapshot) # pipeline completed successfully after resuming assert "database_agent" in result diff --git a/test/components/agents/test_agent_breakpoints_isolation_async.py b/test/components/agents/test_agent_breakpoints_isolation_async.py index 297b7df711..a41da346f3 100644 --- a/test/components/agents/test_agent_breakpoints_isolation_async.py +++ b/test/components/agents/test_agent_breakpoints_isolation_async.py @@ -86,7 +86,7 @@ async def test_run_async_with_chat_generator_breakpoint(agent): with pytest.raises(BreakpointException) as exc_info: await agent.run_async(messages=messages, break_point=agent_breakpoint, agent_name=AGENT_NAME) assert exc_info.value.component == "chat_generator" - assert "messages" in exc_info.value.pipeline_snapshot + assert "messages" in exc_info.value.inputs["chat_generator"]["serialized_data"] @pytest.mark.asyncio @@ -100,13 +100,13 @@ async def test_run_async_with_tool_invoker_breakpoint(mock_agent_with_tool_calls ) assert exc_info.value.component == "tool_invoker" - assert "messages" in exc_info.value.pipeline_snapshot + assert "messages" in exc_info.value.inputs["tool_invoker"]["serialized_data"] @pytest.mark.asyncio async def test_resume_from_chat_generator_async(agent, debug_path): messages = [ChatMessage.from_user("What's the weather in Berlin?")] - chat_generator_bp = Breakpoint(component_name="chat_generator", visit_count=0, debug_path=debug_path) + chat_generator_bp = Breakpoint(component_name="chat_generator", visit_count=0, snapshot_file_path=debug_path) agent_breakpoint = AgentBreakpoint(break_point=chat_generator_bp, agent_name=AGENT_NAME) try: @@ -121,7 +121,7 @@ async def test_resume_from_chat_generator_async(agent, debug_path): result = await agent.run_async( messages=[ChatMessage.from_user("Continue from where we left off.")], - pipeline_snapshot=load_pipeline_snapshot(latest_snapshot_file), + snapshot=load_pipeline_snapshot(latest_snapshot_file).agent_snapshot, ) assert "messages" in result @@ -133,7 +133,7 @@ async def test_resume_from_chat_generator_async(agent, debug_path): async def test_resume_from_tool_invoker_async(mock_agent_with_tool_calls, debug_path): messages = [ChatMessage.from_user("What's the weather in Berlin?")] tool_bp = ToolBreakpoint( - component_name="tool_invoker", visit_count=0, tool_name="weather_tool", debug_path=debug_path + component_name="tool_invoker", visit_count=0, tool_name="weather_tool", snapshot_file_path=debug_path ) agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name=AGENT_NAME) @@ -151,7 +151,7 @@ async def test_resume_from_tool_invoker_async(mock_agent_with_tool_calls, debug_ result = await mock_agent_with_tool_calls.run_async( messages=[ChatMessage.from_user("Continue from where we left off.")], - pipeline_snapshot=load_pipeline_snapshot(latest_snapshot_file), + snapshot=load_pipeline_snapshot(latest_snapshot_file).agent_snapshot, ) assert "messages" in result @@ -164,9 +164,9 @@ async def test_invalid_combination_breakpoint_and_pipeline_snapshot_async(mock_a messages = [ChatMessage.from_user("What's the weather in Berlin?")] tool_bp = ToolBreakpoint(component_name="tool_invoker", visit_count=0, tool_name="weather_tool") agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test") - with pytest.raises(ValueError, match="agent_breakpoint and pipeline_snapshot cannot be provided at the same time"): + with pytest.raises(ValueError, match="break_point and snapshot cannot be provided at the same time"): await mock_agent_with_tool_calls.run_async( - messages=messages, break_point=agent_breakpoint, pipeline_snapshot={"some": "snapshot"} + messages=messages, break_point=agent_breakpoint, snapshot={"some": "snapshot"} ) diff --git a/test/components/agents/test_agent_breakpoints_isolation_sync.py b/test/components/agents/test_agent_breakpoints_isolation_sync.py index dd649479ea..4afb4b92a9 100644 --- a/test/components/agents/test_agent_breakpoints_isolation_sync.py +++ b/test/components/agents/test_agent_breakpoints_isolation_sync.py @@ -27,7 +27,7 @@ def test_run_with_chat_generator_breakpoint(agent_sync): # noqa: F811 with pytest.raises(BreakpointException) as exc_info: agent_sync.run(messages=messages, break_point=agent_breakpoint, agent_name="test") assert exc_info.value.component == "chat_generator" - assert "messages" in exc_info.value.pipeline_snapshot + assert "messages" in exc_info.value.inputs["chat_generator"]["serialized_data"] def test_run_with_tool_invoker_breakpoint(mock_agent_with_tool_calls_sync): # noqa: F811 @@ -38,13 +38,16 @@ def test_run_with_tool_invoker_breakpoint(mock_agent_with_tool_calls_sync): # n mock_agent_with_tool_calls_sync.run(messages=messages, break_point=agent_breakpoint, agent_name="test") assert exc_info.value.component == "tool_invoker" - assert "messages" in exc_info.value.pipeline_snapshot + assert {"chat_generator", "tool_invoker"} == set(exc_info.value.inputs.keys()) + assert "serialization_schema" in exc_info.value.inputs["chat_generator"] + assert "serialized_data" in exc_info.value.inputs["chat_generator"] + assert "messages" in exc_info.value.inputs["chat_generator"]["serialized_data"] def test_resume_from_chat_generator(agent_sync, tmp_path): # noqa: F811 messages = [ChatMessage.from_user("What's the weather in Berlin?")] debug_path = str(tmp_path / "debug_snapshots") - chat_generator_bp = Breakpoint(component_name="chat_generator", visit_count=0, debug_path=debug_path) + chat_generator_bp = Breakpoint(component_name="chat_generator", visit_count=0, snapshot_file_path=debug_path) agent_breakpoint = AgentBreakpoint(break_point=chat_generator_bp, agent_name=AGENT_NAME) try: @@ -58,7 +61,7 @@ def test_resume_from_chat_generator(agent_sync, tmp_path): # noqa: F811 result = agent_sync.run( messages=[ChatMessage.from_user("Continue from where we left off.")], - pipeline_snapshot=load_pipeline_snapshot(latest_snapshot_file), + snapshot=load_pipeline_snapshot(latest_snapshot_file).agent_snapshot, ) assert "messages" in result @@ -69,7 +72,9 @@ def test_resume_from_chat_generator(agent_sync, tmp_path): # noqa: F811 def test_resume_from_tool_invoker(mock_agent_with_tool_calls_sync, tmp_path): # noqa: F811 messages = [ChatMessage.from_user("What's the weather in Berlin?")] debug_path = str(tmp_path / "debug_snapshots") - tool_bp = ToolBreakpoint(component_name="tool_invoker", visit_count=0, tool_name=None, debug_path=debug_path) + tool_bp = ToolBreakpoint( + component_name="tool_invoker", visit_count=0, tool_name=None, snapshot_file_path=debug_path + ) agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name=AGENT_NAME) try: @@ -83,7 +88,7 @@ def test_resume_from_tool_invoker(mock_agent_with_tool_calls_sync, tmp_path): # result = mock_agent_with_tool_calls_sync.run( messages=[ChatMessage.from_user("Continue from where we left off.")], - pipeline_snapshot=load_pipeline_snapshot(latest_snapshot_file), + snapshot=load_pipeline_snapshot(latest_snapshot_file).agent_snapshot, ) assert "messages" in result @@ -95,9 +100,9 @@ def test_invalid_combination_breakpoint_and_pipeline_snapshot(mock_agent_with_to messages = [ChatMessage.from_user("What's the weather in Berlin?")] tool_bp = ToolBreakpoint(component_name="tool_invoker", visit_count=0, tool_name="weather_tool") agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test_agent") - with pytest.raises(ValueError, match="agent_breakpoint and pipeline_snapshot cannot be provided at the same time"): + with pytest.raises(ValueError, match="break_point and snapshot cannot be provided at the same time"): mock_agent_with_tool_calls_sync.run( - messages=messages, break_point=agent_breakpoint, pipeline_snapshot={"some": "snapshot"} + messages=messages, break_point=agent_breakpoint, snapshot={"some": "snapshot"} ) diff --git a/test/conftest.py b/test/conftest.py index 002c026d12..7c1d103c02 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -105,8 +105,8 @@ def load_and_resume_pipeline_snapshot(pipeline, output_directory: Path, componen for full_path in all_files: f_name = Path(full_path).name if str(f_name).startswith(component_name): - resume_state = load_pipeline_snapshot(full_path) - return pipeline.run(data=data, pipeline_snapshot=resume_state) + pipeline_snapshot = load_pipeline_snapshot(full_path) + return pipeline.run(data=data, pipeline_snapshot=pipeline_snapshot) if not file_found: msg = f"No files found for {component_name} in {output_directory}." diff --git a/test/core/pipeline/test_breakpoint.py b/test/core/pipeline/test_breakpoint.py index 505292668c..53ad8f150d 100644 --- a/test/core/pipeline/test_breakpoint.py +++ b/test/core/pipeline/test_breakpoint.py @@ -6,11 +6,8 @@ import pytest -from haystack.core.pipeline.breakpoint import ( - _transform_json_structure, - _validate_pipeline_snapshot, - load_pipeline_snapshot, -) +from haystack.core.pipeline.breakpoint import _transform_json_structure, load_pipeline_snapshot +from haystack.dataclasses.breakpoints import PipelineSnapshot def test_transform_json_structure_unwraps_sender_value(): @@ -37,64 +34,11 @@ def test_transform_json_structure_handles_nested_structures(): assert result == {"key1": "value1", "key2": {"nested": "value2", "direct": "value3"}, "key3": ["value4", "value5"]} -def test_validate_pipeline_snapshot_validates_required_keys(): - pipeline_snapshot = { - "input_data": {}, - "pipeline_breakpoint": {"component": "comp1", "visits": 0}, - # Missing pipeline_state - } - - with pytest.raises(ValueError, match="Invalid pipeline_snapshot: missing required keys"): - _validate_pipeline_snapshot(pipeline_snapshot) - - pipeline_snapshot = { - "input_data": {}, - "pipeline_breakpoint": {"component": "comp1", "visits": 0}, - "pipeline_state": { - "inputs": {}, - "component_visits": {}, - # Missing ordered_component_names - }, - } - - with pytest.raises(ValueError, match="Invalid pipeline_state: missing required keys"): - _validate_pipeline_snapshot(pipeline_snapshot) - - -def test_validate_pipeline_snapshot_validates_component_consistency(): - pipeline_snapshot = { - "input_data": {}, - "pipeline_breakpoint": {"component": "comp1", "visits": 0}, - "pipeline_state": { - "inputs": {}, - "component_visits": {"comp1": 0, "comp2": 0}, - "ordered_component_names": ["comp1", "comp3"], # inconsistent with component_visits - }, - } - - with pytest.raises(ValueError, match="Inconsistent state: components in pipeline_state"): - _validate_pipeline_snapshot(pipeline_snapshot) - - -def test_validate_pipeline_snapshot_validates_valid_snapshot(): - pipeline_snapshot = { - "input_data": {}, - "pipeline_breakpoint": {"component": "comp1", "visits": 0}, - "pipeline_state": { - "inputs": {}, - "component_visits": {"comp1": 0, "comp2": 0}, - "ordered_component_names": ["comp1", "comp2"], - }, - } - - _validate_pipeline_snapshot(pipeline_snapshot) # should not raise any exception - - def test_load_pipeline_snapshot_loads_valid_snapshot(tmp_path): pipeline_snapshot = { - "input_data": {}, - "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + "break_point": {"component_name": "comp1", "visit_count": 0}, "pipeline_state": { + "original_input_data": {}, "inputs": {}, "component_visits": {"comp1": 0, "comp2": 0}, "ordered_component_names": ["comp1", "comp2"], @@ -105,14 +49,14 @@ def test_load_pipeline_snapshot_loads_valid_snapshot(tmp_path): json.dump(pipeline_snapshot, f) loaded_snapshot = load_pipeline_snapshot(pipeline_snapshot_file) - assert loaded_snapshot == pipeline_snapshot + assert loaded_snapshot == PipelineSnapshot.from_dict(pipeline_snapshot) def test_load_state_handles_invalid_state(tmp_path): pipeline_snapshot = { - "input_data": {}, - "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + "break_point": {"component_name": "comp1", "visit_count": 0}, "pipeline_state": { + "original_input_data": {}, "inputs": {}, "component_visits": {"comp1": 0, "comp2": 0}, "ordered_component_names": ["comp1", "comp3"], # inconsistent with component_visits diff --git a/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py index 9752ac9c01..af7cf70326 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py @@ -111,7 +111,7 @@ def test_pipeline_breakpoints_answer_joiner(self, answer_join_pipeline, output_d } # Create a Breakpoint on-the-fly using the shared output directory - break_point = Breakpoint(component_name=component, visit_count=0, debug_path=str(output_directory)) + break_point = Breakpoint(component_name=component, visit_count=0, snapshot_file_path=str(output_directory)) try: _ = answer_join_pipeline.run(data, break_point=break_point) diff --git a/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py index 75e526e57b..e5e2618891 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py @@ -106,7 +106,7 @@ def test_pipeline_breakpoints_branch_joiner(self, branch_joiner_pipeline, output } # Create a Breakpoint on-the-fly using the shared output directory - break_point = Breakpoint(component_name=component, visit_count=0, debug_path=str(output_directory)) + break_point = Breakpoint(component_name=component, visit_count=0, snapshot_file_path=str(output_directory)) try: _ = branch_joiner_pipeline.run(data, break_point=break_point) diff --git a/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py index 7281585fdb..9a89abac3f 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py @@ -125,7 +125,7 @@ def test_list_joiner_pipeline(self, list_joiner_pipeline, output_directory, comp } # Create a Breakpoint on-the-fly using the shared output directory - break_point = Breakpoint(component_name=component, visit_count=0, debug_path=str(output_directory)) + break_point = Breakpoint(component_name=component, visit_count=0, snapshot_file_path=str(output_directory)) try: _ = list_joiner_pipeline.run(data, break_point=break_point) diff --git a/test/core/pipeline/test_pipeline_breakpoints_loops.py b/test/core/pipeline/test_pipeline_breakpoints_loops.py index fd4365d8b8..f76dd6d48c 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_loops.py +++ b/test/core/pipeline/test_pipeline_breakpoints_loops.py @@ -211,7 +211,7 @@ def test_pipeline_breakpoints_validation_loop( data = {"prompt_builder": {"passage": test_data["passage"], "schema": test_data["schema"]}} # Create a Breakpoint on-the-fly using the shared output directory - break_point = Breakpoint(component_name=component, visit_count=0, debug_path=str(output_directory)) + break_point = Breakpoint(component_name=component, visit_count=0, snapshot_file_path=str(output_directory)) try: _ = validation_loop_pipeline.run(data, break_point=break_point) diff --git a/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py b/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py index 6c9f8a95fe..820942c247 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py +++ b/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py @@ -285,7 +285,7 @@ def test_pipeline_breakpoints_hybrid_rag( } # Create a Breakpoint on-the-fly using the shared output directory - break_point = Breakpoint(component_name=component, visit_count=0, debug_path=str(output_directory)) + break_point = Breakpoint(component_name=component, visit_count=0, snapshot_file_path=str(output_directory)) try: _ = hybrid_rag_pipeline.run(data, break_point=break_point) diff --git a/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py index 4b791535c6..1ab96c34f6 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py @@ -49,7 +49,7 @@ def test_string_joiner_pipeline(self, string_joiner_pipeline, output_directory, data = {"prompt_builder_1": {"query": string_1}, "prompt_builder_2": {"query": string_2}} # Create a Breakpoint on-the-fly using the shared output directory - break_point = Breakpoint(component_name=component, visit_count=0, debug_path=str(output_directory)) + break_point = Breakpoint(component_name=component, visit_count=0, snapshot_file_path=str(output_directory)) try: _ = string_joiner_pipeline.run(data, break_point=break_point) From b87d2b3ddfcb5cd3ff40cbc585381503a5138966 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Wed, 23 Jul 2025 12:32:45 +0100 Subject: [PATCH 12/21] feat: saving include_outputs_from intermediate results to `PipelineState` object (#9629) * saving intermediate components results in include_outputs_from into the PipelineSnaptshot * cleaning up * fixing tests * fixing tests * extending tests * Update haystack/dataclasses/breakpoints.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * Update haystack/dataclasses/breakpoints.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * linting * moving intermediate results to pipeline state and adding pipeline outputs to state * moving ordered_component_names and include_outputs_from to PipelineSnapshot * moving original_input_data to PipelineSnapshot * simplifying saving the intermediate results * Update haystack/dataclasses/breakpoints.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * Update haystack/dataclasses/breakpoints.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * Update haystack/dataclasses/breakpoints.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * Update haystack/dataclasses/breakpoints.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> --------- Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> --- haystack/core/pipeline/breakpoint.py | 31 +++++++----- haystack/core/pipeline/pipeline.py | 13 ++++- haystack/dataclasses/breakpoints.py | 63 ++++++++++++++++--------- test/core/pipeline/test_breakpoint.py | 68 ++++++++++++++++++++++----- 4 files changed, 127 insertions(+), 48 deletions(-) diff --git a/haystack/core/pipeline/breakpoint.py b/haystack/core/pipeline/breakpoint.py index 168fe2628e..6699f4fe10 100644 --- a/haystack/core/pipeline/breakpoint.py +++ b/haystack/core/pipeline/breakpoint.py @@ -7,7 +7,7 @@ from dataclasses import replace from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Set, Union from networkx import MultiDiGraph @@ -73,7 +73,7 @@ def _validate_pipeline_snapshot_against_pipeline(pipeline_snapshot: PipelineSnap valid_components = set(graph.nodes.keys()) # Check if the ordered_component_names are valid components in the pipeline - invalid_ordered_components = set(pipeline_state.ordered_component_names) - valid_components + invalid_ordered_components = set(pipeline_snapshot.ordered_component_names) - valid_components if invalid_ordered_components: raise PipelineInvalidPipelineSnapshotError( f"Invalid pipeline snapshot: components {invalid_ordered_components} in 'ordered_component_names' " @@ -81,7 +81,7 @@ def _validate_pipeline_snapshot_against_pipeline(pipeline_snapshot: PipelineSnap ) # Check if the original_input_data is valid components in the pipeline - serialized_input_data = pipeline_snapshot.pipeline_state.original_input_data["serialized_data"] + serialized_input_data = pipeline_snapshot.original_input_data["serialized_data"] invalid_input_data = set(serialized_input_data.keys()) - valid_components if invalid_input_data: raise PipelineInvalidPipelineSnapshotError( @@ -184,6 +184,9 @@ def _create_pipeline_snapshot( component_visits: Dict[str, int], original_input_data: Optional[Dict[str, Any]] = None, ordered_component_names: Optional[List[str]] = None, + include_outputs_from: Optional[Set[str]] = None, + intermediate_outputs: Optional[Dict[str, Any]] = None, + pipeline_outputs: Optional[Dict[str, Any]] = None, ) -> PipelineSnapshot: """ Create a snapshot of the pipeline at the point where the breakpoint was triggered. @@ -193,6 +196,8 @@ def _create_pipeline_snapshot( :param component_visits: The visit count of the component that triggered the breakpoint. :param original_input_data: The original input data. :param ordered_component_names: The ordered component names. + :param include_outputs_from: Set of component names whose outputs should be included in the pipeline results. + :param intermediate_outputs: Dictionary containing outputs from components that are in the include_outputs_from set. """ dt = datetime.now() @@ -201,13 +206,15 @@ def _create_pipeline_snapshot( pipeline_snapshot = PipelineSnapshot( pipeline_state=PipelineState( - original_input_data=_serialize_value_with_schema(transformed_original_input_data), inputs=_serialize_value_with_schema(transformed_inputs), # current pipeline inputs component_visits=component_visits, - ordered_component_names=ordered_component_names or [], + pipeline_outputs=pipeline_outputs or {}, ), timestamp=dt, break_point=break_point, + original_input_data=_serialize_value_with_schema(transformed_original_input_data), + ordered_component_names=ordered_component_names or [], + include_outputs_from=include_outputs_from or set(), ) return pipeline_snapshot @@ -353,12 +360,13 @@ def _check_chat_generator_breakpoint( if parent_snapshot is None: # Create an empty pipeline snapshot if no parent snapshot is provided final_snapshot = PipelineSnapshot( - pipeline_state=PipelineState( - original_input_data={}, inputs={}, component_visits={}, ordered_component_names=[] - ), + pipeline_state=PipelineState(inputs={}, component_visits={}, pipeline_outputs={}), timestamp=agent_snapshot.timestamp, break_point=agent_snapshot.break_point, agent_snapshot=agent_snapshot, + original_input_data={}, + ordered_component_names=[], + include_outputs_from=set(), ) else: final_snapshot = replace(parent_snapshot, agent_snapshot=agent_snapshot) @@ -412,12 +420,13 @@ def _check_tool_invoker_breakpoint( if parent_snapshot is None: # Create an empty pipeline snapshot if no parent snapshot is provided final_snapshot = PipelineSnapshot( - pipeline_state=PipelineState( - original_input_data={}, inputs={}, component_visits={}, ordered_component_names=[] - ), + pipeline_state=PipelineState(inputs={}, component_visits={}, pipeline_outputs={}), timestamp=agent_snapshot.timestamp, break_point=agent_snapshot.break_point, agent_snapshot=agent_snapshot, + original_input_data={}, + ordered_component_names=[], + include_outputs_from=set(), ) else: final_snapshot = replace(parent_snapshot, agent_snapshot=agent_snapshot) diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index 5cd2b1ad5f..4f4640a9ae 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -212,6 +212,8 @@ def run( # noqa: PLR0915, PLR0912, C901, pylint: disable=too-many-branches if include_outputs_from is None: include_outputs_from = set() + pipeline_outputs: Dict[str, Any] = {} + if not pipeline_snapshot: # normalize `data` data = self._prepare_component_input_data(data) @@ -232,16 +234,21 @@ def run( # noqa: PLR0915, PLR0912, C901, pylint: disable=too-many-branches # Handle resuming the pipeline from a snapshot component_visits = pipeline_snapshot.pipeline_state.component_visits - ordered_component_names = pipeline_snapshot.pipeline_state.ordered_component_names + ordered_component_names = pipeline_snapshot.ordered_component_names data = self._prepare_component_input_data(pipeline_snapshot.pipeline_state.inputs) data = _deserialize_value_with_schema(pipeline_snapshot.pipeline_state.inputs) + # include_outputs_from from the snapshot when resuming + include_outputs_from = pipeline_snapshot.include_outputs_from + + # also intermediate_outputs from the snapshot when resuming + pipeline_outputs = pipeline_snapshot.pipeline_state.pipeline_outputs + cached_topological_sort = None # We need to access a component's receivers multiple times during a pipeline run. # We store them here for easy access. cached_receivers = {name: self._find_receivers_from(name) for name in ordered_component_names} - pipeline_outputs: Dict[str, Any] = {} with tracing.tracer.trace( "haystack.pipeline.run", tags={ @@ -341,6 +348,8 @@ def run( # noqa: PLR0915, PLR0912, C901, pylint: disable=too-many-branches component_visits=component_visits, original_input_data=data, ordered_component_names=ordered_component_names, + include_outputs_from=include_outputs_from, + pipeline_outputs=pipeline_outputs, ) # Scenario 2.1: an AgentBreakpoint is provided to stop the pipeline at a specific component diff --git a/haystack/dataclasses/breakpoints.py b/haystack/dataclasses/breakpoints.py index d67e02e644..4e0f75ad4c 100644 --- a/haystack/dataclasses/breakpoints.py +++ b/haystack/dataclasses/breakpoints.py @@ -2,9 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from datetime import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Set, Union @dataclass(frozen=True) @@ -155,31 +155,20 @@ class PipelineState: A dataclass to hold the state of the pipeline at a specific point in time. :param component_visits: A dictionary mapping component names to their visit counts. - :param ordered_component_names: A list of component names in the order they were visited. - :param original_input_data: The original input data provided to the pipeline. :param inputs: The inputs processed by the pipeline at the time of the snapshot. + :param pipeline_outputs: Dictionary containing the final outputs of the pipeline up to the breakpoint. """ - original_input_data: Dict[str, Any] inputs: Dict[str, Any] component_visits: Dict[str, int] - ordered_component_names: List[str] - - def __post_init__(self): - components_in_state = set(self.component_visits.keys()) - components_in_order = set(self.ordered_component_names) - - if components_in_state != components_in_order: - raise ValueError( - f"Inconsistent state: components in PipelineState.component_visits {components_in_state} " - f"do not match components in PipelineState.ordered_component_names {components_in_order}" - ) + pipeline_outputs: Dict[str, Any] def to_dict(self) -> Dict[str, Any]: """ Convert the PipelineState to a dictionary representation. - :return: A dictionary containing the original input data, inputs, component visits, and ordered component names. + :return: A dictionary containing the inputs, component visits, + and pipeline outputs. """ return asdict(self) @@ -188,8 +177,8 @@ def from_dict(cls, data: dict) -> "PipelineState": """ Populate the PipelineState from a dictionary representation. - :param data: A dictionary containing the original input data, inputs, component visits, and ordered component - names. + :param data: A dictionary containing the inputs, component visits, + and pipeline outputs. :return: An instance of PipelineState. """ return cls(**data) @@ -202,34 +191,61 @@ class PipelineSnapshot: :param pipeline_state: The state of the pipeline at the time of the snapshot. :param break_point: The breakpoint that triggered the snapshot. + :param agent_snapshot: Optional agent snapshot if the breakpoint is an agent breakpoint. :param timestamp: A timestamp indicating when the snapshot was taken. + :param original_input_data: The original input data provided to the pipeline. + :param ordered_component_names: A list of component names in the order they were visited. + :param include_outputs_from: Set of component names whose outputs should be included in the pipeline results. """ + original_input_data: Dict[str, Any] + ordered_component_names: List[str] pipeline_state: PipelineState break_point: Union[AgentBreakpoint, Breakpoint] agent_snapshot: Optional[AgentSnapshot] = None timestamp: Optional[datetime] = None + include_outputs_from: Set[str] = field(default_factory=set) + + def __post_init__(self): + # Validate consistency between component_visits and ordered_component_names + components_in_state = set(self.pipeline_state.component_visits.keys()) + components_in_order = set(self.ordered_component_names) + + if components_in_state != components_in_order: + raise ValueError( + f"Inconsistent state: components in PipelineState.component_visits {components_in_state} " + f"do not match components in PipelineSnapshot.ordered_component_names {components_in_order}" + ) def to_dict(self) -> Dict[str, Any]: """ Convert the PipelineSnapshot to a dictionary representation. - :return: A dictionary containing the pipeline state, timestamp, and breakpoint. + :return: A dictionary containing the pipeline state, timestamp, breakpoint, agent snapshot, original input data, + ordered component names, include_outputs_from, and pipeline outputs. """ - return { + data = { "pipeline_state": self.pipeline_state.to_dict(), "break_point": self.break_point.to_dict(), "agent_snapshot": self.agent_snapshot.to_dict() if self.agent_snapshot else None, "timestamp": self.timestamp.isoformat() if self.timestamp else None, + "original_input_data": self.original_input_data, + "ordered_component_names": self.ordered_component_names, + "include_outputs_from": list(self.include_outputs_from), } + return data @classmethod def from_dict(cls, data: dict) -> "PipelineSnapshot": """ Populate the PipelineSnapshot from a dictionary representation. - :param data: A dictionary containing the pipeline state, timestamp, and breakpoint. + :param data: A dictionary containing the pipeline state, timestamp, breakpoint, agent snapshot, original input + data, ordered component names, include_outputs_from, and pipeline outputs. """ + # Convert include_outputs_from list back to set for serialization + include_outputs_from = set(data.get("include_outputs_from", [])) + return cls( pipeline_state=PipelineState.from_dict(data=data["pipeline_state"]), break_point=( @@ -239,4 +255,7 @@ def from_dict(cls, data: dict) -> "PipelineSnapshot": ), agent_snapshot=AgentSnapshot.from_dict(data["agent_snapshot"]) if data.get("agent_snapshot") else None, timestamp=datetime.fromisoformat(data["timestamp"]) if data.get("timestamp") else None, + original_input_data=data.get("original_input_data", {}), + ordered_component_names=data.get("ordered_component_names", []), + include_outputs_from=include_outputs_from, ) diff --git a/test/core/pipeline/test_breakpoint.py b/test/core/pipeline/test_breakpoint.py index 53ad8f150d..7e9a53f7ad 100644 --- a/test/core/pipeline/test_breakpoint.py +++ b/test/core/pipeline/test_breakpoint.py @@ -6,8 +6,11 @@ import pytest +from haystack import component +from haystack.core.errors import BreakpointException +from haystack.core.pipeline import Pipeline from haystack.core.pipeline.breakpoint import _transform_json_structure, load_pipeline_snapshot -from haystack.dataclasses.breakpoints import PipelineSnapshot +from haystack.dataclasses.breakpoints import Breakpoint, PipelineSnapshot def test_transform_json_structure_unwraps_sender_value(): @@ -37,12 +40,10 @@ def test_transform_json_structure_handles_nested_structures(): def test_load_pipeline_snapshot_loads_valid_snapshot(tmp_path): pipeline_snapshot = { "break_point": {"component_name": "comp1", "visit_count": 0}, - "pipeline_state": { - "original_input_data": {}, - "inputs": {}, - "component_visits": {"comp1": 0, "comp2": 0}, - "ordered_component_names": ["comp1", "comp2"], - }, + "pipeline_state": {"inputs": {}, "component_visits": {"comp1": 0, "comp2": 0}, "pipeline_outputs": {}}, + "original_input_data": {}, + "ordered_component_names": ["comp1", "comp2"], + "include_outputs_from": ["comp1", "comp2"], } pipeline_snapshot_file = tmp_path / "state.json" with open(pipeline_snapshot_file, "w") as f: @@ -55,12 +56,10 @@ def test_load_pipeline_snapshot_loads_valid_snapshot(tmp_path): def test_load_state_handles_invalid_state(tmp_path): pipeline_snapshot = { "break_point": {"component_name": "comp1", "visit_count": 0}, - "pipeline_state": { - "original_input_data": {}, - "inputs": {}, - "component_visits": {"comp1": 0, "comp2": 0}, - "ordered_component_names": ["comp1", "comp3"], # inconsistent with component_visits - }, + "pipeline_state": {"inputs": {}, "component_visits": {"comp1": 0, "comp2": 0}, "pipeline_outputs": {}}, + "original_input_data": {}, + "include_outputs_from": ["comp1", "comp2"], + "ordered_component_names": ["comp1", "comp3"], # inconsistent with component_visits } pipeline_snapshot_file = tmp_path / "invalid_pipeline_snapshot.json" @@ -69,3 +68,46 @@ def test_load_state_handles_invalid_state(tmp_path): with pytest.raises(ValueError, match="Invalid pipeline snapshot from"): load_pipeline_snapshot(pipeline_snapshot_file) + + +def test_breakpoint_saves_intermediate_outputs(tmp_path): + @component + class SimpleComponent: + @component.output_types(result=str) + def run(self, input_value: str) -> dict[str, str]: + return {"result": f"processed_{input_value}"} + + pipeline = Pipeline() + comp1 = SimpleComponent() + comp2 = SimpleComponent() + pipeline.add_component("comp1", comp1) + pipeline.add_component("comp2", comp2) + pipeline.connect("comp1", "comp2") + + # breakpoint on comp2 + break_point = Breakpoint(component_name="comp2", visit_count=0, snapshot_file_path=str(tmp_path)) + + try: + # run with include_outputs_from to capture intermediate outputs + pipeline.run(data={"comp1": {"input_value": "test"}}, include_outputs_from={"comp1"}, break_point=break_point) + except BreakpointException as e: + # breakpoint should be triggered + assert e.component == "comp2" + + # verify snapshot file contains the intermediate outputs + snapshot_files = list(tmp_path.glob("comp2_*.json")) + assert len(snapshot_files) == 1, f"Expected exactly one snapshot file, found {len(snapshot_files)}" + + snapshot_file = snapshot_files[0] + loaded_snapshot = load_pipeline_snapshot(snapshot_file) + + # verify the snapshot contains the intermediate outputs from comp1 + assert "comp1" in loaded_snapshot.pipeline_state.pipeline_outputs + assert loaded_snapshot.pipeline_state.pipeline_outputs["comp1"]["result"] == "processed_test" + + # verify the whole pipeline state contains the expected data + assert loaded_snapshot.pipeline_state.component_visits["comp1"] == 1 + assert loaded_snapshot.pipeline_state.component_visits["comp2"] == 0 + assert "comp1" in loaded_snapshot.include_outputs_from + assert loaded_snapshot.break_point.component_name == "comp2" + assert loaded_snapshot.break_point.visit_count == 0 From bc32c230304be629ca5ad3d157d25dc311a96151 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Wed, 23 Jul 2025 14:35:48 +0200 Subject: [PATCH 13/21] linting --- test/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/conftest.py b/test/conftest.py index 4457714fe8..eeef71d923 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -112,6 +112,7 @@ def load_and_resume_pipeline_snapshot(pipeline, output_directory: Path, componen msg = f"No files found for {component_name} in {output_directory}." raise ValueError(msg) + @pytest.fixture() def base64_image_string(): return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=" From ba1fdce5d208354eb4cc2036f91ee7ca02a7a0db Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Wed, 23 Jul 2025 14:57:34 +0200 Subject: [PATCH 14/21] cleaning up --- haystack/core/pipeline/breakpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/haystack/core/pipeline/breakpoint.py b/haystack/core/pipeline/breakpoint.py index 6699f4fe10..4fe371491e 100644 --- a/haystack/core/pipeline/breakpoint.py +++ b/haystack/core/pipeline/breakpoint.py @@ -185,7 +185,6 @@ def _create_pipeline_snapshot( original_input_data: Optional[Dict[str, Any]] = None, ordered_component_names: Optional[List[str]] = None, include_outputs_from: Optional[Set[str]] = None, - intermediate_outputs: Optional[Dict[str, Any]] = None, pipeline_outputs: Optional[Dict[str, Any]] = None, ) -> PipelineSnapshot: """ From be0b7c2ef54c01961249ec434d620b6c9784a678 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Wed, 23 Jul 2025 15:46:42 +0200 Subject: [PATCH 15/21] avoiding creating PipelineSnapshot for every component run --- haystack/core/pipeline/pipeline.py | 45 +++++++++++++++++++----------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index 4f4640a9ae..b0e5808361 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -339,33 +339,46 @@ def run( # noqa: PLR0915, PLR0912, C901, pylint: disable=too-many-branches # Scenario 2: A breakpoint is provided to stop the pipeline at a specific component if break_point: - # Create a PipelineSnapshot to capture the current state of the pipeline - pipeline_snapshot_inputs_serialised = deepcopy(inputs) - pipeline_snapshot_inputs_serialised[component_name] = deepcopy(component_inputs) - new_pipeline_snapshot = _create_pipeline_snapshot( - inputs=pipeline_snapshot_inputs_serialised, - break_point=break_point, - component_visits=component_visits, - original_input_data=data, - ordered_component_names=ordered_component_names, - include_outputs_from=include_outputs_from, - pipeline_outputs=pipeline_outputs, - ) + should_trigger_breakpoint = False + should_create_snapshot = False # Scenario 2.1: an AgentBreakpoint is provided to stop the pipeline at a specific component if isinstance(break_point, AgentBreakpoint) and component_name == break_point.agent_name: - # Add the break_point and pipeline_snapshot to the agent's component inputs + should_create_snapshot = True component_inputs["break_point"] = break_point - component_inputs["parent_snapshot"] = new_pipeline_snapshot # Scenario 2.2: a regular breakpoint is provided to stop the pipeline at a specific component and # visit count - if ( + elif ( isinstance(break_point, Breakpoint) and break_point.component_name == component_name and break_point.visit_count == component_visits[component_name] ): - _trigger_break_point(pipeline_snapshot=new_pipeline_snapshot, pipeline_outputs=pipeline_outputs) + should_trigger_breakpoint = True + should_create_snapshot = True + + if should_create_snapshot: + pipeline_snapshot_inputs_serialised = deepcopy(inputs) + pipeline_snapshot_inputs_serialised[component_name] = deepcopy(component_inputs) + new_pipeline_snapshot = _create_pipeline_snapshot( + inputs=pipeline_snapshot_inputs_serialised, + break_point=break_point, + component_visits=component_visits, + original_input_data=data, + ordered_component_names=ordered_component_names, + include_outputs_from=include_outputs_from, + pipeline_outputs=pipeline_outputs, + ) + + # add the parent_snapshot to agent inputs if needed + if isinstance(break_point, AgentBreakpoint) and component_name == break_point.agent_name: + component_inputs["parent_snapshot"] = new_pipeline_snapshot + + # trigger the breakpoint if needed + if should_trigger_breakpoint: + _trigger_break_point( + pipeline_snapshot=new_pipeline_snapshot, pipeline_outputs=pipeline_outputs + ) component_outputs = self._run_component( component_name=component_name, From 2e25b758069219c8c9b6e4c3f96400e9bea3195c Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Wed, 23 Jul 2025 16:16:51 +0200 Subject: [PATCH 16/21] removing unecessary code --- haystack/core/pipeline/pipeline.py | 1 - haystack/utils/base_serialization.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index b0e5808361..221b496ffe 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -235,7 +235,6 @@ def run( # noqa: PLR0915, PLR0912, C901, pylint: disable=too-many-branches # Handle resuming the pipeline from a snapshot component_visits = pipeline_snapshot.pipeline_state.component_visits ordered_component_names = pipeline_snapshot.ordered_component_names - data = self._prepare_component_input_data(pipeline_snapshot.pipeline_state.inputs) data = _deserialize_value_with_schema(pipeline_snapshot.pipeline_state.inputs) # include_outputs_from from the snapshot when resuming diff --git a/haystack/utils/base_serialization.py b/haystack/utils/base_serialization.py index 3303fc61f3..2f98b30471 100644 --- a/haystack/utils/base_serialization.py +++ b/haystack/utils/base_serialization.py @@ -208,7 +208,7 @@ def _deserialize_value_with_schema(serialized: Dict[str, Any]) -> Any: # pylint schema_type = schema.get("type") if not schema_type: - # for backward comaptability till Haystack 2.16 we use legacy implementation + # for backward compatibility till Haystack 2.16 we use legacy implementation raise DeserializationError( "Missing 'type' key in 'serialization_schema'. This likely indicates that you're using a serialized " "State object created with a version of Haystack older than 2.15.0. " From 1c475bf7c2865ae09cf7e39dde90977d8a7cdf9f Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 24 Jul 2025 08:56:27 +0200 Subject: [PATCH 17/21] Update checks in Agent to not unecessarily create AgentSnapshot when not needed. --- haystack/components/agents/agent.py | 24 ++++++++++++++++++++---- haystack/core/pipeline/breakpoint.py | 7 +------ 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index d8c4fd9f1e..5abbaea110 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -333,7 +333,11 @@ def run( # noqa: PLR0915 while counter < self.max_agent_steps: # check for breakpoint before ChatGenerator - if break_point and break_point.break_point.component_name == "chat_generator": + if ( + break_point + and break_point.break_point.component_name == "chat_generator" + and component_visits["chat_generator"] == break_point.break_point.visit_count + ): agent_snapshot = _create_agent_snapshot( component_visits=component_visits, agent_breakpoint=break_point, @@ -367,7 +371,11 @@ def run( # noqa: PLR0915 break # check for breakpoint before ToolInvoker - if break_point and break_point.break_point.component_name == "tool_invoker": + if ( + break_point + and break_point.break_point.component_name == "tool_invoker" + and break_point.break_point.visit_count == component_visits["tool_invoker"] + ): agent_snapshot = _create_agent_snapshot( component_visits=component_visits, agent_breakpoint=break_point, @@ -517,7 +525,11 @@ async def run_async( # noqa: PLR0915 while counter < self.max_agent_steps: # check for breakpoint before ChatGenerator - if break_point and break_point.break_point.component_name == "chat_generator": + if ( + break_point + and break_point.break_point.component_name == "chat_generator" + and component_visits["chat_generator"] == break_point.break_point.visit_count + ): agent_snapshot = _create_agent_snapshot( component_visits=component_visits, agent_breakpoint=break_point, @@ -552,7 +564,11 @@ async def run_async( # noqa: PLR0915 break # Check for breakpoint before ToolInvoker - if break_point and break_point.break_point.component_name == "tool_invoker": + if ( + break_point + and break_point.break_point.component_name == "tool_invoker" + and break_point.break_point.visit_count == component_visits["tool_invoker"] + ): agent_snapshot = _create_agent_snapshot( component_visits=component_visits, agent_breakpoint=break_point, diff --git a/haystack/core/pipeline/breakpoint.py b/haystack/core/pipeline/breakpoint.py index 4fe371491e..6ec1733846 100644 --- a/haystack/core/pipeline/breakpoint.py +++ b/haystack/core/pipeline/breakpoint.py @@ -353,8 +353,6 @@ def _check_chat_generator_breakpoint( """ break_point = agent_snapshot.break_point.break_point - if agent_snapshot.component_visits[break_point.component_name] != break_point.visit_count: - return if parent_snapshot is None: # Create an empty pipeline snapshot if no parent snapshot is provided @@ -373,7 +371,7 @@ def _check_chat_generator_breakpoint( msg = ( f"Breaking at {break_point.component_name} visit count " - "{agent_snapshot.component_visits[break_point.component_name]}" + f"{agent_snapshot.component_visits[break_point.component_name]}" ) logger.info(msg) raise BreakpointException( @@ -399,9 +397,6 @@ def _check_tool_invoker_breakpoint( return tool_breakpoint = agent_snapshot.break_point.break_point - # Check if the visit count matches - if agent_snapshot.component_visits[tool_breakpoint.component_name] != tool_breakpoint.visit_count: - return # Check if we should break for this specific tool or all tools if tool_breakpoint.tool_name is None: From f92e19c1622cd932530b13b4201a5b1159b567ed Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Thu, 24 Jul 2025 10:20:11 +0200 Subject: [PATCH 18/21] Update haystack/components/agents/agent.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> --- haystack/components/agents/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 5abbaea110..7219638505 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -269,7 +269,7 @@ def run( # noqa: PLR0915 """ # kwargs can contain the key parent_snapshot. - # We pop it here to avoid passing it into State. We explicitly handle it pass it on if a break point is + # We pop it here to avoid passing it into State. We explicitly pass it on if a break point is # triggered. parent_snapshot = kwargs.pop("parent_snapshot", None) From 3b26f656ce1854bd0b1f280c318a25bba9926d18 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Thu, 24 Jul 2025 10:20:23 +0200 Subject: [PATCH 19/21] Update haystack/components/agents/agent.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> --- haystack/components/agents/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 7219638505..596fdfe78d 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -256,7 +256,7 @@ def run( # noqa: PLR0915 The same callback can be configured to emit tool results when a tool is called. :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint for "tool_invoker". - :param snapshot: A dictionary containing the state of a previously saved agent execution. + :param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains the relevant information to restart the Agent execution from where it left off. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :returns: From 0e45966e8a00ab416d42c695939d8922d5abc1fe Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Thu, 24 Jul 2025 10:23:20 +0200 Subject: [PATCH 20/21] cleaning up tests --- test/core/pipeline/test_pipeline.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index 7f5962d42d..20514b7325 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -123,13 +123,3 @@ def run(self): component_visits={"erroring_component": 0}, ) assert "Component name: 'erroring_component'" in str(exc_info.value) - - def test_run(self): - joiner_1 = BranchJoiner(type_=str) - joiner_2 = BranchJoiner(type_=str) - pp = Pipeline() - pp.add_component("joiner_1", joiner_1) - pp.add_component("joiner_2", joiner_2) - pp.connect("joiner_1", "joiner_2") - - _ = pp.run({"value": "test_value"}) From 6fb202274e422849448e055f9376fc0aa2e7ac86 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Thu, 24 Jul 2025 10:24:05 +0200 Subject: [PATCH 21/21] linting --- haystack/components/agents/agent.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 596fdfe78d..595a931035 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -256,7 +256,8 @@ def run( # noqa: PLR0915 The same callback can be configured to emit tool results when a tool is called. :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint for "tool_invoker". - :param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains the relevant information to restart the Agent execution from where it left off. + :param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains + the relevant information to restart the Agent execution from where it left off. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :returns: