-
Notifications
You must be signed in to change notification settings - Fork 2.3k
feat: Add generation_kwargs to run methods of Agent #9616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -210,9 +210,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "Agent": | |||||||||||
|
||||||||||||
return default_from_dict(cls, data) | ||||||||||||
|
||||||||||||
def _prepare_generator_inputs(self, streaming_callback: Optional[StreamingCallbackT] = None) -> Dict[str, Any]: | ||||||||||||
def _prepare_generator_inputs( | ||||||||||||
self, | ||||||||||||
streaming_callback: Optional[StreamingCallbackT] = None, | ||||||||||||
generation_kwargs: Optional[Dict[str, Any]] = None | ||||||||||||
) -> Dict[str, Any]: | ||||||||||||
"""Prepare inputs for the chat generator.""" | ||||||||||||
generator_inputs: Dict[str, Any] = {"tools": self.tools} | ||||||||||||
if generation_kwargs is not None: | ||||||||||||
generator_inputs["generation_kwargs"] = generation_kwargs | ||||||||||||
if streaming_callback is not None: | ||||||||||||
generator_inputs["streaming_callback"] = streaming_callback | ||||||||||||
return generator_inputs | ||||||||||||
|
@@ -230,7 +236,11 @@ def _create_agent_span(self) -> Any: | |||||||||||
) | ||||||||||||
|
||||||||||||
def run( | ||||||||||||
self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any | ||||||||||||
self, | ||||||||||||
messages: List[ChatMessage], | ||||||||||||
streaming_callback: Optional[StreamingCallbackT] = None, | ||||||||||||
generation_kwargs: Optional[Dict[str, Any]] = None, | ||||||||||||
**kwargs: Any | ||||||||||||
) -> Dict[str, Any]: | ||||||||||||
""" | ||||||||||||
Process messages and execute tools until an exit condition is met. | ||||||||||||
|
@@ -239,6 +249,8 @@ def run( | |||||||||||
If a list of dictionaries is provided, each dictionary will be converted to a ChatMessage object. | ||||||||||||
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM. | ||||||||||||
The same callback can be configured to emit tool results when a tool is called. | ||||||||||||
:param generation_kwargs: Additional keyword arguments for LLM. These parameters will | ||||||||||||
override the parameters passed during component initialization. | ||||||||||||
:param kwargs: Additional data to pass to the State schema used by the Agent. | ||||||||||||
The keys must match the schema defined in the Agent's `state_schema`. | ||||||||||||
:returns: | ||||||||||||
|
@@ -267,7 +279,7 @@ def run( | |||||||||||
streaming_callback = select_streaming_callback( | ||||||||||||
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False | ||||||||||||
) | ||||||||||||
generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback) | ||||||||||||
generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback, generator_inputs=generation_kwargs) | ||||||||||||
jdb78 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
with self._create_agent_span() as span: | ||||||||||||
span.set_content_tag( | ||||||||||||
"haystack.agent.input", | ||||||||||||
|
@@ -328,7 +340,11 @@ def run( | |||||||||||
return result | ||||||||||||
|
||||||||||||
async def run_async( | ||||||||||||
self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any | ||||||||||||
self, | ||||||||||||
messages: List[ChatMessage], | ||||||||||||
streaming_callback: Optional[StreamingCallbackT] = None, | ||||||||||||
generation_kwargs: Optional[Dict[str, Any]] = None, | ||||||||||||
Comment on lines
+345
to
+346
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here
Suggested change
|
||||||||||||
**kwargs: Any | ||||||||||||
) -> Dict[str, Any]: | ||||||||||||
""" | ||||||||||||
Asynchronously process messages and execute tools until the exit condition is met. | ||||||||||||
|
@@ -341,6 +357,8 @@ async def run_async( | |||||||||||
:param streaming_callback: An asynchronous callback that will be invoked when a response | ||||||||||||
is streamed from the LLM. The same callback can be configured to emit tool results | ||||||||||||
when a tool is called. | ||||||||||||
:param generation_kwargs: Additional keyword arguments for LLM. These parameters will | ||||||||||||
override the parameters passed during component initialization. | ||||||||||||
:param kwargs: Additional data to pass to the State schema used by the Agent. | ||||||||||||
The keys must match the schema defined in the Agent's `state_schema`. | ||||||||||||
:returns: | ||||||||||||
|
@@ -369,7 +387,7 @@ async def run_async( | |||||||||||
streaming_callback = select_streaming_callback( | ||||||||||||
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True | ||||||||||||
) | ||||||||||||
generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback) | ||||||||||||
generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback, generation_kwargs=generation_kwargs) | ||||||||||||
with self._create_agent_span() as span: | ||||||||||||
span.set_content_tag( | ||||||||||||
"haystack.agent.input", | ||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -746,7 +746,7 @@ def test_run(self, weather_tool): | |
chat_generator = OpenAIChatGenerator(model="gpt-4o-mini") | ||
agent = Agent(chat_generator=chat_generator, tools=[weather_tool], max_agent_steps=3) | ||
agent.warm_up() | ||
response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")]) | ||
response = agent.run([ChatMessage.from_user("What is the weather in Berlin?")], generation_kwargs={"temperature": 0.0}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we add a new unit test instead that resuses |
||
|
||
assert isinstance(response, dict) | ||
assert "messages" in response | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since I know other work is being done on this run method in a separate feature branch let's go ahead and make
generation_kwargs
a keyword only arg.