diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 5df6549159..595a931035 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -5,16 +5,25 @@ import inspect from typing import Any, Dict, List, Optional, Union -from haystack import component, default_from_dict, default_to_dict, logging, tracing +from haystack import logging, tracing from haystack.components.generators.chat.types import ChatGenerator from haystack.components.tools import ToolInvoker +from haystack.core.component.component import component from haystack.core.pipeline.async_pipeline import AsyncPipeline +from haystack.core.pipeline.breakpoint import ( + _check_chat_generator_breakpoint, + _check_tool_invoker_breakpoint, + _create_agent_snapshot, + _validate_tool_breakpoint_is_valid, +) from haystack.core.pipeline.pipeline import Pipeline from haystack.core.pipeline.utils import _deepcopy_with_exceptions -from haystack.core.serialization import component_to_dict +from haystack.core.serialization import component_to_dict, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, ChatRole +from haystack.dataclasses.breakpoints import AgentBreakpoint, AgentSnapshot, ToolBreakpoint from haystack.dataclasses.streaming_chunk import StreamingCallbackT, select_streaming_callback from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset +from haystack.utils import _deserialize_value_with_schema from haystack.utils.callable_serialization import deserialize_callable, serialize_callable from haystack.utils.deserialization import deserialize_chatgenerator_inplace @@ -229,8 +238,14 @@ def _create_agent_span(self) -> Any: }, ) - def run( - self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any + def run( # noqa: PLR0915 + self, + messages: List[ChatMessage], + streaming_callback: Optional[StreamingCallbackT] = None, + *, + break_point: Optional[AgentBreakpoint] = None, + snapshot: Optional[AgentSnapshot] = None, + **kwargs: Any, ) -> Dict[str, Any]: """ Process messages and execute tools until an exit condition is met. @@ -239,6 +254,10 @@ def run( If a list of dictionaries is provided, each dictionary will be converted to a ChatMessage object. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint + for "tool_invoker". + :param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains + the relevant information to restart the Agent execution from where it left off. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :returns: @@ -247,22 +266,60 @@ def run( - "last_message": The last message exchanged during the agent's run. - Any additional keys defined in the `state_schema`. :raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`. + :raises BreakpointException: If an agent breakpoint is triggered. + """ + # kwargs can contain the key parent_snapshot. + # We pop it here to avoid passing it into State. We explicitly pass it on if a break point is + # triggered. + parent_snapshot = kwargs.pop("parent_snapshot", None) + if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"): raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run()'.") - if self.system_prompt is not None: - messages = [ChatMessage.from_system(self.system_prompt)] + messages - - if all(m.is_from(ChatRole.SYSTEM) for m in messages): - logger.warning( - "All messages provided to the Agent component are system messages. This is not recommended as the " - "Agent will not perform any actions specific to user input. Consider adding user messages to the input." + if break_point and snapshot: + raise ValueError( + "break_point and snapshot cannot be provided at the same time. The agent run will be aborted." ) - state = State(schema=self.state_schema, data=kwargs) + # validate breakpoints + if break_point and isinstance(break_point.break_point, ToolBreakpoint): + _validate_tool_breakpoint_is_valid(agent_breakpoint=break_point, tools=self.tools) + + # Handle agent snapshot if provided + if snapshot: + component_visits = snapshot.component_visits + current_inputs = { + "chat_generator": _deserialize_value_with_schema(snapshot.component_inputs["chat_generator"]), + "tool_invoker": _deserialize_value_with_schema(snapshot.component_inputs["tool_invoker"]), + } + state_data = current_inputs["tool_invoker"]["state"].data + if isinstance(snapshot.break_point.break_point, ToolBreakpoint): + # If the break point is a ToolBreakpoint, we need to get the messages from the tool invoker inputs + messages = current_inputs["tool_invoker"]["messages"] + # Needed to know if we should start with the ToolInvoker or ChatGenerator + skip_chat_generator = True + else: + messages = current_inputs["chat_generator"]["messages"] + skip_chat_generator = False + # We also load the streaming_callback from the snapshot if it exists + streaming_callback = current_inputs["chat_generator"].get("streaming_callback", None) + else: + skip_chat_generator = False + if self.system_prompt is not None: + messages = [ChatMessage.from_system(self.system_prompt)] + messages + + if all(m.is_from(ChatRole.SYSTEM) for m in messages): + logger.warning( + "All messages provided to the Agent component are system messages. This is not recommended as the " + "Agent will not perform any actions specific to user input. Consider adding user messages to the " + "input." + ) + component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) + state_data = kwargs + + state = State(schema=self.state_schema, data=state_data) state.set("messages", messages) - component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False @@ -274,23 +331,68 @@ def run( _deepcopy_with_exceptions({"messages": messages, "streaming_callback": streaming_callback, **kwargs}), ) counter = 0 + while counter < self.max_agent_steps: + # check for breakpoint before ChatGenerator + if ( + break_point + and break_point.break_point.component_name == "chat_generator" + and component_visits["chat_generator"] == break_point.break_point.visit_count + ): + agent_snapshot = _create_agent_snapshot( + component_visits=component_visits, + agent_breakpoint=break_point, + component_inputs={ + "chat_generator": {"messages": messages, **generator_inputs}, + "tool_invoker": {"messages": [], "state": state, "streaming_callback": streaming_callback}, + }, + ) + _check_chat_generator_breakpoint(agent_snapshot=agent_snapshot, parent_snapshot=parent_snapshot) + # 1. Call the ChatGenerator - result = Pipeline._run_component( - component_name="chat_generator", - component={"instance": self.chat_generator}, - inputs={"messages": messages, **generator_inputs}, - component_visits=component_visits, - parent_span=span, - ) - llm_messages = result["replies"] - state.set("messages", llm_messages) + # We skip the chat generator when restarting from a snapshot where we restart at the ToolInvoker. + if skip_chat_generator: + llm_messages = state.get("messages", [])[-1:] + # We set it to False to ensure that the next iteration will call the chat generator again + skip_chat_generator = False + else: + result = Pipeline._run_component( + component_name="chat_generator", + component={"instance": self.chat_generator}, + inputs={"messages": messages, **generator_inputs}, + component_visits=component_visits, + parent_span=span, + ) + llm_messages = result["replies"] + state.set("messages", llm_messages) # 2. Check if any of the LLM responses contain a tool call or if the LLM is not using tools if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None: counter += 1 break + # check for breakpoint before ToolInvoker + if ( + break_point + and break_point.break_point.component_name == "tool_invoker" + and break_point.break_point.visit_count == component_visits["tool_invoker"] + ): + agent_snapshot = _create_agent_snapshot( + component_visits=component_visits, + agent_breakpoint=break_point, + component_inputs={ + "chat_generator": {"messages": messages[:-1], **generator_inputs}, + "tool_invoker": { + "messages": llm_messages, + "state": state, + "streaming_callback": streaming_callback, + }, + }, + ) + _check_tool_invoker_breakpoint( + llm_messages=llm_messages, agent_snapshot=agent_snapshot, parent_snapshot=parent_snapshot + ) + # 3. Call the ToolInvoker # We only send the messages from the LLM to the tool invoker tool_invoker_result = Pipeline._run_component( @@ -327,8 +429,14 @@ def run( result.update({"last_message": all_messages[-1]}) return result - async def run_async( - self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any + async def run_async( # noqa: PLR0915 + self, + messages: List[ChatMessage], + streaming_callback: Optional[StreamingCallbackT] = None, + *, + break_point: Optional[AgentBreakpoint] = None, + snapshot: Optional[AgentSnapshot] = None, + **kwargs: Any, ) -> Dict[str, Any]: """ Asynchronously process messages and execute tools until the exit condition is met. @@ -339,8 +447,10 @@ async def run_async( :param messages: List of chat messages to process :param streaming_callback: An asynchronous callback that will be invoked when a response - is streamed from the LLM. The same callback can be configured to emit tool results - when a tool is called. + is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint + for "tool_invoker". + :param snapshot: A dictionary containing the state of a previously saved agent execution. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :returns: @@ -349,22 +459,59 @@ async def run_async( - "last_message": The last message exchanged during the agent's run. - Any additional keys defined in the `state_schema`. :raises RuntimeError: If the Agent component wasn't warmed up before calling `run_async()`. + :raises BreakpointException: If an agent breakpoint is triggered. """ + # kwargs can contain the key parent_snapshot. + # We pop it here to avoid passing it into State. We explicitly handle it pass it on if a break point is + # triggered. + parent_snapshot = kwargs.pop("parent_snapshot", None) + if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"): raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run_async()'.") - if self.system_prompt is not None: - messages = [ChatMessage.from_system(self.system_prompt)] + messages - - if all(m.is_from(ChatRole.SYSTEM) for m in messages): - logger.warning( - "All messages provided to the Agent component are system messages. This is not recommended as the " - "Agent will not perform any actions specific to user input. Consider adding user messages to the input." + if break_point and snapshot: + raise ValueError( + "break_point and snapshot cannot be provided at the same time. The agent run will be aborted." ) - state = State(schema=self.state_schema, data=kwargs) + # validate breakpoints + if break_point and isinstance(break_point.break_point, ToolBreakpoint): + _validate_tool_breakpoint_is_valid(agent_breakpoint=break_point, tools=self.tools) + + # Handle agent snapshot if provided + if snapshot: + component_visits = snapshot.component_visits + current_inputs = { + "chat_generator": _deserialize_value_with_schema(snapshot.component_inputs["chat_generator"]), + "tool_invoker": _deserialize_value_with_schema(snapshot.component_inputs["tool_invoker"]), + } + state_data = current_inputs["tool_invoker"]["state"].data + if isinstance(snapshot.break_point.break_point, ToolBreakpoint): + # If the break point is a ToolBreakpoint, we need to get the messages from the tool invoker inputs + messages = current_inputs["tool_invoker"]["messages"] + # Needed to know if we should start with the ToolInvoker or ChatGenerator + skip_chat_generator = True + else: + messages = current_inputs["chat_generator"]["messages"] + skip_chat_generator = False + # We also load the streaming_callback from the snapshot if it exists + streaming_callback = current_inputs["chat_generator"].get("streaming_callback", None) + else: + skip_chat_generator = False + if self.system_prompt is not None: + messages = [ChatMessage.from_system(self.system_prompt)] + messages + + if all(m.is_from(ChatRole.SYSTEM) for m in messages): + logger.warning( + "All messages provided to the Agent component are system messages. This is not recommended as the " + "Agent will not perform any actions specific to user input. Consider adding user messages to the " + "input." + ) + component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) + state_data = kwargs + + state = State(schema=self.state_schema, data=state_data) state.set("messages", messages) - component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0) streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True @@ -376,23 +523,69 @@ async def run_async( _deepcopy_with_exceptions({"messages": messages, "streaming_callback": streaming_callback, **kwargs}), ) counter = 0 + while counter < self.max_agent_steps: + # check for breakpoint before ChatGenerator + if ( + break_point + and break_point.break_point.component_name == "chat_generator" + and component_visits["chat_generator"] == break_point.break_point.visit_count + ): + agent_snapshot = _create_agent_snapshot( + component_visits=component_visits, + agent_breakpoint=break_point, + component_inputs={ + "chat_generator": {"messages": messages, **generator_inputs}, + "tool_invoker": {"messages": [], "state": state, "streaming_callback": streaming_callback}, + }, + ) + _check_chat_generator_breakpoint(agent_snapshot=agent_snapshot, parent_snapshot=parent_snapshot) + # 1. Call the ChatGenerator - result = await AsyncPipeline._run_component_async( - component_name="chat_generator", - component={"instance": self.chat_generator}, - component_inputs={"messages": messages, **generator_inputs}, - component_visits=component_visits, - parent_span=span, - ) - llm_messages = result["replies"] - state.set("messages", llm_messages) + # If skip_chat_generator is True, we skip the chat generator and use the messages from the state + # This is useful when the agent is resumed from a snapshot where the chat generator already ran. + if skip_chat_generator: + llm_messages = state.get("messages", [])[-1:] # Get the last message from the state + # We set it to False to ensure that the next iteration will call the chat generator again + skip_chat_generator = False + else: + result = await AsyncPipeline._run_component_async( + component_name="chat_generator", + component={"instance": self.chat_generator}, + component_inputs={"messages": messages, **generator_inputs}, + component_visits=component_visits, + parent_span=span, + ) + llm_messages = result["replies"] + state.set("messages", llm_messages) # 2. Check if any of the LLM responses contain a tool call or if the LLM is not using tools if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None: counter += 1 break + # Check for breakpoint before ToolInvoker + if ( + break_point + and break_point.break_point.component_name == "tool_invoker" + and break_point.break_point.visit_count == component_visits["tool_invoker"] + ): + agent_snapshot = _create_agent_snapshot( + component_visits=component_visits, + agent_breakpoint=break_point, + component_inputs={ + "chat_generator": {"messages": messages[:-1], **generator_inputs}, + "tool_invoker": { + "messages": llm_messages, + "state": state, + "streaming_callback": streaming_callback, + }, + }, + ) + _check_tool_invoker_breakpoint( + llm_messages=llm_messages, agent_snapshot=agent_snapshot, parent_snapshot=parent_snapshot + ) + # 3. Call the ToolInvoker # We only send the messages from the LLM to the tool invoker tool_invoker_result = await AsyncPipeline._run_component_async( diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 679460496f..af425c50f9 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -9,9 +9,10 @@ from functools import partial from typing import Any, Dict, List, Optional, Set, Union -from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.agents import State +from haystack.core.component.component import component from haystack.core.component.sockets import Sockets +from haystack.core.serialization import default_from_dict, default_to_dict, logging from haystack.dataclasses import ChatMessage, ToolCall from haystack.dataclasses.streaming_chunk import StreamingCallbackT, StreamingChunk, select_streaming_callback from haystack.tools import ( diff --git a/haystack/core/errors.py b/haystack/core/errors.py index 04c4ccc864..d3ddbc46c6 100644 --- a/haystack/core/errors.py +++ b/haystack/core/errors.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Optional, Type +from typing import Any, Dict, Optional, Type class PipelineError(Exception): @@ -89,3 +89,30 @@ class DeserializationError(Exception): class SerializationError(Exception): pass + + +class BreakpointException(Exception): + """ + Exception raised when a pipeline breakpoint is triggered. + """ + + def __init__( + self, + message: str, + component: Optional[str] = None, + inputs: Optional[Dict[str, Any]] = None, + results: Optional[Dict[str, Any]] = None, + ): + super().__init__(message) + self.component = component + self.inputs = inputs + self.results = results + + +class PipelineInvalidPipelineSnapshotError(Exception): + """ + Exception raised when a pipeline is resumed from an invalid snapshot. + """ + + def __init__(self, message: str): + super().__init__(message) diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index 68908144d2..c0d318e7e1 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -1079,7 +1079,9 @@ def _convert_to_internal_format(pipeline_inputs: Dict[str, Any]) -> Dict[str, Di return inputs @staticmethod - def _consume_component_inputs(component_name: str, component: Dict, inputs: Dict) -> Dict[str, Any]: + def _consume_component_inputs( + component_name: str, component: Dict, inputs: Dict, is_resume: bool = False + ) -> Dict[str, Any]: """ Extracts the inputs needed to run for the component and removes them from the global inputs state. @@ -1094,6 +1096,11 @@ def _consume_component_inputs(component_name: str, component: Dict, inputs: Dict for socket_name, socket in component["input_sockets"].items(): socket_inputs = component_inputs.get(socket_name, []) socket_inputs = [sock["value"] for sock in socket_inputs if sock["value"] is not _NO_OUTPUT_PRODUCED] + + # if we are resuming a component, the inputs are already consumed, so we just return the first input + if is_resume: + consumed_inputs[socket_name] = socket_inputs[0] + continue if socket_inputs: if not socket.is_variadic: # We only care about the first input provided to the socket. diff --git a/haystack/core/pipeline/breakpoint.py b/haystack/core/pipeline/breakpoint.py new file mode 100644 index 0000000000..6ec1733846 --- /dev/null +++ b/haystack/core/pipeline/breakpoint.py @@ -0,0 +1,442 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +from copy import deepcopy +from dataclasses import replace +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Union + +from networkx import MultiDiGraph + +from haystack import logging +from haystack.core.errors import BreakpointException, PipelineInvalidPipelineSnapshotError +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import ( + AgentBreakpoint, + AgentSnapshot, + Breakpoint, + PipelineSnapshot, + PipelineState, + ToolBreakpoint, +) +from haystack.tools import Tool, Toolset +from haystack.utils.base_serialization import _serialize_value_with_schema + +logger = logging.getLogger(__name__) + + +def _validate_break_point_against_pipeline( + break_point: Union[Breakpoint, AgentBreakpoint], graph: MultiDiGraph +) -> None: + """ + Validates the breakpoints passed to the pipeline. + + Makes sure the breakpoint contains a valid components registered in the pipeline. + + :param break_point: a breakpoint to validate, can be Breakpoint or AgentBreakpoint + """ + + # all Breakpoints must refer to a valid component in the pipeline + if isinstance(break_point, Breakpoint) and break_point.component_name not in graph.nodes: + raise ValueError(f"break_point {break_point} is not a registered component in the pipeline") + + if isinstance(break_point, AgentBreakpoint): + breakpoint_agent_component = graph.nodes.get(break_point.agent_name) + if not breakpoint_agent_component: + raise ValueError(f"break_point {break_point} is not a registered Agent component in the pipeline") + + if isinstance(break_point.break_point, ToolBreakpoint): + instance = breakpoint_agent_component["instance"] + for tool in instance.tools: + if break_point.break_point.tool_name == tool.name: + break + else: + raise ValueError( + f"break_point {break_point.break_point} is not a registered tool in the Agent component" + ) + + +def _validate_pipeline_snapshot_against_pipeline(pipeline_snapshot: PipelineSnapshot, graph: MultiDiGraph) -> None: + """ + Validates that the pipeline_snapshot contains valid configuration for the current pipeline. + + Raises a PipelineInvalidPipelineSnapshotError if any component in pipeline_snapshot is not part of the + target pipeline. + + :param pipeline_snapshot: The saved state to validate. + """ + + pipeline_state = pipeline_snapshot.pipeline_state + valid_components = set(graph.nodes.keys()) + + # Check if the ordered_component_names are valid components in the pipeline + invalid_ordered_components = set(pipeline_snapshot.ordered_component_names) - valid_components + if invalid_ordered_components: + raise PipelineInvalidPipelineSnapshotError( + f"Invalid pipeline snapshot: components {invalid_ordered_components} in 'ordered_component_names' " + f"are not part of the current pipeline." + ) + + # Check if the original_input_data is valid components in the pipeline + serialized_input_data = pipeline_snapshot.original_input_data["serialized_data"] + invalid_input_data = set(serialized_input_data.keys()) - valid_components + if invalid_input_data: + raise PipelineInvalidPipelineSnapshotError( + f"Invalid pipeline snapshot: components {invalid_input_data} in 'input_data' " + f"are not part of the current pipeline." + ) + + # Validate 'component_visits' + invalid_component_visits = set(pipeline_state.component_visits.keys()) - valid_components + if invalid_component_visits: + raise PipelineInvalidPipelineSnapshotError( + f"Invalid pipeline snapshot: components {invalid_component_visits} in 'component_visits' " + f"are not part of the current pipeline." + ) + + if isinstance(pipeline_snapshot.break_point, AgentBreakpoint): + component_name = pipeline_snapshot.break_point.agent_name + else: + component_name = pipeline_snapshot.break_point.component_name + + visit_count = pipeline_snapshot.pipeline_state.component_visits[component_name] + + logger.info( + "Resuming pipeline from {component} with visit count {visits}", component=component_name, visits=visit_count + ) + + +def load_pipeline_snapshot(file_path: Union[str, Path]) -> PipelineSnapshot: + """ + Load a saved pipeline snapshot. + + :param file_path: Path to the pipeline_snapshot file. + :returns: + Dict containing the loaded pipeline_snapshot. + """ + + file_path = Path(file_path) + + try: + with open(file_path, "r", encoding="utf-8") as f: + pipeline_snapshot_dict = json.load(f) + except FileNotFoundError: + raise FileNotFoundError(f"File not found: {file_path}") + except json.JSONDecodeError as e: + raise json.JSONDecodeError(f"Invalid JSON file {file_path}: {str(e)}", e.doc, e.pos) + except IOError as e: + raise IOError(f"Error reading {file_path}: {str(e)}") + + try: + pipeline_snapshot = PipelineSnapshot.from_dict(pipeline_snapshot_dict) + except ValueError as e: + raise ValueError(f"Invalid pipeline snapshot from {file_path}: {str(e)}") + + logger.info(f"Successfully loaded the pipeline snapshot from: {file_path}") + return pipeline_snapshot + + +def _save_pipeline_snapshot_to_file( + *, pipeline_snapshot: PipelineSnapshot, snapshot_file_path: Union[str, Path], dt: datetime +) -> None: + """ + Save the pipeline snapshot dictionary to a JSON file. + + :param pipeline_snapshot: The pipeline snapshot to save. + :param snapshot_file_path: The path where to save the file. + :param dt: The datetime object for timestamping. + :raises: + ValueError: If the snapshot_file_path is not a string or a Path object. + Exception: If saving the JSON snapshot fails. + """ + snapshot_file_path = Path(snapshot_file_path) if isinstance(snapshot_file_path, str) else snapshot_file_path + if not isinstance(snapshot_file_path, Path): + raise ValueError("Debug path must be a string or a Path object.") + + snapshot_file_path.mkdir(exist_ok=True) + + # Generate filename + # We check if the agent_name is provided to differentiate between agent and non-agent breakpoints + if isinstance(pipeline_snapshot.break_point, AgentBreakpoint): + agent_name = pipeline_snapshot.break_point.agent_name + component_name = pipeline_snapshot.break_point.break_point.component_name + file_name = f"{agent_name}_{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json" + else: + component_name = pipeline_snapshot.break_point.component_name + file_name = f"{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json" + + try: + with open(snapshot_file_path / file_name, "w") as f_out: + json.dump(pipeline_snapshot.to_dict(), f_out, indent=2) + logger.info(f"Pipeline snapshot saved at: {file_name}") + except Exception as e: + logger.error(f"Failed to save pipeline snapshot: {str(e)}") + raise + + +def _create_pipeline_snapshot( + *, + inputs: Dict[str, Any], + break_point: Union[AgentBreakpoint, Breakpoint], + component_visits: Dict[str, int], + original_input_data: Optional[Dict[str, Any]] = None, + ordered_component_names: Optional[List[str]] = None, + include_outputs_from: Optional[Set[str]] = None, + pipeline_outputs: Optional[Dict[str, Any]] = None, +) -> PipelineSnapshot: + """ + Create a snapshot of the pipeline at the point where the breakpoint was triggered. + + :param inputs: The current pipeline snapshot inputs. + :param break_point: The breakpoint that triggered the snapshot, can be AgentBreakpoint or Breakpoint. + :param component_visits: The visit count of the component that triggered the breakpoint. + :param original_input_data: The original input data. + :param ordered_component_names: The ordered component names. + :param include_outputs_from: Set of component names whose outputs should be included in the pipeline results. + :param intermediate_outputs: Dictionary containing outputs from components that are in the include_outputs_from set. + """ + dt = datetime.now() + + transformed_original_input_data = _transform_json_structure(original_input_data) + transformed_inputs = _transform_json_structure(inputs) + + pipeline_snapshot = PipelineSnapshot( + pipeline_state=PipelineState( + inputs=_serialize_value_with_schema(transformed_inputs), # current pipeline inputs + component_visits=component_visits, + pipeline_outputs=pipeline_outputs or {}, + ), + timestamp=dt, + break_point=break_point, + original_input_data=_serialize_value_with_schema(transformed_original_input_data), + ordered_component_names=ordered_component_names or [], + include_outputs_from=include_outputs_from or set(), + ) + return pipeline_snapshot + + +def _save_pipeline_snapshot(pipeline_snapshot: PipelineSnapshot) -> PipelineSnapshot: + """ + Save the pipeline snapshot to a file. + + :param pipeline_snapshot: The pipeline snapshot to save. + + :returns: + The dictionary containing the snapshot of the pipeline containing the following keys: + - input_data: The original input data passed to the pipeline. + - timestamp: The timestamp of the breakpoint. + - pipeline_breakpoint: The component name and visit count that triggered the breakpoint. + - pipeline_state: The state of the pipeline when the breakpoint was triggered containing the following keys: + - inputs: The current state of inputs for pipeline components. + - component_visits: The visit count of the components when the breakpoint was triggered. + - ordered_component_names: The order of components in the pipeline. + """ + break_point = pipeline_snapshot.break_point + if isinstance(break_point, AgentBreakpoint): + snapshot_file_path = break_point.break_point.snapshot_file_path + else: + snapshot_file_path = break_point.snapshot_file_path + + if snapshot_file_path is not None: + dt = pipeline_snapshot.timestamp or datetime.now() + _save_pipeline_snapshot_to_file( + pipeline_snapshot=pipeline_snapshot, snapshot_file_path=snapshot_file_path, dt=dt + ) + + return pipeline_snapshot + + +def _transform_json_structure(data: Union[Dict[str, Any], List[Any], Any]) -> Any: + """ + Transforms a JSON structure by removing the 'sender' key and moving the 'value' to the top level. + + For example: + "key": [{"sender": null, "value": "some value"}] -> "key": "some value" + + :param data: The JSON structure to transform. + :returns: The transformed structure. + """ + if isinstance(data, dict): + # If this dict has both 'sender' and 'value', return just the value + if "value" in data and "sender" in data: + return data["value"] + # Otherwise, recursively process each key-value pair + return {k: _transform_json_structure(v) for k, v in data.items()} + + if isinstance(data, list): + # First, transform each item in the list. + transformed = [_transform_json_structure(item) for item in data] + # If the original list has exactly one element and that element was a dict + # with 'sender' and 'value', then unwrap the list. + if len(data) == 1 and isinstance(data[0], dict) and "value" in data[0] and "sender" in data[0]: + return transformed[0] + return transformed + + # For other data types, just return the value as is. + return data + + +def _trigger_break_point(*, pipeline_snapshot: PipelineSnapshot, pipeline_outputs: Dict[str, Any]) -> None: + """ + Trigger a breakpoint by saving a snapshot and raising exception. + + :param pipeline_snapshot: The current pipeline snapshot containing the state and break point + :param pipeline_outputs: Current pipeline outputs + :raises PipelineBreakpointException: When breakpoint is triggered + """ + _save_pipeline_snapshot(pipeline_snapshot=pipeline_snapshot) + + if isinstance(pipeline_snapshot.break_point, Breakpoint): + component_name = pipeline_snapshot.break_point.component_name + else: + component_name = pipeline_snapshot.break_point.agent_name + + component_visits = pipeline_snapshot.pipeline_state.component_visits + msg = f"Breaking at component {component_name} at visit count {component_visits[component_name]}" + raise BreakpointException( + message=msg, component=component_name, inputs=pipeline_snapshot.pipeline_state.inputs, results=pipeline_outputs + ) + + +def _create_agent_snapshot( + *, component_visits: Dict[str, int], agent_breakpoint: AgentBreakpoint, component_inputs: Dict[str, Any] +) -> AgentSnapshot: + """ + Create a snapshot of the agent's state. + + :param component_visits: The visit counts for the agent's components. + :param agent_breakpoint: AgentBreakpoint object containing breakpoints + :return: An AgentSnapshot containing the agent's state and component visits. + """ + return AgentSnapshot( + component_inputs={ + "chat_generator": _serialize_value_with_schema(deepcopy(component_inputs["chat_generator"])), + "tool_invoker": _serialize_value_with_schema(deepcopy(component_inputs["tool_invoker"])), + }, + component_visits=component_visits, + break_point=agent_breakpoint, + timestamp=datetime.now(), + ) + + +def _validate_tool_breakpoint_is_valid(agent_breakpoint: AgentBreakpoint, tools: Union[List[Tool], Toolset]) -> None: + """ + Validates the AgentBreakpoint passed to the agent. + + Validates that the tool name in ToolBreakpoints correspond to a tool available in the agent. + + :param agent_breakpoint: AgentBreakpoint object containing breakpoints for the agent components. + :param tools: List of Tool objects or a Toolset that the agent can use. + :raises ValueError: If any tool name in ToolBreakpoints is not available in the agent's tools. + """ + + available_tool_names = {tool.name for tool in tools} + tool_breakpoint = agent_breakpoint.break_point + # Assert added for mypy to pass, but this is already checked before this function is called + assert isinstance(tool_breakpoint, ToolBreakpoint) + if tool_breakpoint.tool_name and tool_breakpoint.tool_name not in available_tool_names: + raise ValueError(f"Tool '{tool_breakpoint.tool_name}' is not available in the agent's tools") + + +def _check_chat_generator_breakpoint( + *, agent_snapshot: AgentSnapshot, parent_snapshot: Optional[PipelineSnapshot] +) -> None: + """ + Check for breakpoint before calling the ChatGenerator. + + :param agent_snapshot: AgentSnapshot object containing the agent's state and breakpoints + :param parent_snapshot: Optional parent snapshot containing the state of the pipeline that houses the agent. + :raises BreakpointException: If a breakpoint is triggered + """ + + break_point = agent_snapshot.break_point.break_point + + if parent_snapshot is None: + # Create an empty pipeline snapshot if no parent snapshot is provided + final_snapshot = PipelineSnapshot( + pipeline_state=PipelineState(inputs={}, component_visits={}, pipeline_outputs={}), + timestamp=agent_snapshot.timestamp, + break_point=agent_snapshot.break_point, + agent_snapshot=agent_snapshot, + original_input_data={}, + ordered_component_names=[], + include_outputs_from=set(), + ) + else: + final_snapshot = replace(parent_snapshot, agent_snapshot=agent_snapshot) + _save_pipeline_snapshot(pipeline_snapshot=final_snapshot) + + msg = ( + f"Breaking at {break_point.component_name} visit count " + f"{agent_snapshot.component_visits[break_point.component_name]}" + ) + logger.info(msg) + raise BreakpointException( + message=msg, + component=break_point.component_name, + inputs=agent_snapshot.component_inputs, + results=agent_snapshot.component_inputs["tool_invoker"]["serialized_data"]["state"], + ) + + +def _check_tool_invoker_breakpoint( + *, llm_messages: List[ChatMessage], agent_snapshot: AgentSnapshot, parent_snapshot: Optional[PipelineSnapshot] +) -> None: + """ + Check for breakpoint before calling the ToolInvoker. + + :param llm_messages: List of ChatMessage objects containing potential tool calls. + :param agent_snapshot: AgentSnapshot object containing the agent's state and breakpoints. + :param parent_snapshot: Optional parent snapshot containing the state of the pipeline that houses the agent. + :raises BreakpointException: If a breakpoint is triggered + """ + if not isinstance(agent_snapshot.break_point.break_point, ToolBreakpoint): + return + + tool_breakpoint = agent_snapshot.break_point.break_point + + # Check if we should break for this specific tool or all tools + if tool_breakpoint.tool_name is None: + # Break for any tool call + should_break = any(msg.tool_call for msg in llm_messages) + else: + # Break only for the specific tool + should_break = any( + msg.tool_call and msg.tool_call.tool_name == tool_breakpoint.tool_name for msg in llm_messages + ) + + if not should_break: + return # No breakpoint triggered + + if parent_snapshot is None: + # Create an empty pipeline snapshot if no parent snapshot is provided + final_snapshot = PipelineSnapshot( + pipeline_state=PipelineState(inputs={}, component_visits={}, pipeline_outputs={}), + timestamp=agent_snapshot.timestamp, + break_point=agent_snapshot.break_point, + agent_snapshot=agent_snapshot, + original_input_data={}, + ordered_component_names=[], + include_outputs_from=set(), + ) + else: + final_snapshot = replace(parent_snapshot, agent_snapshot=agent_snapshot) + _save_pipeline_snapshot(pipeline_snapshot=final_snapshot) + + msg = ( + f"Breaking at {tool_breakpoint.component_name} visit count " + f"{agent_snapshot.component_visits[tool_breakpoint.component_name]}" + ) + if tool_breakpoint.tool_name: + msg += f" for tool {tool_breakpoint.tool_name}" + logger.info(msg) + + raise BreakpointException( + message=msg, + component=tool_breakpoint.component_name, + inputs=agent_snapshot.component_inputs, + results=agent_snapshot.component_inputs["tool_invoker"]["serialized_data"]["state"], + ) diff --git a/haystack/core/pipeline/pipeline.py b/haystack/core/pipeline/pipeline.py index 8020f6c104..221b496ffe 100644 --- a/haystack/core/pipeline/pipeline.py +++ b/haystack/core/pipeline/pipeline.py @@ -2,11 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, Mapping, Optional, Set +from copy import deepcopy +from typing import Any, Dict, Mapping, Optional, Set, Union from haystack import logging, tracing from haystack.core.component import Component -from haystack.core.errors import PipelineRuntimeError +from haystack.core.errors import BreakpointException, PipelineInvalidPipelineSnapshotError, PipelineRuntimeError from haystack.core.pipeline.base import ( _COMPONENT_INPUT, _COMPONENT_OUTPUT, @@ -14,8 +15,16 @@ ComponentPriority, PipelineBase, ) +from haystack.core.pipeline.breakpoint import ( + _create_pipeline_snapshot, + _trigger_break_point, + _validate_break_point_against_pipeline, + _validate_pipeline_snapshot_against_pipeline, +) from haystack.core.pipeline.utils import _deepcopy_with_exceptions +from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, PipelineSnapshot from haystack.telemetry import pipeline_running +from haystack.utils import _deserialize_value_with_schema logger = logging.getLogger(__name__) @@ -59,6 +68,11 @@ def _run_component( try: component_output = instance.run(**inputs) + except BreakpointException as error: + # Re-raise BreakpointException to preserve the original exception context + # This is important when Agent components internally use Pipeline._run_component + # and trigger breakpoints that need to bubble up to the main pipeline + raise error except Exception as error: raise PipelineRuntimeError.from_exception(component_name, instance.__class__, error) from error @@ -72,8 +86,13 @@ def _run_component( return component_output - def run( # noqa: PLR0915, PLR0912 - self, data: Dict[str, Any], include_outputs_from: Optional[Set[str]] = None + def run( # noqa: PLR0915, PLR0912, C901, pylint: disable=too-many-branches + self, + data: Dict[str, Any], + include_outputs_from: Optional[Set[str]] = None, + *, + break_point: Optional[Union[Breakpoint, AgentBreakpoint]] = None, + pipeline_snapshot: Optional[PipelineSnapshot] = None, ) -> Dict[str, Any]: """ Runs the Pipeline with given input data. @@ -149,6 +168,13 @@ def run( # noqa: PLR0915, PLR0912 included in the pipeline's output. For components that are invoked multiple times (in a loop), only the last-produced output is included. + + :param break_point: + A set of breakpoints that can be used to debug the pipeline execution. + + :param pipeline_snapshot: + A dictionary containing a snapshot of a previously saved pipeline execution. + :returns: A dictionary where each entry corresponds to a component name and its output. If `include_outputs_from` is `None`, this dictionary @@ -163,36 +189,65 @@ def run( # noqa: PLR0915, PLR0912 Or if a Component fails or returns output in an unsupported type. :raises PipelineMaxComponentRuns: If a Component reaches the maximum number of times it can be run in this Pipeline. + :raises PipelineBreakpointException: + When a pipeline_breakpoint is triggered. Contains the component name, state, and partial results. """ pipeline_running(self) + if break_point and pipeline_snapshot: + msg = ( + "pipeline_breakpoint and pipeline_snapshot cannot be provided at the same time. " + "The pipeline run will be aborted." + ) + raise PipelineInvalidPipelineSnapshotError(message=msg) + + # make sure all breakpoints are valid, i.e. reference components in the pipeline + if break_point: + _validate_break_point_against_pipeline(break_point, self.graph) + # TODO: Remove this warmup once we can check reliably whether a component has been warmed up or not # As of now it's here to make sure we don't have failing tests that assume warm_up() is called in run() self.warm_up() - # normalize `data` - data = self._prepare_component_input_data(data) - - # Raise ValueError if input is malformed in some way - self.validate_input(data) - if include_outputs_from is None: include_outputs_from = set() - # We create a list of components in the pipeline sorted by name, so that the algorithm runs deterministically - # and independent of insertion order into the pipeline. - ordered_component_names = sorted(self.graph.nodes.keys()) + pipeline_outputs: Dict[str, Any] = {} + + if not pipeline_snapshot: + # normalize `data` + data = self._prepare_component_input_data(data) + + # Raise ValueError if input is malformed in some way + self.validate_input(data) + + # We create a list of components in the pipeline sorted by name, so that the algorithm runs + # deterministically and independent of insertion order into the pipeline. + ordered_component_names = sorted(self.graph.nodes.keys()) + + # We track component visits to decide if a component can run. + component_visits = dict.fromkeys(ordered_component_names, 0) + + else: + # Validate the pipeline snapshot against the current pipeline graph + _validate_pipeline_snapshot_against_pipeline(pipeline_snapshot, self.graph) + + # Handle resuming the pipeline from a snapshot + component_visits = pipeline_snapshot.pipeline_state.component_visits + ordered_component_names = pipeline_snapshot.ordered_component_names + data = _deserialize_value_with_schema(pipeline_snapshot.pipeline_state.inputs) - # We track component visits to decide if a component can run. - component_visits = dict.fromkeys(ordered_component_names, 0) + # include_outputs_from from the snapshot when resuming + include_outputs_from = pipeline_snapshot.include_outputs_from + # also intermediate_outputs from the snapshot when resuming + pipeline_outputs = pipeline_snapshot.pipeline_state.pipeline_outputs + + cached_topological_sort = None # We need to access a component's receivers multiple times during a pipeline run. # We store them here for easy access. cached_receivers = {name: self._find_receivers_from(name) for name in ordered_component_names} - cached_topological_sort = None - - pipeline_outputs: Dict[str, Any] = {} with tracing.tracer.trace( "haystack.pipeline.run", tags={ @@ -242,30 +297,98 @@ def run( # noqa: PLR0915, PLR0912 priority_queue=priority_queue, topological_sort=cached_topological_sort, ) + cached_topological_sort = topological_sort component = self._get_component_with_graph_metadata_and_visits( component_name, component_visits[component_name] ) + if pipeline_snapshot: + if isinstance(pipeline_snapshot.break_point, AgentBreakpoint): + name_to_check = pipeline_snapshot.break_point.agent_name + else: + name_to_check = pipeline_snapshot.break_point.component_name + is_resume = name_to_check == component_name + else: + is_resume = False component_inputs = self._consume_component_inputs( - component_name=component_name, component=component, inputs=inputs + component_name=component_name, component=component, inputs=inputs, is_resume=is_resume ) + # We need to add missing defaults using default values from input sockets because the run signature # might not provide these defaults for components with inputs defined dynamically upon component # initialization component_inputs = self._add_missing_input_defaults(component_inputs, component["input_sockets"]) + # Scenario 1: Pipeline snapshot is provided to resume the pipeline at a specific component + # Deserialize the component_inputs if they are passed in the pipeline_snapshot. + # this check will prevent other component_inputs generated at runtime from being deserialized + if pipeline_snapshot and component_name in pipeline_snapshot.pipeline_state.inputs.keys(): + for key, value in component_inputs.items(): + component_inputs[key] = _deserialize_value_with_schema(value) + + # If we are resuming from an AgentBreakpoint, we inject the agent_snapshot into the Agents inputs + if ( + pipeline_snapshot + and isinstance(pipeline_snapshot.break_point, AgentBreakpoint) + and component_name == pipeline_snapshot.break_point.agent_name + ): + component_inputs["snapshot"] = pipeline_snapshot.agent_snapshot + component_inputs["break_point"] = None + + # Scenario 2: A breakpoint is provided to stop the pipeline at a specific component + if break_point: + should_trigger_breakpoint = False + should_create_snapshot = False + + # Scenario 2.1: an AgentBreakpoint is provided to stop the pipeline at a specific component + if isinstance(break_point, AgentBreakpoint) and component_name == break_point.agent_name: + should_create_snapshot = True + component_inputs["break_point"] = break_point + + # Scenario 2.2: a regular breakpoint is provided to stop the pipeline at a specific component and + # visit count + elif ( + isinstance(break_point, Breakpoint) + and break_point.component_name == component_name + and break_point.visit_count == component_visits[component_name] + ): + should_trigger_breakpoint = True + should_create_snapshot = True + + if should_create_snapshot: + pipeline_snapshot_inputs_serialised = deepcopy(inputs) + pipeline_snapshot_inputs_serialised[component_name] = deepcopy(component_inputs) + new_pipeline_snapshot = _create_pipeline_snapshot( + inputs=pipeline_snapshot_inputs_serialised, + break_point=break_point, + component_visits=component_visits, + original_input_data=data, + ordered_component_names=ordered_component_names, + include_outputs_from=include_outputs_from, + pipeline_outputs=pipeline_outputs, + ) + + # add the parent_snapshot to agent inputs if needed + if isinstance(break_point, AgentBreakpoint) and component_name == break_point.agent_name: + component_inputs["parent_snapshot"] = new_pipeline_snapshot + + # trigger the breakpoint if needed + if should_trigger_breakpoint: + _trigger_break_point( + pipeline_snapshot=new_pipeline_snapshot, pipeline_outputs=pipeline_outputs + ) + component_outputs = self._run_component( component_name=component_name, component=component, - inputs=component_inputs, + inputs=component_inputs, # the inputs to the current component component_visits=component_visits, parent_span=span, ) # Updates global input state with component outputs and returns outputs that should go to # pipeline outputs. - component_pipeline_outputs = self._write_component_outputs( component_name=component_name, component_outputs=component_outputs, @@ -275,8 +398,16 @@ def run( # noqa: PLR0915, PLR0912 ) if component_pipeline_outputs: - pipeline_outputs[component_name] = _deepcopy_with_exceptions(component_pipeline_outputs) + pipeline_outputs[component_name] = deepcopy(component_pipeline_outputs) if self._is_queue_stale(priority_queue): priority_queue = self._fill_queue(ordered_component_names, inputs, component_visits) + if isinstance(break_point, Breakpoint): + logger.warning( + "The given breakpoint {break_point} was never triggered. This is because:\n" + "1. The provided component is not a part of the pipeline execution path.\n" + "2. The component did not reach the visit count specified in the pipeline_breakpoint", + pipeline_breakpoint=break_point, + ) + return pipeline_outputs diff --git a/haystack/dataclasses/breakpoints.py b/haystack/dataclasses/breakpoints.py new file mode 100644 index 0000000000..4e0f75ad4c --- /dev/null +++ b/haystack/dataclasses/breakpoints.py @@ -0,0 +1,261 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import asdict, dataclass, field +from datetime import datetime +from typing import Any, Dict, List, Optional, Set, Union + + +@dataclass(frozen=True) +class Breakpoint: + """ + A dataclass to hold a breakpoint for a component. + + :param component_name: The name of the component where the breakpoint is set. + :param visit_count: The number of times the component must be visited before the breakpoint is triggered. + :param snapshot_file_path: Optional path to store a snapshot of the pipeline when the breakpoint is hit. + This is useful for debugging purposes, allowing you to inspect the state of the pipeline at the time of the + breakpoint and to resume execution from that point. + """ + + component_name: str + visit_count: int = 0 + snapshot_file_path: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the Breakpoint to a dictionary representation. + + :return: A dictionary containing the component name, visit count, and debug path. + """ + return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "Breakpoint": + """ + Populate the Breakpoint from a dictionary representation. + + :param data: A dictionary containing the component name, visit count, and debug path. + :return: An instance of Breakpoint. + """ + return cls(**data) + + +@dataclass(frozen=True) +class ToolBreakpoint(Breakpoint): + """ + A dataclass representing a breakpoint specific to tools used within an Agent component. + + Inherits from Breakpoint and adds the ability to target individual tools. If `tool_name` is None, + the breakpoint applies to all tools within the Agent component. + + :param tool_name: The name of the tool to target within the Agent component. If None, applies to all tools. + """ + + tool_name: Optional[str] = None + + def __str__(self) -> str: + tool_str = f", tool_name={self.tool_name}" if self.tool_name else ", tool_name=ALL_TOOLS" + return f"ToolBreakpoint(component_name={self.component_name}, visit_count={self.visit_count}{tool_str})" + + +@dataclass(frozen=True) +class AgentBreakpoint: + """ + A dataclass representing a breakpoint tied to an Agent’s execution. + + This allows for debugging either a specific component (e.g., the chat generator) or a tool used by the agent. + It enforces constraints on which component names are valid for each breakpoint type. + + :param agent_name: The name of the agent component in a pipeline where the breakpoint is set. + :param break_point: An instance of Breakpoint or ToolBreakpoint indicating where to break execution. + + :raises ValueError: If the component_name is invalid for the given breakpoint type: + - Breakpoint must have component_name='chat_generator'. + - ToolBreakpoint must have component_name='tool_invoker'. + """ + + agent_name: str + break_point: Union[Breakpoint, ToolBreakpoint] + + def __post_init__(self): + if ( + isinstance(self.break_point, Breakpoint) and not isinstance(self.break_point, ToolBreakpoint) + ) and self.break_point.component_name != "chat_generator": + raise ValueError("If the break_point is a Breakpoint, it must have the component_name 'chat_generator'.") + + if isinstance(self.break_point, ToolBreakpoint) and self.break_point.component_name != "tool_invoker": + raise ValueError("If the break_point is a ToolBreakpoint, it must have the component_name 'tool_invoker'.") + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the AgentBreakpoint to a dictionary representation. + + :return: A dictionary containing the agent name and the breakpoint details. + """ + return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "AgentBreakpoint": + """ + Populate the AgentBreakpoint from a dictionary representation. + + :param data: A dictionary containing the agent name and the breakpoint details. + :return: An instance of AgentBreakpoint. + """ + break_point_data = data["break_point"] + break_point: Union[Breakpoint, ToolBreakpoint] + if "tool_name" in break_point_data: + break_point = ToolBreakpoint(**break_point_data) + else: + break_point = Breakpoint(**break_point_data) + return cls(agent_name=data["agent_name"], break_point=break_point) + + +@dataclass +class AgentSnapshot: + component_inputs: Dict[str, Any] + component_visits: Dict[str, int] + break_point: AgentBreakpoint + timestamp: Optional[datetime] = None + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the AgentSnapshot to a dictionary representation. + + :return: A dictionary containing the agent state, timestamp, and breakpoint. + """ + return { + "component_inputs": self.component_inputs, + "component_visits": self.component_visits, + "break_point": self.break_point.to_dict(), + "timestamp": self.timestamp.isoformat() if self.timestamp else None, + } + + @classmethod + def from_dict(cls, data: dict) -> "AgentSnapshot": + """ + Populate the AgentSnapshot from a dictionary representation. + + :param data: A dictionary containing the agent state, timestamp, and breakpoint. + :return: An instance of AgentSnapshot. + """ + return cls( + component_inputs=data["component_inputs"], + component_visits=data["component_visits"], + break_point=AgentBreakpoint.from_dict(data["break_point"]), + timestamp=datetime.fromisoformat(data["timestamp"]) if data.get("timestamp") else None, + ) + + +@dataclass +class PipelineState: + """ + A dataclass to hold the state of the pipeline at a specific point in time. + + :param component_visits: A dictionary mapping component names to their visit counts. + :param inputs: The inputs processed by the pipeline at the time of the snapshot. + :param pipeline_outputs: Dictionary containing the final outputs of the pipeline up to the breakpoint. + """ + + inputs: Dict[str, Any] + component_visits: Dict[str, int] + pipeline_outputs: Dict[str, Any] + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the PipelineState to a dictionary representation. + + :return: A dictionary containing the inputs, component visits, + and pipeline outputs. + """ + return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> "PipelineState": + """ + Populate the PipelineState from a dictionary representation. + + :param data: A dictionary containing the inputs, component visits, + and pipeline outputs. + :return: An instance of PipelineState. + """ + return cls(**data) + + +@dataclass +class PipelineSnapshot: + """ + A dataclass to hold a snapshot of the pipeline at a specific point in time. + + :param pipeline_state: The state of the pipeline at the time of the snapshot. + :param break_point: The breakpoint that triggered the snapshot. + :param agent_snapshot: Optional agent snapshot if the breakpoint is an agent breakpoint. + :param timestamp: A timestamp indicating when the snapshot was taken. + :param original_input_data: The original input data provided to the pipeline. + :param ordered_component_names: A list of component names in the order they were visited. + :param include_outputs_from: Set of component names whose outputs should be included in the pipeline results. + """ + + original_input_data: Dict[str, Any] + ordered_component_names: List[str] + pipeline_state: PipelineState + break_point: Union[AgentBreakpoint, Breakpoint] + agent_snapshot: Optional[AgentSnapshot] = None + timestamp: Optional[datetime] = None + include_outputs_from: Set[str] = field(default_factory=set) + + def __post_init__(self): + # Validate consistency between component_visits and ordered_component_names + components_in_state = set(self.pipeline_state.component_visits.keys()) + components_in_order = set(self.ordered_component_names) + + if components_in_state != components_in_order: + raise ValueError( + f"Inconsistent state: components in PipelineState.component_visits {components_in_state} " + f"do not match components in PipelineSnapshot.ordered_component_names {components_in_order}" + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the PipelineSnapshot to a dictionary representation. + + :return: A dictionary containing the pipeline state, timestamp, breakpoint, agent snapshot, original input data, + ordered component names, include_outputs_from, and pipeline outputs. + """ + data = { + "pipeline_state": self.pipeline_state.to_dict(), + "break_point": self.break_point.to_dict(), + "agent_snapshot": self.agent_snapshot.to_dict() if self.agent_snapshot else None, + "timestamp": self.timestamp.isoformat() if self.timestamp else None, + "original_input_data": self.original_input_data, + "ordered_component_names": self.ordered_component_names, + "include_outputs_from": list(self.include_outputs_from), + } + return data + + @classmethod + def from_dict(cls, data: dict) -> "PipelineSnapshot": + """ + Populate the PipelineSnapshot from a dictionary representation. + + :param data: A dictionary containing the pipeline state, timestamp, breakpoint, agent snapshot, original input + data, ordered component names, include_outputs_from, and pipeline outputs. + """ + # Convert include_outputs_from list back to set for serialization + include_outputs_from = set(data.get("include_outputs_from", [])) + + return cls( + pipeline_state=PipelineState.from_dict(data=data["pipeline_state"]), + break_point=( + AgentBreakpoint.from_dict(data=data["break_point"]) + if "agent_name" in data["break_point"] + else Breakpoint.from_dict(data=data["break_point"]) + ), + agent_snapshot=AgentSnapshot.from_dict(data["agent_snapshot"]) if data.get("agent_snapshot") else None, + timestamp=datetime.fromisoformat(data["timestamp"]) if data.get("timestamp") else None, + original_input_data=data.get("original_input_data", {}), + ordered_component_names=data.get("ordered_component_names", []), + include_outputs_from=include_outputs_from, + ) diff --git a/haystack/utils/base_serialization.py b/haystack/utils/base_serialization.py index 3303fc61f3..2f98b30471 100644 --- a/haystack/utils/base_serialization.py +++ b/haystack/utils/base_serialization.py @@ -208,7 +208,7 @@ def _deserialize_value_with_schema(serialized: Dict[str, Any]) -> Any: # pylint schema_type = schema.get("type") if not schema_type: - # for backward comaptability till Haystack 2.16 we use legacy implementation + # for backward compatibility till Haystack 2.16 we use legacy implementation raise DeserializationError( "Missing 'type' key in 'serialization_schema'. This likely indicates that you're using a serialized " "State object created with a version of Haystack older than 2.15.0. " diff --git a/pyproject.toml b/pyproject.toml index 4705b76a83..e4cf9aad7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,6 +147,7 @@ dependencies = [ "pip", # mypy needs pip to install missing stub packages "pylint", "ipython", + "colorama==0.4.6", # Pipeline checkpoints test - ToDo: rewrite the test without this lib ] [tool.hatch.envs.test.scripts] diff --git a/releasenotes/notes/feat-pipeline-and-agents-breakpoints-aa819128c8c5f456.yaml b/releasenotes/notes/feat-pipeline-and-agents-breakpoints-aa819128c8c5f456.yaml new file mode 100644 index 0000000000..74c10fd100 --- /dev/null +++ b/releasenotes/notes/feat-pipeline-and-agents-breakpoints-aa819128c8c5f456.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + The Pipeline and the Agent now support breakpoints, a feature useful for debugging. A breakpoint is associated with a component and it stops the execution of a Pipeline/Agent generating a JSON file with the execution status, which can be inspected + edited and later used to resume the execution. diff --git a/test/components/agents/test_agent_breakpoints_inside_pipeline.py b/test/components/agents/test_agent_breakpoints_inside_pipeline.py new file mode 100644 index 0000000000..91a4d6ebac --- /dev/null +++ b/test/components/agents/test_agent_breakpoints_inside_pipeline.py @@ -0,0 +1,360 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import re +import tempfile +from pathlib import Path +from typing import Dict, List, Optional + +import pytest + +from haystack import component +from haystack.components.agents import Agent +from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.core.errors import BreakpointException +from haystack.core.pipeline import Pipeline +from haystack.core.pipeline.breakpoint import load_pipeline_snapshot +from haystack.dataclasses import ByteStream, ChatMessage, Document, ToolCall +from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack.tools import create_tool_from_function + +document_store = InMemoryDocumentStore() + + +@component +class MockLinkContentFetcher: + @component.output_types(streams=List[ByteStream]) + def run(self, urls: List[str]) -> Dict[str, List[ByteStream]]: + mock_html_content = """ + + + + Deepset - About Our Team + + +

About Deepset

+

Deepset is a company focused on natural language processing and AI.

+

Our Leadership Team

+
+

Malte Pietsch

+

Malte Pietsch is the CEO and co-founder of Deepset. He has extensive experience in machine learning + and natural language processing.

+

Job Title: Chief Executive Officer

+
+
+

Milos Rusic

+

Milos Rusic is the CTO and co-founder of Deepset. He specializes in building scalable AI systems + and has worked on various NLP projects.

+

Job Title: Chief Technology Officer

+
+

Our Mission

+

Deepset aims to make natural language processing accessible to developers and businesses worldwide + through open-source tools and enterprise solutions.

+ + + """ + + bytestream = ByteStream( + data=mock_html_content.encode("utf-8"), + mime_type="text/html", + meta={"url": urls[0] if urls else "https://en.wikipedia.org/wiki/Deepset"}, + ) + + return {"streams": [bytestream]} + + +@component +class MockHTMLToDocument: + @component.output_types(documents=List[Document]) + def run(self, sources: List[ByteStream]) -> Dict[str, List[Document]]: + """Mock HTML to Document converter that extracts text content from HTML ByteStreams.""" + + documents = [] + for source in sources: + # Extract the HTML content from the ByteStream + html_content = source.data.decode("utf-8") + + # Simple text extraction - remove HTML tags and extract meaningful content + # This is a simplified version that extracts the main content + # Remove HTML tags + text_content = re.sub(r"<[^>]+>", " ", html_content) + # Remove extra whitespace + text_content = re.sub(r"\s+", " ", text_content).strip() + + # Create a Document with the extracted text + document = Document( + content=text_content, + meta={"url": source.meta.get("url", "unknown"), "mime_type": source.mime_type, "source_type": "html"}, + ) + documents.append(document) + + return {"documents": documents} + + +def add_database_tool_function(name: str, surname: str, job_title: Optional[str], other: Optional[str]): + document_store.write_documents( + [Document(content=name + " " + surname + " " + (job_title or ""), meta={"other": other})] + ) + + +# We use this since the @tool decorator has issues with deserialization +add_database_tool = create_tool_from_function(add_database_tool_function, name="add_database_tool") + + +@pytest.fixture +def pipeline_with_agent(monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test_key") + generator = OpenAIChatGenerator() + call_count = 0 + + def mock_run(messages, tools=None, **kwargs): + nonlocal call_count + call_count += 1 + + if call_count == 1: + return { + "replies": [ + ChatMessage.from_assistant( + "I'll extract the information about the people mentioned in the context.", + tool_calls=[ + ToolCall( + tool_name="add_database_tool", + arguments={ + "name": "Malte", + "surname": "Pietsch", + "job_title": "Chief Executive Officer", + "other": "CEO and co-founder of Deepset with extensive experience in machine " + "learning and natural language processing", + }, + ), + ToolCall( + tool_name="add_database_tool", + arguments={ + "name": "Milos", + "surname": "Rusic", + "job_title": "Chief Technology Officer", + "other": "CTO and co-founder of Deepset specializing in building scalable " + "AI systems and NLP projects", + }, + ), + ], + ) + ] + } + else: + return { + "replies": [ + ChatMessage.from_assistant( + "I have successfully extracted and stored information about the following people:\n\n" + "1. **Malte Pietsch** - Chief Executive Officer\n" + " - CEO and co-founder of Deepset\n" + " - Extensive experience in machine learning and natural language processing\n\n" + "2. **Milos Rusic** - Chief Technology Officer\n" + " - CTO and co-founder of Deepset\n" + " - Specializes in building scalable AI systems and NLP projects\n\n" + "Both individuals have been added to the knowledge base with their respective information." + ) + ] + } + + generator.run = mock_run + + database_assistant = Agent( + chat_generator=generator, + tools=[add_database_tool], + system_prompt=""" + You are a database assistant. + Your task is to extract the names of people mentioned in the given context and add them to a knowledge base, + along with additional relevant information about them that can be extracted from the context. + Do not use you own knowledge, stay grounded to the given context. + Do not ask the user for confirmation. Instead, automatically update the knowledge base and return a brief + summary of the people added, including the information stored for each. + """, + exit_conditions=["text"], + max_agent_steps=100, + raise_on_tool_invocation_failure=False, + ) + + extraction_agent = Pipeline() + extraction_agent.add_component("fetcher", MockLinkContentFetcher()) + extraction_agent.add_component("converter", MockHTMLToDocument()) + extraction_agent.add_component( + "builder", + ChatPromptBuilder( + template=[ + ChatMessage.from_user(""" + {% for doc in docs %} + {{ doc.content|default|truncate(25000) }} + {% endfor %} + """) + ], + required_variables=["docs"], + ), + ) + extraction_agent.add_component("database_agent", database_assistant) + + extraction_agent.connect("fetcher.streams", "converter.sources") + extraction_agent.connect("converter.documents", "builder.docs") + extraction_agent.connect("builder", "database_agent") + + return extraction_agent + + +def run_pipeline_without_any_breakpoints(pipeline_with_agent): + agent_output = pipeline_with_agent.run(data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}) + + # pipeline completed + assert "database_agent" in agent_output + assert "messages" in agent_output["database_agent"] + assert len(agent_output["database_agent"]["messages"]) > 0 + + # final message contains the expected summary + final_message = agent_output["database_agent"]["messages"][-1].text + assert "Malte Pietsch" in final_message + assert "Milos Rusic" in final_message + assert "Chief Executive Officer" in final_message + assert "Chief Technology Officer" in final_message + + +def test_chat_generator_breakpoint_in_pipeline_agent(pipeline_with_agent): + with tempfile.TemporaryDirectory() as debug_path: + agent_generator_breakpoint = Breakpoint("chat_generator", 0, snapshot_file_path=debug_path) + agent_breakpoint = AgentBreakpoint(break_point=agent_generator_breakpoint, agent_name="database_agent") + try: + pipeline_with_agent.run( + data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint + ) + assert False, "Expected exception was not raised" + + except BreakpointException as e: # this is the exception from the Agent + assert e.component == "chat_generator" + assert e.inputs is not None + assert "messages" in e.inputs["chat_generator"]["serialized_data"] + assert e.results is not None + + # verify that snapshot file was created + chat_generator_snapshot_files = list(Path(debug_path).glob("database_agent_chat_generator_*.json")) + assert len(chat_generator_snapshot_files) > 0, f"No chat_generator snapshot file found in {debug_path}" + + +def test_tool_breakpoint_in_pipeline_agent(pipeline_with_agent): + with tempfile.TemporaryDirectory() as debug_path: + agent_tool_breakpoint = ToolBreakpoint( + "tool_invoker", 0, tool_name="add_database_tool", snapshot_file_path=debug_path + ) + agent_breakpoint = AgentBreakpoint(break_point=agent_tool_breakpoint, agent_name="database_agent") + try: + pipeline_with_agent.run( + data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint + ) + assert False, "Expected exception was not raised" + except BreakpointException as e: # this is the exception from the Agent + assert e.component == "tool_invoker" + assert e.inputs is not None + assert "messages" in e.inputs["tool_invoker"]["serialized_data"] + assert e.results is not None + + # verify that snapshot file was created + tool_invoker_snapshot_files = list(Path(debug_path).glob("database_agent_tool_invoker_*.json")) + assert len(tool_invoker_snapshot_files) > 0, f"No tool_invoker snapshot file found in {debug_path}" + + +def test_agent_breakpoint_chat_generator_and_resume_pipeline(pipeline_with_agent): + with tempfile.TemporaryDirectory() as debug_path: + agent_generator_breakpoint = Breakpoint("chat_generator", 0, snapshot_file_path=debug_path) + agent_breakpoint = AgentBreakpoint(break_point=agent_generator_breakpoint, agent_name="database_agent") + try: + pipeline_with_agent.run( + data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint + ) + assert False, "Expected PipelineBreakpointException was not raised" + + except BreakpointException as e: + assert e.component == "chat_generator" + assert e.inputs is not None + assert "messages" in e.inputs["chat_generator"]["serialized_data"] + assert e.results is not None + + # verify that the snapshot file was created + chat_generator_snapshot_files = list(Path(debug_path).glob("database_agent_chat_generator_*.json")) + assert len(chat_generator_snapshot_files) > 0, f"No chat_generator snapshot file found in {debug_path}" + + # resume the pipeline from the saved snapshot + latest_snapshot_file = max(chat_generator_snapshot_files, key=os.path.getctime) + result = pipeline_with_agent.run(data={}, pipeline_snapshot=load_pipeline_snapshot(latest_snapshot_file)) + + # pipeline completed successfully after resuming + assert "database_agent" in result + assert "messages" in result["database_agent"] + assert len(result["database_agent"]["messages"]) > 0 + + # final message contains the expected summary + final_message = result["database_agent"]["messages"][-1].text + assert "Malte Pietsch" in final_message + assert "Milos Rusic" in final_message + assert "Chief Executive Officer" in final_message + assert "Chief Technology Officer" in final_message + + # tool should have been called during the resumed execution + documents = document_store.filter_documents() + assert len(documents) >= 2, "Expected at least 2 documents to be added to the database" + + # both people were added + person_names = [doc.content for doc in documents] + assert any("Malte Pietsch" in name for name in person_names) + assert any("Milos Rusic" in name for name in person_names) + + +def test_agent_breakpoint_tool_and_resume_pipeline(pipeline_with_agent): + with tempfile.TemporaryDirectory() as debug_path: + agent_tool_breakpoint = ToolBreakpoint( + "tool_invoker", 0, tool_name="add_database_tool", snapshot_file_path=debug_path + ) + agent_breakpoint = AgentBreakpoint(break_point=agent_tool_breakpoint, agent_name="database_agent") + try: + pipeline_with_agent.run( + data={"fetcher": {"urls": ["https://en.wikipedia.org/wiki/Deepset"]}}, break_point=agent_breakpoint + ) + assert False, "Expected PipelineBreakpointException was not raised" + + except BreakpointException as e: + assert e.component == "tool_invoker" + assert e.inputs is not None + assert "serialization_schema" in e.inputs["tool_invoker"] + assert "serialized_data" in e.inputs["tool_invoker"] + assert "messages" in e.inputs["tool_invoker"]["serialized_data"] + assert e.results is not None + + # verify that the snapshot file was created + tool_invoker_snapshot_files = list(Path(debug_path).glob("database_agent_tool_invoker_*.json")) + assert len(tool_invoker_snapshot_files) > 0, f"No tool_invoker snapshot file found in {debug_path}" + + # resume the pipeline from the saved snapshot + latest_snapshot_file = max(tool_invoker_snapshot_files, key=os.path.getctime) + pipeline_snapshot = load_pipeline_snapshot(latest_snapshot_file) + result = pipeline_with_agent.run(data={}, pipeline_snapshot=pipeline_snapshot) + + # pipeline completed successfully after resuming + assert "database_agent" in result + assert "messages" in result["database_agent"] + assert len(result["database_agent"]["messages"]) > 0 + + # final message contains the expected summary + final_message = result["database_agent"]["messages"][-1].text + assert "Malte Pietsch" in final_message + assert "Milos Rusic" in final_message + assert "Chief Executive Officer" in final_message + assert "Chief Technology Officer" in final_message + + # tool should have been called during the resumed execution + documents = document_store.filter_documents() + assert len(documents) >= 2, "Expected at least 2 documents to be added to the database" + + # both people were added + person_names = [doc.content for doc in documents] + assert any("Malte Pietsch" in name for name in person_names) + assert any("Milos Rusic" in name for name in person_names) diff --git a/test/components/agents/test_agent_breakpoints_isolation_async.py b/test/components/agents/test_agent_breakpoints_isolation_async.py new file mode 100644 index 0000000000..a41da346f3 --- /dev/null +++ b/test/components/agents/test_agent_breakpoints_isolation_async.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from pathlib import Path +from unittest.mock import AsyncMock + +import pytest + +from haystack.components.agents import Agent +from haystack.core.errors import BreakpointException +from haystack.core.pipeline.breakpoint import load_pipeline_snapshot +from haystack.dataclasses import ChatMessage, ToolCall +from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint +from haystack.tools import Tool +from test.components.agents.test_agent import MockChatGeneratorWithRunAsync, weather_function + +AGENT_NAME = "isolated_agent" + + +@pytest.fixture +def weather_tool(): + return Tool( + name="weather_tool", + description="Provides weather information for a given location.", + parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, + function=weather_function, + ) + + +@pytest.fixture +def mock_chat_generator(): + generator = MockChatGeneratorWithRunAsync() + mock_run_async = AsyncMock() + mock_run_async.return_value = { + "replies": [ + ChatMessage.from_assistant( + "I'll help you check the weather.", + tool_calls=[{"tool_name": "weather_tool", "tool_args": {"location": "Berlin"}}], + ) + ] + } + + async def mock_run_async_with_tools(messages, tools=None, **kwargs): + return mock_run_async.return_value + + generator.run_async = mock_run_async_with_tools + return generator + + +@pytest.fixture +def agent(mock_chat_generator, weather_tool): + return Agent( + chat_generator=mock_chat_generator, + tools=[weather_tool], + system_prompt="You are a helpful assistant that can use tools to help users.", + max_agent_steps=10, # Increase max steps to allow breakpoints to trigger + ) + + +@pytest.fixture +def debug_path(tmp_path): + return str(tmp_path / "debug_snapshots") + + +@pytest.fixture +def mock_agent_with_tool_calls(monkeypatch, weather_tool): + monkeypatch.setenv("OPENAI_API_KEY", "fake-key") + generator = MockChatGeneratorWithRunAsync() + mock_messages = [ + ChatMessage.from_assistant("First response"), + ChatMessage.from_assistant(tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]), + ] + agent = Agent(chat_generator=generator, tools=[weather_tool], max_agent_steps=10) # Increase max steps + agent.warm_up() + agent.chat_generator.run_async = AsyncMock(return_value={"replies": mock_messages}) + return agent + + +@pytest.mark.asyncio +async def test_run_async_with_chat_generator_breakpoint(agent): + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + chat_generator_bp = Breakpoint(component_name="chat_generator", visit_count=0) + agent_breakpoint = AgentBreakpoint(break_point=chat_generator_bp, agent_name="test") + with pytest.raises(BreakpointException) as exc_info: + await agent.run_async(messages=messages, break_point=agent_breakpoint, agent_name=AGENT_NAME) + assert exc_info.value.component == "chat_generator" + assert "messages" in exc_info.value.inputs["chat_generator"]["serialized_data"] + + +@pytest.mark.asyncio +async def test_run_async_with_tool_invoker_breakpoint(mock_agent_with_tool_calls): + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + tool_bp = ToolBreakpoint(component_name="tool_invoker", visit_count=0, tool_name="weather_tool") + agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test") + with pytest.raises(BreakpointException) as exc_info: + await mock_agent_with_tool_calls.run_async( + messages=messages, break_point=agent_breakpoint, agent_name=AGENT_NAME + ) + + assert exc_info.value.component == "tool_invoker" + assert "messages" in exc_info.value.inputs["tool_invoker"]["serialized_data"] + + +@pytest.mark.asyncio +async def test_resume_from_chat_generator_async(agent, debug_path): + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + chat_generator_bp = Breakpoint(component_name="chat_generator", visit_count=0, snapshot_file_path=debug_path) + agent_breakpoint = AgentBreakpoint(break_point=chat_generator_bp, agent_name=AGENT_NAME) + + try: + await agent.run_async(messages=messages, break_point=agent_breakpoint, agent_name=AGENT_NAME) + except BreakpointException: + pass + + snapshot_files = list(Path(debug_path).glob(AGENT_NAME + "_chat_generator_*.json")) + + assert len(snapshot_files) > 0 + latest_snapshot_file = str(max(snapshot_files, key=os.path.getctime)) + + result = await agent.run_async( + messages=[ChatMessage.from_user("Continue from where we left off.")], + snapshot=load_pipeline_snapshot(latest_snapshot_file).agent_snapshot, + ) + + assert "messages" in result + assert "last_message" in result + assert len(result["messages"]) > 0 + + +@pytest.mark.asyncio +async def test_resume_from_tool_invoker_async(mock_agent_with_tool_calls, debug_path): + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + tool_bp = ToolBreakpoint( + component_name="tool_invoker", visit_count=0, tool_name="weather_tool", snapshot_file_path=debug_path + ) + agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name=AGENT_NAME) + + try: + await mock_agent_with_tool_calls.run_async( + messages=messages, break_point=agent_breakpoint, agent_name=AGENT_NAME + ) + except BreakpointException: + pass + + snapshot_files = list(Path(debug_path).glob(AGENT_NAME + "_tool_invoker_*.json")) + + assert len(snapshot_files) > 0 + latest_snapshot_file = str(max(snapshot_files, key=os.path.getctime)) + + result = await mock_agent_with_tool_calls.run_async( + messages=[ChatMessage.from_user("Continue from where we left off.")], + snapshot=load_pipeline_snapshot(latest_snapshot_file).agent_snapshot, + ) + + assert "messages" in result + assert "last_message" in result + assert len(result["messages"]) > 0 + + +@pytest.mark.asyncio +async def test_invalid_combination_breakpoint_and_pipeline_snapshot_async(mock_agent_with_tool_calls): + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + tool_bp = ToolBreakpoint(component_name="tool_invoker", visit_count=0, tool_name="weather_tool") + agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test") + with pytest.raises(ValueError, match="break_point and snapshot cannot be provided at the same time"): + await mock_agent_with_tool_calls.run_async( + messages=messages, break_point=agent_breakpoint, snapshot={"some": "snapshot"} + ) + + +@pytest.mark.asyncio +async def test_breakpoint_with_invalid_component_async(mock_agent_with_tool_calls): + invalid_bp = Breakpoint(component_name="invalid_breakpoint", visit_count=0) + with pytest.raises(ValueError): + AgentBreakpoint(break_point=invalid_bp, agent_name="test") + + +@pytest.mark.asyncio +async def test_breakpoint_with_invalid_tool_name_async(mock_agent_with_tool_calls): + tool_breakpoint = ToolBreakpoint(component_name="tool_invoker", visit_count=0, tool_name="invalid_tool") + with pytest.raises(ValueError, match="Tool 'invalid_tool' is not available in the agent's tools"): + agent_breakpoint = AgentBreakpoint(break_point=tool_breakpoint, agent_name="test") + await mock_agent_with_tool_calls.run_async( + messages=[ChatMessage.from_user("What's the weather in Berlin?")], break_point=agent_breakpoint + ) diff --git a/test/components/agents/test_agent_breakpoints_isolation_sync.py b/test/components/agents/test_agent_breakpoints_isolation_sync.py new file mode 100644 index 0000000000..4afb4b92a9 --- /dev/null +++ b/test/components/agents/test_agent_breakpoints_isolation_sync.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os +from pathlib import Path + +import pytest + +from haystack.core.errors import BreakpointException +from haystack.core.pipeline.breakpoint import load_pipeline_snapshot +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, ToolBreakpoint +from test.components.agents.test_agent_breakpoints_utils import ( + agent_sync, + mock_agent_with_tool_calls_sync, + weather_tool, +) + +AGENT_NAME = "isolated_agent" + + +def test_run_with_chat_generator_breakpoint(agent_sync): # noqa: F811 + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + chat_generator_bp = Breakpoint(component_name="chat_generator", visit_count=0) + agent_breakpoint = AgentBreakpoint(break_point=chat_generator_bp, agent_name="test_agent") + with pytest.raises(BreakpointException) as exc_info: + agent_sync.run(messages=messages, break_point=agent_breakpoint, agent_name="test") + assert exc_info.value.component == "chat_generator" + assert "messages" in exc_info.value.inputs["chat_generator"]["serialized_data"] + + +def test_run_with_tool_invoker_breakpoint(mock_agent_with_tool_calls_sync): # noqa: F811 + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + tool_bp = ToolBreakpoint(component_name="tool_invoker", visit_count=0, tool_name="weather_tool") + agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test_agent") + with pytest.raises(BreakpointException) as exc_info: + mock_agent_with_tool_calls_sync.run(messages=messages, break_point=agent_breakpoint, agent_name="test") + + assert exc_info.value.component == "tool_invoker" + assert {"chat_generator", "tool_invoker"} == set(exc_info.value.inputs.keys()) + assert "serialization_schema" in exc_info.value.inputs["chat_generator"] + assert "serialized_data" in exc_info.value.inputs["chat_generator"] + assert "messages" in exc_info.value.inputs["chat_generator"]["serialized_data"] + + +def test_resume_from_chat_generator(agent_sync, tmp_path): # noqa: F811 + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + debug_path = str(tmp_path / "debug_snapshots") + chat_generator_bp = Breakpoint(component_name="chat_generator", visit_count=0, snapshot_file_path=debug_path) + agent_breakpoint = AgentBreakpoint(break_point=chat_generator_bp, agent_name=AGENT_NAME) + + try: + agent_sync.run(messages=messages, break_point=agent_breakpoint, agent_name=AGENT_NAME) + except BreakpointException: + pass + + snapshot_files = list(Path(debug_path).glob(AGENT_NAME + "_chat_generator_*.json")) + assert len(snapshot_files) > 0 + latest_snapshot_file = str(max(snapshot_files, key=os.path.getctime)) + + result = agent_sync.run( + messages=[ChatMessage.from_user("Continue from where we left off.")], + snapshot=load_pipeline_snapshot(latest_snapshot_file).agent_snapshot, + ) + + assert "messages" in result + assert "last_message" in result + assert len(result["messages"]) > 0 + + +def test_resume_from_tool_invoker(mock_agent_with_tool_calls_sync, tmp_path): # noqa: F811 + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + debug_path = str(tmp_path / "debug_snapshots") + tool_bp = ToolBreakpoint( + component_name="tool_invoker", visit_count=0, tool_name=None, snapshot_file_path=debug_path + ) + agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name=AGENT_NAME) + + try: + mock_agent_with_tool_calls_sync.run(messages=messages, break_point=agent_breakpoint, agent_name=AGENT_NAME) + except BreakpointException: + pass + + snapshot_files = list(Path(debug_path).glob(AGENT_NAME + "_tool_invoker_*.json")) + assert len(snapshot_files) > 0 + latest_snapshot_file = str(max(snapshot_files, key=os.path.getctime)) + + result = mock_agent_with_tool_calls_sync.run( + messages=[ChatMessage.from_user("Continue from where we left off.")], + snapshot=load_pipeline_snapshot(latest_snapshot_file).agent_snapshot, + ) + + assert "messages" in result + assert "last_message" in result + assert len(result["messages"]) > 0 + + +def test_invalid_combination_breakpoint_and_pipeline_snapshot(mock_agent_with_tool_calls_sync): # noqa: F811 + messages = [ChatMessage.from_user("What's the weather in Berlin?")] + tool_bp = ToolBreakpoint(component_name="tool_invoker", visit_count=0, tool_name="weather_tool") + agent_breakpoint = AgentBreakpoint(break_point=tool_bp, agent_name="test_agent") + with pytest.raises(ValueError, match="break_point and snapshot cannot be provided at the same time"): + mock_agent_with_tool_calls_sync.run( + messages=messages, break_point=agent_breakpoint, snapshot={"some": "snapshot"} + ) + + +def test_breakpoint_with_invalid_component(mock_agent_with_tool_calls_sync): # noqa: F811 + invalid_bp = Breakpoint(component_name="invalid_breakpoint", visit_count=0) + with pytest.raises(ValueError): + AgentBreakpoint(break_point=invalid_bp, agent_name="test_agent") + + +def test_breakpoint_with_invalid_tool_name(mock_agent_with_tool_calls_sync): # noqa: F811 + tool_breakpoint = ToolBreakpoint(component_name="tool_invoker", visit_count=0, tool_name="invalid_tool") + with pytest.raises(ValueError, match="Tool 'invalid_tool' is not available in the agent's tools"): + agent_breakpoints = AgentBreakpoint(break_point=tool_breakpoint, agent_name="test_agent") + mock_agent_with_tool_calls_sync.run( + messages=[ChatMessage.from_user("What's the weather in Berlin?")], break_point=agent_breakpoints + ) diff --git a/test/components/agents/test_agent_breakpoints_utils.py b/test/components/agents/test_agent_breakpoints_utils.py new file mode 100644 index 0000000000..3edd23100f --- /dev/null +++ b/test/components/agents/test_agent_breakpoints_utils.py @@ -0,0 +1,110 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from haystack.components.agents import Agent +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.dataclasses import ChatMessage, ToolCall +from haystack.tools import Tool +from test.components.agents.test_agent import ( + MockChatGeneratorWithoutRunAsync, + MockChatGeneratorWithRunAsync, + weather_function, +) + + +# Common fixtures +@pytest.fixture +def weather_tool(): + return Tool( + name="weather_tool", + description="Provides weather information for a given location.", + parameters={"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}, + function=weather_function, + ) + + +@pytest.fixture +def debug_path(tmp_path): + return str(tmp_path / "debug_snapshots") + + +@pytest.fixture +def agent_sync(weather_tool): + generator = MockChatGeneratorWithoutRunAsync() + mock_run = MagicMock() + mock_run.return_value = { + "replies": [ + ChatMessage.from_assistant( + "I'll help you check the weather.", + tool_calls=[{"tool_name": "weather_tool", "tool_args": {"location": "Berlin"}}], + ) + ] + } + + def mock_run_with_tools(messages, tools=None, **kwargs): + return mock_run.return_value + + generator.run = mock_run_with_tools + + return Agent( + chat_generator=generator, + tools=[weather_tool], + system_prompt="You are a helpful assistant that can use tools to help users.", + ) + + +@pytest.fixture +def mock_agent_with_tool_calls_sync(monkeypatch, weather_tool): + monkeypatch.setenv("OPENAI_API_KEY", "fake-key") + generator = OpenAIChatGenerator() + mock_messages = [ + ChatMessage.from_assistant("First response"), + ChatMessage.from_assistant(tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]), + ] + agent = Agent(chat_generator=generator, tools=[weather_tool], max_agent_steps=1) + agent.warm_up() + agent.chat_generator.run = MagicMock(return_value={"replies": mock_messages}) + return agent + + +@pytest.fixture +def agent_async(weather_tool): + generator = MockChatGeneratorWithRunAsync() + mock_run_async = AsyncMock() + mock_run_async.return_value = { + "replies": [ + ChatMessage.from_assistant( + "I'll help you check the weather.", + tool_calls=[{"tool_name": "weather_tool", "tool_args": {"location": "Berlin"}}], + ) + ] + } + + async def mock_run_async_with_tools(messages, tools=None, **kwargs): + return mock_run_async.return_value + + generator.run_async = mock_run_async_with_tools + return Agent( + chat_generator=generator, + tools=[weather_tool], + system_prompt="You are a helpful assistant that can use tools to help users.", + ) + + +@pytest.fixture +def mock_agent_with_tool_calls_async(monkeypatch, weather_tool): + monkeypatch.setenv("OPENAI_API_KEY", "fake-key") + generator = MockChatGeneratorWithRunAsync() + mock_messages = [ + ChatMessage.from_assistant("First response"), + ChatMessage.from_assistant(tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]), + ] + agent = Agent(chat_generator=generator, tools=[weather_tool], max_agent_steps=1) + agent.warm_up() + agent.chat_generator.run_async = AsyncMock(return_value={"replies": mock_messages}) + return agent diff --git a/test/conftest.py b/test/conftest.py index f8c9ff6a1e..eeef71d923 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -11,6 +11,7 @@ import pytest from haystack import component, tracing +from haystack.core.pipeline.breakpoint import load_pipeline_snapshot from haystack.testing.test_utils import set_all_seeds from test.tracing.utils import SpyingTracer @@ -82,6 +83,36 @@ def spying_tracer() -> Generator[SpyingTracer, None, None]: tracing.disable_tracing() +def load_and_resume_pipeline_snapshot(pipeline, output_directory: Path, component_name: str, data: Dict = None) -> Dict: + """ + Utility function to load and resume pipeline snapshot from a breakpoint file. + + :param pipeline: The pipeline instance to resume + :param output_directory: Directory containing the breakpoint files + :param component_name: Component name to look for in breakpoint files + :param data: Data to pass to the pipeline run (defaults to empty dict) + + :returns: + Dict containing the pipeline run results + + :raises: + ValueError: If no breakpoint file is found for the given component + """ + data = data or {} + all_files = list(output_directory.glob("*")) + file_found = False + + for full_path in all_files: + f_name = Path(full_path).name + if str(f_name).startswith(component_name): + pipeline_snapshot = load_pipeline_snapshot(full_path) + return pipeline.run(data=data, pipeline_snapshot=pipeline_snapshot) + + if not file_found: + msg = f"No files found for {component_name} in {output_directory}." + raise ValueError(msg) + + @pytest.fixture() def base64_image_string(): return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+ip1sAAAAASUVORK5CYII=" diff --git a/test/core/pipeline/test_breakpoint.py b/test/core/pipeline/test_breakpoint.py new file mode 100644 index 0000000000..7e9a53f7ad --- /dev/null +++ b/test/core/pipeline/test_breakpoint.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json + +import pytest + +from haystack import component +from haystack.core.errors import BreakpointException +from haystack.core.pipeline import Pipeline +from haystack.core.pipeline.breakpoint import _transform_json_structure, load_pipeline_snapshot +from haystack.dataclasses.breakpoints import Breakpoint, PipelineSnapshot + + +def test_transform_json_structure_unwraps_sender_value(): + data = { + "key1": [{"sender": None, "value": "some value"}], + "key2": [{"sender": "comp1", "value": 42}], + "key3": "direct value", + } + + result = _transform_json_structure(data) + + assert result == {"key1": "some value", "key2": 42, "key3": "direct value"} + + +def test_transform_json_structure_handles_nested_structures(): + data = { + "key1": [{"sender": None, "value": "value1"}], + "key2": {"nested": [{"sender": "comp1", "value": "value2"}], "direct": "value3"}, + "key3": [[{"sender": None, "value": "value4"}], [{"sender": "comp2", "value": "value5"}]], + } + + result = _transform_json_structure(data) + + assert result == {"key1": "value1", "key2": {"nested": "value2", "direct": "value3"}, "key3": ["value4", "value5"]} + + +def test_load_pipeline_snapshot_loads_valid_snapshot(tmp_path): + pipeline_snapshot = { + "break_point": {"component_name": "comp1", "visit_count": 0}, + "pipeline_state": {"inputs": {}, "component_visits": {"comp1": 0, "comp2": 0}, "pipeline_outputs": {}}, + "original_input_data": {}, + "ordered_component_names": ["comp1", "comp2"], + "include_outputs_from": ["comp1", "comp2"], + } + pipeline_snapshot_file = tmp_path / "state.json" + with open(pipeline_snapshot_file, "w") as f: + json.dump(pipeline_snapshot, f) + + loaded_snapshot = load_pipeline_snapshot(pipeline_snapshot_file) + assert loaded_snapshot == PipelineSnapshot.from_dict(pipeline_snapshot) + + +def test_load_state_handles_invalid_state(tmp_path): + pipeline_snapshot = { + "break_point": {"component_name": "comp1", "visit_count": 0}, + "pipeline_state": {"inputs": {}, "component_visits": {"comp1": 0, "comp2": 0}, "pipeline_outputs": {}}, + "original_input_data": {}, + "include_outputs_from": ["comp1", "comp2"], + "ordered_component_names": ["comp1", "comp3"], # inconsistent with component_visits + } + + pipeline_snapshot_file = tmp_path / "invalid_pipeline_snapshot.json" + with open(pipeline_snapshot_file, "w") as f: + json.dump(pipeline_snapshot, f) + + with pytest.raises(ValueError, match="Invalid pipeline snapshot from"): + load_pipeline_snapshot(pipeline_snapshot_file) + + +def test_breakpoint_saves_intermediate_outputs(tmp_path): + @component + class SimpleComponent: + @component.output_types(result=str) + def run(self, input_value: str) -> dict[str, str]: + return {"result": f"processed_{input_value}"} + + pipeline = Pipeline() + comp1 = SimpleComponent() + comp2 = SimpleComponent() + pipeline.add_component("comp1", comp1) + pipeline.add_component("comp2", comp2) + pipeline.connect("comp1", "comp2") + + # breakpoint on comp2 + break_point = Breakpoint(component_name="comp2", visit_count=0, snapshot_file_path=str(tmp_path)) + + try: + # run with include_outputs_from to capture intermediate outputs + pipeline.run(data={"comp1": {"input_value": "test"}}, include_outputs_from={"comp1"}, break_point=break_point) + except BreakpointException as e: + # breakpoint should be triggered + assert e.component == "comp2" + + # verify snapshot file contains the intermediate outputs + snapshot_files = list(tmp_path.glob("comp2_*.json")) + assert len(snapshot_files) == 1, f"Expected exactly one snapshot file, found {len(snapshot_files)}" + + snapshot_file = snapshot_files[0] + loaded_snapshot = load_pipeline_snapshot(snapshot_file) + + # verify the snapshot contains the intermediate outputs from comp1 + assert "comp1" in loaded_snapshot.pipeline_state.pipeline_outputs + assert loaded_snapshot.pipeline_state.pipeline_outputs["comp1"]["result"] == "processed_test" + + # verify the whole pipeline state contains the expected data + assert loaded_snapshot.pipeline_state.component_visits["comp1"] == 1 + assert loaded_snapshot.pipeline_state.component_visits["comp2"] == 0 + assert "comp1" in loaded_snapshot.include_outputs_from + assert loaded_snapshot.break_point.component_name == "comp2" + assert loaded_snapshot.break_point.visit_count == 0 diff --git a/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py new file mode 100644 index 0000000000..af7cf70326 --- /dev/null +++ b/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from haystack.components.builders.answer_builder import AnswerBuilder +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.components.joiners import AnswerJoiner +from haystack.core.errors import BreakpointException +from haystack.core.pipeline.pipeline import Pipeline +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import Breakpoint +from haystack.utils.auth import Secret +from test.conftest import load_and_resume_pipeline_snapshot + + +class TestPipelineBreakpoints: + @pytest.fixture + def mock_openai_chat_generator(self): + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + mock_completion = MagicMock() + mock_completion.choices = [ + MagicMock( + finish_reason="stop", + index=0, + message=MagicMock( + content="Natural Language Processing (NLP) is a field of AI focused on enabling " + "computers to understand, interpret, and generate human language." + ), + ) + ] + mock_completion.usage = {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97} + + mock_chat_completion_create.return_value = mock_completion + + # Create a mock for the OpenAIChatGenerator + @patch.dict("os.environ", {"OPENAI_API_KEY": "test-api-key"}) + def create_mock_generator(model_name): + generator = OpenAIChatGenerator(model=model_name, api_key=Secret.from_env_var("OPENAI_API_KEY")) + + # Mock the run method + def mock_run(messages, streaming_callback=None, generation_kwargs=None, tools=None, tools_strict=None): + if "gpt-4" in model_name: + content = ( + "Natural Language Processing (NLP) is a field of AI focused on enabling computers " + "to understand, interpret, and generate human language." + ) + else: + content = "NLP is a branch of AI that helps machines understand and process human language." + + return { + "replies": [ChatMessage.from_assistant(content)], + "meta": { + "model": model_name, + "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + }, + } + + # Replace the run method with our mock + generator.run = mock_run + + return generator + + yield create_mock_generator + + @pytest.fixture + def answer_join_pipeline(self, mock_openai_chat_generator): + """ + Creates a pipeline with mocked OpenAI components. + """ + # Create the pipeline with mocked components + pipeline = Pipeline(connection_type_validation=False) + pipeline.add_component("gpt-4o", mock_openai_chat_generator("gpt-4o")) + pipeline.add_component("gpt-3", mock_openai_chat_generator("gpt-3.5-turbo")) + pipeline.add_component("answer_builder_a", AnswerBuilder()) + pipeline.add_component("answer_builder_b", AnswerBuilder()) + pipeline.add_component("answer_joiner", AnswerJoiner()) + pipeline.connect("gpt-4o.replies", "answer_builder_a") + pipeline.connect("gpt-3.replies", "answer_builder_b") + pipeline.connect("answer_builder_a.answers", "answer_joiner") + pipeline.connect("answer_builder_b.answers", "answer_joiner") + + return pipeline + + @pytest.fixture(scope="session") + def output_directory(self, tmp_path_factory) -> Path: + return tmp_path_factory.mktemp("output_files") + + BREAKPOINT_COMPONENTS = ["gpt-4o", "gpt-3", "answer_builder_a", "answer_builder_b", "answer_joiner"] + + @pytest.mark.parametrize("component", BREAKPOINT_COMPONENTS, ids=BREAKPOINT_COMPONENTS) + @pytest.mark.integration + def test_pipeline_breakpoints_answer_joiner(self, answer_join_pipeline, output_directory, component): + """ + Test that an answer joiner pipeline can be executed with breakpoints at each component. + """ + query = "What's Natural Language Processing?" + messages = [ + ChatMessage.from_system("You are a helpful, respectful and honest assistant. Be super concise."), + ChatMessage.from_user(query), + ] + data = { + "gpt-4o": {"messages": messages}, + "gpt-3": {"messages": messages}, + "answer_builder_a": {"query": query}, + "answer_builder_b": {"query": query}, + } + + # Create a Breakpoint on-the-fly using the shared output directory + break_point = Breakpoint(component_name=component, visit_count=0, snapshot_file_path=str(output_directory)) + + try: + _ = answer_join_pipeline.run(data, break_point=break_point) + except BreakpointException: + pass + + result = load_and_resume_pipeline_snapshot( + pipeline=answer_join_pipeline, + output_directory=output_directory, + component_name=break_point.component_name, + data=data, + ) + assert result["answer_joiner"] diff --git a/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py new file mode 100644 index 0000000000..e5e2618891 --- /dev/null +++ b/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from typing import List +from unittest.mock import MagicMock, patch + +import pytest + +from haystack.components.converters import OutputAdapter +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.components.joiners import BranchJoiner +from haystack.components.validators import JsonSchemaValidator +from haystack.core.errors import BreakpointException +from haystack.core.pipeline.pipeline import Pipeline +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import Breakpoint +from haystack.utils.auth import Secret +from test.conftest import load_and_resume_pipeline_snapshot + + +class TestPipelineBreakpoints: + @pytest.fixture + def mock_openai_chat_generator(self): + """ + Creates a mock for the OpenAIChatGenerator. + """ + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + # Create mock completion objects + mock_completion = MagicMock() + mock_completion.choices = [ + MagicMock( + finish_reason="stop", + index=0, + message=MagicMock( + content='{"first_name": "Peter", "last_name": "Parker", "nationality": "American"}' + ), + ) + ] + mock_completion.usage = {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97} + + mock_chat_completion_create.return_value = mock_completion + + # Create a mock for the OpenAIChatGenerator + @patch.dict("os.environ", {"OPENAI_API_KEY": "test-api-key"}) + def create_mock_generator(model_name): + generator = OpenAIChatGenerator(model=model_name, api_key=Secret.from_env_var("OPENAI_API_KEY")) + + # Mock the run method + def mock_run(messages, streaming_callback=None, generation_kwargs=None, tools=None, tools_strict=None): + content = '{"first_name": "Peter", "last_name": "Parker", "nationality": "American"}' + + return { + "replies": [ChatMessage.from_assistant(content)], + "meta": { + "model": model_name, + "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + }, + } + + # Replace the run method with our mock + generator.run = mock_run + + return generator + + yield create_mock_generator + + @pytest.fixture + def branch_joiner_pipeline(self, mock_openai_chat_generator): + person_schema = { + "type": "object", + "properties": { + "first_name": {"type": "string", "pattern": "^[A-Z][a-z]+$"}, + "last_name": {"type": "string", "pattern": "^[A-Z][a-z]+$"}, + "nationality": {"type": "string", "enum": ["Italian", "Portuguese", "American"]}, + }, + "required": ["first_name", "last_name", "nationality"], + } + + pipe = Pipeline() + pipe.add_component("joiner", BranchJoiner(List[ChatMessage])) + pipe.add_component("fc_llm", mock_openai_chat_generator("gpt-4o-mini")) + pipe.add_component("validator", JsonSchemaValidator(json_schema=person_schema)) + pipe.add_component("adapter", OutputAdapter("{{chat_message}}", List[ChatMessage], unsafe=True)) + + pipe.connect("adapter", "joiner") + pipe.connect("joiner", "fc_llm") + pipe.connect("fc_llm.replies", "validator.messages") + pipe.connect("validator.validation_error", "joiner") + + return pipe + + @pytest.fixture(scope="session") + def output_directory(self, tmp_path_factory) -> Path: + return tmp_path_factory.mktemp("output_files") + + BREAKPOINT_COMPONENTS = ["joiner", "fc_llm", "validator", "adapter"] + + @pytest.mark.parametrize("component", BREAKPOINT_COMPONENTS, ids=BREAKPOINT_COMPONENTS) + @pytest.mark.integration + def test_pipeline_breakpoints_branch_joiner(self, branch_joiner_pipeline, output_directory, component): + data = { + "fc_llm": {"generation_kwargs": {"response_format": {"type": "json_object"}}}, + "adapter": {"chat_message": [ChatMessage.from_user("Create JSON from Peter Parker")]}, + } + + # Create a Breakpoint on-the-fly using the shared output directory + break_point = Breakpoint(component_name=component, visit_count=0, snapshot_file_path=str(output_directory)) + + try: + _ = branch_joiner_pipeline.run(data, break_point=break_point) + except BreakpointException: + pass + + result = load_and_resume_pipeline_snapshot( + pipeline=branch_joiner_pipeline, + output_directory=output_directory, + component_name=break_point.component_name, + data=data, + ) + assert result["validator"], "The result should be valid according to the schema." diff --git a/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py new file mode 100644 index 0000000000..9a89abac3f --- /dev/null +++ b/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from typing import List +from unittest.mock import MagicMock, patch + +import pytest + +from haystack import Pipeline +from haystack.components.builders import ChatPromptBuilder +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.components.joiners import ListJoiner +from haystack.core.errors import BreakpointException +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import Breakpoint +from haystack.utils.auth import Secret +from test.conftest import load_and_resume_pipeline_snapshot + + +class TestPipelineBreakpoints: + @pytest.fixture + def mock_openai_chat_generator(self): + """ + Creates a mock for the OpenAIChatGenerator. + """ + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + # Create mock completion objects + mock_completion = MagicMock() + mock_completion.choices = [ + MagicMock( + finish_reason="stop", + index=0, + message=MagicMock( + content="Nuclear physics is the study of atomic nuclei, their constituents, " + "and their interactions." + ), + ) + ] + mock_completion.usage = {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97} + + mock_chat_completion_create.return_value = mock_completion + + # Create a mock for the OpenAIChatGenerator + @patch.dict("os.environ", {"OPENAI_API_KEY": "test-api-key"}) + def create_mock_generator(model_name): + generator = OpenAIChatGenerator(model=model_name, api_key=Secret.from_env_var("OPENAI_API_KEY")) + + # Mock the run method + def mock_run(messages, streaming_callback=None, generation_kwargs=None, tools=None, tools_strict=None): + # Check if this is a feedback request or a regular query + if any("feedback" in msg.text.lower() for msg in messages): + content = ( + "Score: 8/10. The answer is concise and accurate, providing a good overview " + "of nuclear physics." + ) + else: + content = ( + "Nuclear physics is the study of atomic nuclei, their constituents, and their interactions." + ) + + return { + "replies": [ChatMessage.from_assistant(content)], + "meta": { + "model": model_name, + "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + }, + } + + # Replace the run method with our mock + generator.run = mock_run + + return generator + + yield create_mock_generator + + @pytest.fixture + def list_joiner_pipeline(self, mock_openai_chat_generator): + user_message = [ChatMessage.from_user("Give a brief answer the following question: {{query}}")] + + feedback_prompt = """ + You are given a question and an answer. + Your task is to provide a score and a brief feedback on the answer. + Question: {{query}} + Answer: {{response}} + """ + + feedback_message = [ChatMessage.from_system(feedback_prompt)] + + prompt_builder = ChatPromptBuilder(template=user_message) + feedback_prompt_builder = ChatPromptBuilder(template=feedback_message) + llm = mock_openai_chat_generator("gpt-4o-mini") + feedback_llm = mock_openai_chat_generator("gpt-4o-mini") + + pipe = Pipeline() + pipe.add_component("prompt_builder", prompt_builder) + pipe.add_component("llm", llm) + pipe.add_component("feedback_prompt_builder", feedback_prompt_builder) + pipe.add_component("feedback_llm", feedback_llm) + pipe.add_component("list_joiner", ListJoiner(List[ChatMessage])) + + pipe.connect("prompt_builder.prompt", "llm.messages") + pipe.connect("prompt_builder.prompt", "list_joiner") + pipe.connect("llm.replies", "list_joiner") + pipe.connect("llm.replies", "feedback_prompt_builder.response") + pipe.connect("feedback_prompt_builder.prompt", "feedback_llm.messages") + pipe.connect("feedback_llm.replies", "list_joiner") + + return pipe + + @pytest.fixture(scope="session") + def output_directory(self, tmp_path_factory) -> Path: + return tmp_path_factory.mktemp("output_files") + + BREAKPOINT_COMPONENTS = ["prompt_builder", "llm", "feedback_prompt_builder", "feedback_llm", "list_joiner"] + + @pytest.mark.parametrize("component", BREAKPOINT_COMPONENTS, ids=BREAKPOINT_COMPONENTS) + @pytest.mark.integration + def test_list_joiner_pipeline(self, list_joiner_pipeline, output_directory, component): + query = "What is nuclear physics?" + data = { + "prompt_builder": {"template_variables": {"query": query}}, + "feedback_prompt_builder": {"template_variables": {"query": query}}, + } + + # Create a Breakpoint on-the-fly using the shared output directory + break_point = Breakpoint(component_name=component, visit_count=0, snapshot_file_path=str(output_directory)) + + try: + _ = list_joiner_pipeline.run(data, break_point=break_point) + except BreakpointException: + pass + + result = load_and_resume_pipeline_snapshot( + pipeline=list_joiner_pipeline, + output_directory=output_directory, + component_name=break_point.component_name, + data=data, + ) + assert result["list_joiner"] diff --git a/test/core/pipeline/test_pipeline_breakpoints_loops.py b/test/core/pipeline/test_pipeline_breakpoints_loops.py new file mode 100644 index 0000000000..f76dd6d48c --- /dev/null +++ b/test/core/pipeline/test_pipeline_breakpoints_loops.py @@ -0,0 +1,238 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +from pathlib import Path +from typing import List, Optional +from unittest.mock import MagicMock, patch + +import pydantic +import pytest +from colorama import Fore +from pydantic import BaseModel, ValidationError + +from haystack import component +from haystack.components.builders import ChatPromptBuilder +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.core.errors import BreakpointException +from haystack.core.pipeline.breakpoint import load_pipeline_snapshot +from haystack.core.pipeline.pipeline import Pipeline +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import Breakpoint +from haystack.utils.auth import Secret + + +# Define the component input parameters +@component +class OutputValidator: + def __init__(self, pydantic_model: pydantic.BaseModel): + self.pydantic_model = pydantic_model + self.iteration_counter = 0 + + # Define the component output + @component.output_types(valid_replies=List[str], invalid_replies=Optional[List[str]], error_message=Optional[str]) + def run(self, replies: List[ChatMessage]): + self.iteration_counter += 1 + + ## Try to parse the LLM's reply ## + # If the LLM's reply is a valid object, return `"valid_replies"` + try: + output_dict = json.loads(replies[0].text) + self.pydantic_model.model_validate(output_dict) + print( + Fore.GREEN + f"OutputValidator at Iteration {self.iteration_counter}: " + f"Valid JSON from LLM - No need for looping: {replies[0]}" + ) + return {"valid_replies": replies} + + # If the LLM's reply is corrupted or not valid, return "invalid_replies" and the "error_message" for the LLM + # to try again + except (ValueError, ValidationError) as e: + print( + Fore.RED + + f"OutputValidator at Iteration {self.iteration_counter}: Invalid JSON from LLM - Let's try again.\n" + f"Output from LLM:\n {replies[0]} \n" + f"Error from OutputValidator: {e}" + ) + return {"invalid_replies": replies, "error_message": str(e)} + + +class City(BaseModel): + name: str + country: str + population: int + + +class CitiesData(BaseModel): + cities: List[City] + + +class TestPipelineBreakpointsLoops: + """ + This class contains tests for pipelines with validation loops and breakpoints. + """ + + @pytest.fixture + def mock_openai_chat_generator(self): + """ + Creates a mock for the OpenAIChatGenerator that returns valid JSON responses. + """ + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + # Create mock completion objects + mock_completion = MagicMock() + mock_completion.choices = [ + MagicMock( + finish_reason="stop", + index=0, + message=MagicMock( + content='{"cities": [{"name": "Berlin", "country": "Germany", "population": 3850809}, ' + '{"name": "Paris", "country": "France", "population": 2161000}, ' + '{"name": "Lisbon", "country": "Portugal", "population": 504718}]}' + ), + ) + ] + mock_completion.usage = {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97} + + mock_chat_completion_create.return_value = mock_completion + + # Create a mock for the OpenAIChatGenerator + @patch.dict("os.environ", {"OPENAI_API_KEY": "test-api-key"}) + def create_mock_generator(): + generator = OpenAIChatGenerator(api_key=Secret.from_env_var("OPENAI_API_KEY")) + + # Mock the run method + def mock_run(messages, streaming_callback=None, generation_kwargs=None, tools=None, tools_strict=None): + # Check if this is a retry attempt + if any("You already created the following output" in msg.text for msg in messages): + # Return a valid JSON response for retry attempts + return { + "replies": [ + ChatMessage.from_assistant( + '{"cities": [{"name": "Berlin", "country": "Germany", "population": 3850809}, ' + '{"name": "Paris", "country": "France", "population": 2161000}, ' + '{"name": "Lisbon", "country": "Portugal", "population": 504718}]}' + ) + ], + "meta": { + "model": "gpt-4", + "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + }, + } + else: + # Return a valid JSON response for initial attempts + return { + "replies": [ + ChatMessage.from_assistant( + '{"cities": [{"name": "Berlin", "country": "Germany", "population": 3850809}, ' + '{"name": "Paris", "country": "France", "population": 2161000}, ' + '{"name": "Lisbon", "country": "Portugal", "population": 504718}]}' + ) + ], + "meta": { + "model": "gpt-4", + "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + }, + } + + # Replace the run method with our mock + generator.run = mock_run + + return generator + + yield create_mock_generator + + @pytest.fixture + def validation_loop_pipeline(self, mock_openai_chat_generator): + """Create a pipeline with validation loops for testing.""" + prompt_template = [ + ChatMessage.from_user( + """ + Create a JSON object from the information present in this passage: {{passage}}. + Only use information that is present in the passage. Follow this JSON schema, but only return the + actual instances without any additional schema definition: + {{schema}} + Make sure your response is a dict and not a list. + {% if invalid_replies and error_message %} + You already created the following output in a previous attempt: {{invalid_replies}} + However, this doesn't comply with the format requirements from above and triggered this + Python exception: {{error_message}} + Correct the output and try again. Just return the corrected output without any extra explanations. + {% endif %} + """ + ) + ] + + pipeline = Pipeline(max_runs_per_component=5) + pipeline.add_component(instance=ChatPromptBuilder(template=prompt_template), name="prompt_builder") + pipeline.add_component(instance=mock_openai_chat_generator(), name="llm") + pipeline.add_component(instance=OutputValidator(pydantic_model=CitiesData), name="output_validator") + + # Connect components + pipeline.connect("prompt_builder.prompt", "llm.messages") + pipeline.connect("llm.replies", "output_validator") + pipeline.connect("output_validator.invalid_replies", "prompt_builder.invalid_replies") + pipeline.connect("output_validator.error_message", "prompt_builder.error_message") + + return pipeline + + @pytest.fixture(scope="session") + def output_directory(self, tmp_path_factory): + return tmp_path_factory.mktemp("output_files") + + @pytest.fixture + def test_data(self): + json_schema = { + "cities": [ + {"name": "Berlin", "country": "Germany", "population": 3850809}, + {"name": "Paris", "country": "France", "population": 2161000}, + {"name": "Lisbon", "country": "Portugal", "population": 504718}, + ] + } + + passage = ( + "Berlin is the capital of Germany. It has a population of 3,850,809. Paris, France's capital, has " + "2.161 million residents. Lisbon is the capital and the largest city of Portugal with the " + "population of 504,718." + ) + + return {"schema": json_schema, "passage": passage} + + BREAKPOINT_COMPONENTS = ["prompt_builder", "llm", "output_validator"] + + @pytest.mark.parametrize("component", BREAKPOINT_COMPONENTS, ids=BREAKPOINT_COMPONENTS) + @pytest.mark.integration + def test_pipeline_breakpoints_validation_loop( + self, validation_loop_pipeline, output_directory, test_data, component + ): + """ + Test that a pipeline with validation loops can be executed with breakpoints at each component. + """ + data = {"prompt_builder": {"passage": test_data["passage"], "schema": test_data["schema"]}} + + # Create a Breakpoint on-the-fly using the shared output directory + break_point = Breakpoint(component_name=component, visit_count=0, snapshot_file_path=str(output_directory)) + + try: + _ = validation_loop_pipeline.run(data, break_point=break_point) + except BreakpointException: + pass + + all_files = list(output_directory.glob("*")) + file_found = False + for full_path in all_files: + f_name = Path(full_path).name + if str(f_name).startswith(break_point.component_name): + file_found = True + result = validation_loop_pipeline.run(data={}, pipeline_snapshot=load_pipeline_snapshot(full_path)) + # Verify the result contains valid output + if "output_validator" in result and "valid_replies" in result["output_validator"]: + valid_reply = result["output_validator"]["valid_replies"][0].text + valid_json = json.loads(valid_reply) + assert isinstance(valid_json, dict) + assert "cities" in valid_json + cities_data = CitiesData.model_validate(valid_json) + assert len(cities_data.cities) == 3 + if not file_found: + msg = f"No files found for {component} in {output_directory}." + raise ValueError(msg) diff --git a/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py b/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py new file mode 100644 index 0000000000..820942c247 --- /dev/null +++ b/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py @@ -0,0 +1,301 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from haystack import Document +from haystack.components.builders.answer_builder import AnswerBuilder +from haystack.components.builders.prompt_builder import PromptBuilder +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder +from haystack.components.generators import OpenAIGenerator +from haystack.components.joiners import DocumentJoiner +from haystack.components.rankers import TransformersSimilarityRanker +from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever +from haystack.components.writers import DocumentWriter +from haystack.core.errors import BreakpointException +from haystack.core.pipeline.pipeline import Pipeline +from haystack.dataclasses.breakpoints import Breakpoint +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack.document_stores.types import DuplicatePolicy +from haystack.utils.auth import Secret +from test.conftest import load_and_resume_pipeline_snapshot + + +class TestPipelineBreakpoints: + """ + This class contains tests for pipelines with breakpoints. + """ + + @pytest.fixture + def mock_sentence_transformers_doc_embedder(self): + with patch( + "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" # noqa: E501 + ) as mock_doc_embedder: + mock_model = MagicMock() + mock_doc_embedder.return_value = mock_model + + # the mock returns a fixed embedding + def mock_encode( + documents, batch_size=None, show_progress_bar=None, normalize_embeddings=None, precision=None, **kwargs + ): + import numpy as np + + return [np.ones(384).tolist() for _ in documents] + + mock_model.encode = mock_encode + embedder = SentenceTransformersDocumentEmbedder(model="mock-model", progress_bar=False) + + # mocked run method to return a fixed embedding + def mock_run(documents: list[Document]): + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + raise TypeError( + "SentenceTransformersDocumentEmbedder expects a list of Documents as input." + "In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder." + ) + + import numpy as np + + embedding = np.ones(384).tolist() + + # Add the embedding to each document + for doc in documents: + doc.embedding = embedding + + # Return the documents with embeddings, matching the actual implementation + return {"documents": documents} + + # mocked run + embedder.run = mock_run + + # initialize the component + embedder.warm_up() + + return embedder + + @pytest.fixture + def document_store(self, mock_sentence_transformers_doc_embedder): + """Create and populate a document store for testing.""" + documents = [ + Document(content="My name is Jean and I live in Paris."), + Document(content="My name is Mark and I live in Berlin."), + Document(content="My name is Giorgio and I live in Rome."), + ] + + document_store = InMemoryDocumentStore() + doc_writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP) + ingestion_pipe = Pipeline() + ingestion_pipe.add_component(instance=mock_sentence_transformers_doc_embedder, name="doc_embedder") + ingestion_pipe.add_component(instance=doc_writer, name="doc_writer") + ingestion_pipe.connect("doc_embedder.documents", "doc_writer.documents") + ingestion_pipe.run({"doc_embedder": {"documents": documents}}) + + return document_store + + @pytest.fixture + def mock_openai_completion(self): + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + mock_completion = MagicMock() + mock_completion.model = "gpt-4o-mini" + mock_completion.choices = [ + MagicMock(finish_reason="stop", index=0, message=MagicMock(content="Mark lives in Berlin.")) + ] + mock_completion.usage = {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97} + + mock_chat_completion_create.return_value = mock_completion + yield mock_chat_completion_create + + @pytest.fixture + def mock_transformers_similarity_ranker(self): + """ + This mock simulates the behavior of the ranker without loading the actual model. + """ + with ( + patch( + "haystack.components.rankers.transformers_similarity.AutoModelForSequenceClassification" + ) as mock_model_class, + patch("haystack.components.rankers.transformers_similarity.AutoTokenizer") as mock_tokenizer_class, + ): + mock_model = MagicMock() + mock_tokenizer = MagicMock() + + mock_model_class.from_pretrained.return_value = mock_model + mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer + + ranker = TransformersSimilarityRanker(model="mock-model", top_k=5, scale_score=True, calibration_factor=1.0) + + def mock_run(query, documents, top_k=None, scale_score=None, calibration_factor=None, score_threshold=None): + # assign random scores + import random + + ranked_docs = documents.copy() + for doc in ranked_docs: + doc.score = random.random() # random score between 0 and 1 + + # sort reverse order and select top_k if provided + ranked_docs.sort(key=lambda x: x.score, reverse=True) + if top_k is not None: + ranked_docs = ranked_docs[:top_k] + else: + ranked_docs = ranked_docs[: ranker.top_k] + + # apply score threshold if provided + if score_threshold is not None: + ranked_docs = [doc for doc in ranked_docs if doc.score >= score_threshold] + + return {"documents": ranked_docs} + + # replace the run method with our mock + ranker.run = mock_run + + # warm_up to initialize the component + ranker.warm_up() + + return ranker + + @pytest.fixture + def mock_sentence_transformers_text_embedder(self): + """ + Simulates the behavior of the embedder without loading the actual model + """ + with patch( + "haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory" + ) as mock_text_embedder: # noqa: E501 + mock_model = MagicMock() + mock_text_embedder.return_value = mock_model + + # the mock returns a fixed embedding + def mock_encode( + texts, batch_size=None, show_progress_bar=None, normalize_embeddings=None, precision=None, **kwargs + ): + import numpy as np + + return [np.ones(384).tolist() for _ in texts] + + mock_model.encode = mock_encode + embedder = SentenceTransformersTextEmbedder(model="mock-model", progress_bar=False) + + # mocked run method to return a fixed embedding + def mock_run(text): + if not isinstance(text, str): + raise TypeError( + "SentenceTransformersTextEmbedder expects a string as input." + "In case you want to embed a list of Documents, please use the " + "SentenceTransformersDocumentEmbedder." + ) + + import numpy as np + + embedding = np.ones(384).tolist() + return {"embedding": embedding} + + # mocked run + embedder.run = mock_run + + # initialize the component + embedder.warm_up() + + return embedder + + @pytest.fixture + @patch.dict("os.environ", {"OPENAI_API_KEY": "test-api-key"}) + def hybrid_rag_pipeline( + self, document_store, mock_transformers_similarity_ranker, mock_sentence_transformers_text_embedder + ): + """Create a hybrid RAG pipeline for testing.""" + + prompt_template = """ + Given these documents, answer the question based on the document content only.\nDocuments: + {% for doc in documents %} + {{ doc.content }} + {% endfor %} + + \nQuestion: {{question}} + \nAnswer: + """ + pipeline = Pipeline(connection_type_validation=False) + pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever") + + # Use the mocked embedder instead of creating a new one + pipeline.add_component(instance=mock_sentence_transformers_text_embedder, name="query_embedder") + + pipeline.add_component( + instance=InMemoryEmbeddingRetriever(document_store=document_store), name="embedding_retriever" + ) + pipeline.add_component(instance=DocumentJoiner(sort_by_score=False), name="doc_joiner") + + # Use the mocked ranker instead of the real one + pipeline.add_component(instance=mock_transformers_similarity_ranker, name="ranker") + + pipeline.add_component( + instance=PromptBuilder(template=prompt_template, required_variables=["documents", "question"]), + name="prompt_builder", + ) + + # Use a mocked API key for the OpenAIGenerator + pipeline.add_component(instance=OpenAIGenerator(api_key=Secret.from_env_var("OPENAI_API_KEY")), name="llm") + pipeline.add_component(instance=AnswerBuilder(), name="answer_builder") + + pipeline.connect("query_embedder", "embedding_retriever.query_embedding") + pipeline.connect("embedding_retriever", "doc_joiner.documents") + pipeline.connect("bm25_retriever", "doc_joiner.documents") + pipeline.connect("doc_joiner", "ranker.documents") + pipeline.connect("ranker", "prompt_builder.documents") + pipeline.connect("prompt_builder", "llm") + pipeline.connect("llm.replies", "answer_builder.replies") + pipeline.connect("llm.meta", "answer_builder.meta") + pipeline.connect("doc_joiner", "answer_builder.documents") + + return pipeline + + @pytest.fixture(scope="session") + def output_directory(self, tmp_path_factory) -> Path: + return tmp_path_factory.mktemp("output_files") + + BREAKPOINT_COMPONENTS = [ + "bm25_retriever", + "query_embedder", + "embedding_retriever", + "doc_joiner", + "ranker", + "prompt_builder", + "llm", + "answer_builder", + ] + + @pytest.mark.parametrize("component", BREAKPOINT_COMPONENTS, ids=BREAKPOINT_COMPONENTS) + @pytest.mark.integration + def test_pipeline_breakpoints_hybrid_rag( + self, hybrid_rag_pipeline, document_store, output_directory, component, mock_openai_completion + ): + """ + Test that a hybrid RAG pipeline can be executed with breakpoints at each component. + """ + # Test data + question = "Where does Mark live?" + data = { + "query_embedder": {"text": question}, + "bm25_retriever": {"query": question}, + "ranker": {"query": question, "top_k": 5}, + "prompt_builder": {"question": question}, + "answer_builder": {"query": question}, + } + + # Create a Breakpoint on-the-fly using the shared output directory + break_point = Breakpoint(component_name=component, visit_count=0, snapshot_file_path=str(output_directory)) + + try: + _ = hybrid_rag_pipeline.run(data, break_point=break_point) + except BreakpointException: + pass + + result = load_and_resume_pipeline_snapshot( + pipeline=hybrid_rag_pipeline, + output_directory=output_directory, + component_name=break_point.component_name, + data=data, + ) + assert result["answer_builder"] diff --git a/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py new file mode 100644 index 0000000000..1ab96c34f6 --- /dev/null +++ b/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from haystack.components.builders.chat_prompt_builder import ChatPromptBuilder +from haystack.components.converters import OutputAdapter +from haystack.components.joiners import StringJoiner +from haystack.core.errors import BreakpointException +from haystack.core.pipeline.pipeline import Pipeline +from haystack.dataclasses import ChatMessage +from haystack.dataclasses.breakpoints import Breakpoint +from test.conftest import load_and_resume_pipeline_snapshot + + +class TestPipelineBreakpoints: + @pytest.fixture + def string_joiner_pipeline(self): + pipeline = Pipeline() + pipeline.add_component( + "prompt_builder_1", ChatPromptBuilder(template=[ChatMessage.from_user("Builder 1: {{query}}")]) + ) + pipeline.add_component( + "prompt_builder_2", ChatPromptBuilder(template=[ChatMessage.from_user("Builder 2: {{query}}")]) + ) + pipeline.add_component("adapter_1", OutputAdapter("{{messages[0].text}}", output_type=str)) + pipeline.add_component("adapter_2", OutputAdapter("{{messages[0].text}}", output_type=str)) + pipeline.add_component("string_joiner", StringJoiner()) + + pipeline.connect("prompt_builder_1.prompt", "adapter_1.messages") + pipeline.connect("prompt_builder_2.prompt", "adapter_2.messages") + pipeline.connect("adapter_1", "string_joiner.strings") + pipeline.connect("adapter_2", "string_joiner.strings") + + return pipeline + + @pytest.fixture(scope="session") + def output_directory(self, tmp_path_factory): + return tmp_path_factory.mktemp("output_files") + + BREAKPOINT_COMPONENTS = ["prompt_builder_1", "prompt_builder_2", "adapter_1", "adapter_2", "string_joiner"] + + @pytest.mark.parametrize("component", BREAKPOINT_COMPONENTS, ids=BREAKPOINT_COMPONENTS) + @pytest.mark.integration + def test_string_joiner_pipeline(self, string_joiner_pipeline, output_directory, component): + string_1 = "What's Natural Language Processing?" + string_2 = "What is life?" + data = {"prompt_builder_1": {"query": string_1}, "prompt_builder_2": {"query": string_2}} + + # Create a Breakpoint on-the-fly using the shared output directory + break_point = Breakpoint(component_name=component, visit_count=0, snapshot_file_path=str(output_directory)) + + try: + _ = string_joiner_pipeline.run(data, break_point=break_point) + except BreakpointException: + pass + + result = load_and_resume_pipeline_snapshot( + pipeline=string_joiner_pipeline, + output_directory=output_directory, + component_name=break_point.component_name, + data=data, + ) + assert result["string_joiner"]