From 086e302e5b91672984c59d42741533fb5cac8f2f Mon Sep 17 00:00:00 2001 From: Jan Beitner Date: Wed, 16 Jul 2025 21:16:25 +0100 Subject: [PATCH 1/3] add generation_kwargs to run methods of Agent --- haystack/components/agents/agent.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 5df6549159..51d8008f25 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -210,9 +210,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "Agent": return default_from_dict(cls, data) - def _prepare_generator_inputs(self, streaming_callback: Optional[StreamingCallbackT] = None) -> Dict[str, Any]: + def _prepare_generator_inputs( + self, + streaming_callback: Optional[StreamingCallbackT] = None, + generation_kwargs: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: """Prepare inputs for the chat generator.""" generator_inputs: Dict[str, Any] = {"tools": self.tools} + if generation_kwargs is not None: + generator_inputs["generation_kwargs"] = generation_kwargs if streaming_callback is not None: generator_inputs["streaming_callback"] = streaming_callback return generator_inputs @@ -230,7 +236,11 @@ def _create_agent_span(self) -> Any: ) def run( - self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any + self, + messages: List[ChatMessage], + streaming_callback: Optional[StreamingCallbackT] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any ) -> Dict[str, Any]: """ Process messages and execute tools until an exit condition is met. @@ -239,6 +249,8 @@ def run( If a list of dictionaries is provided, each dictionary will be converted to a ChatMessage object. :param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param generation_kwargs: Additional keyword arguments for LLM. These parameters will + override the parameters passed during component initialization. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :returns: @@ -267,7 +279,7 @@ def run( streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False ) - generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback) + generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback, generator_inputs=generation_kwargs) with self._create_agent_span() as span: span.set_content_tag( "haystack.agent.input", @@ -328,7 +340,11 @@ def run( return result async def run_async( - self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any + self, + messages: List[ChatMessage], + streaming_callback: Optional[StreamingCallbackT] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any ) -> Dict[str, Any]: """ Asynchronously process messages and execute tools until the exit condition is met. @@ -341,6 +357,8 @@ async def run_async( :param streaming_callback: An asynchronous callback that will be invoked when a response is streamed from the LLM. The same callback can be configured to emit tool results when a tool is called. + :param generation_kwargs: Additional keyword arguments for LLM. These parameters will + override the parameters passed during component initialization. :param kwargs: Additional data to pass to the State schema used by the Agent. The keys must match the schema defined in the Agent's `state_schema`. :returns: @@ -369,7 +387,7 @@ async def run_async( streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True ) - generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback) + generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback, generation_kwargs=generation_kwargs) with self._create_agent_span() as span: span.set_content_tag( "haystack.agent.input", From 182f2de3527d4db5529af1e2b055a1e972f2911c Mon Sep 17 00:00:00 2001 From: Jan Beitner Date: Wed, 16 Jul 2025 21:23:27 +0100 Subject: [PATCH 2/3] include test generation_kwargs in agent unit tests --- test/components/agents/test_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/components/agents/test_agent.py b/test/components/agents/test_agent.py index 61e45e9610..221f4c2f08 100644 --- a/test/components/agents/test_agent.py +++ b/test/components/agents/test_agent.py @@ -746,7 +746,7 @@ def test_run(self, weather_tool): chat_generator = OpenAIChatGenerator(model="gpt-4o-mini") agent = Agent(chat_generator=chat_generator, tools=[weather_tool], max_agent_steps=3) agent.warm_up() - response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")]) + response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")], generation_kwargs={"temperature": 0.0}) assert isinstance(response, dict) assert "messages" in response From 20995d7979553ac359d5ed45c3e26ac81e3d18c5 Mon Sep 17 00:00:00 2001 From: Jan Beitner Date: Thu, 17 Jul 2025 08:57:31 +0100 Subject: [PATCH 3/3] fix typo Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> --- haystack/components/agents/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index 51d8008f25..8cea8cc16a 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -279,7 +279,7 @@ def run( streaming_callback = select_streaming_callback( init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False ) - generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback, generator_inputs=generation_kwargs) + generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback, generation_kwargs=generation_kwargs) with self._create_agent_span() as span: span.set_content_tag( "haystack.agent.input",