diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 5df6549159..8cea8cc16a 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -210,9 +210,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "Agent": return default_from_dict(cls, data) - def _prepare_generator_inputs(self, streaming_callback: Optional[StreamingCallbackT] = None) -> Dict[str, Any]: + def _prepare_generator_inputs( + self, + streaming_callback: Optional[StreamingCallbackT] = None, + generation_kwargs: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: """Prepare inputs for the chat generator.""" generator_inputs: Dict[str, Any] = {"tools": self.tools} + if generation_kwargs is not None: + generator_inputs["generation_kwargs"] = generation_kwargs if streaming_callback is not None: generator_inputs["streaming_callback"] = streaming_callback return generator_inputs @@ -230,7 +236,11 @@ def _create_agent_span(self) -> Any: ) def run( - self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any + self, + messages: List[ChatMessage], + streaming_callback: Optional[StreamingCallbackT] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any ) -> Dict[str, Any]: """ Process messages and execute tools until an exit condition is met. @@ -239,6 +249,8 @@ def run( If a list of dictionaries is provided, each dictionary will be converted to a ChatMessage object. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param generation_kwargs: Additional keyword arguments for LLM. These parameters will + override the parameters passed during component initialization. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :returns: @@ -267,7 +279,7 @@ def run( streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False ) - generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback) + generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback, generation_kwargs=generation_kwargs) with self._create_agent_span() as span: span.set_content_tag( "haystack.agent.input", @@ -328,7 +340,11 @@ def run( return result async def run_async( - self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any + self, + messages: List[ChatMessage], + streaming_callback: Optional[StreamingCallbackT] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any ) -> Dict[str, Any]: """ Asynchronously process messages and execute tools until the exit condition is met. @@ -341,6 +357,8 @@ async def run_async( :param streaming_callback: An asynchronous callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param generation_kwargs: Additional keyword arguments for LLM. These parameters will + override the parameters passed during component initialization. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :returns: @@ -369,7 +387,7 @@ async def run_async( streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True ) - generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback) + generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback, generation_kwargs=generation_kwargs) with self._create_agent_span() as span: span.set_content_tag( "haystack.agent.input", diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index 61e45e9610..221f4c2f08 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -746,7 +746,7 @@ def test_run(self, weather_tool): chat_generator = OpenAIChatGenerator(model="gpt-4o-mini") agent = Agent(chat_generator=chat_generator, tools=[weather_tool], max_agent_steps=3) agent.warm_up() - response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")]) + response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")], generation_kwargs={"temperature": 0.0}) assert isinstance(response, dict) assert "messages" in response