Skip to content

Commit f938a1b

Browse files
authored
feat: support basic function call for gemini (google-generativeai) (#17696)
1 parent ec4dc2c commit f938a1b

File tree

5 files changed

+215
-44
lines changed

5 files changed

+215
-44
lines changed

llama-index-core/llama_index/core/agent/workflow/multi_agent_workflow.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
{msg}
5252
"""
5353

54+
DEFAULT_HANDOFF_OUTPUT_PROMPT = "Agent {to_agent} is now handling the request due to the following reason: {reason}.\nPlease continue with the current request."
55+
5456

5557
async def handoff(ctx: Context, to_agent: str, reason: str) -> str:
5658
"""Handoff control of that chat to the given agent."""
@@ -61,7 +63,11 @@ async def handoff(ctx: Context, to_agent: str, reason: str) -> str:
6163
return f"Agent {to_agent} not found. Please select a valid agent to hand off to. Valid agents: {valid_agents}"
6264

6365
await ctx.set("next_agent", to_agent)
64-
return f"Agent {to_agent} is now handling the request due to the following reason: {reason}.\nPlease continue."
66+
handoff_output_prompt = await ctx.get(
67+
"handoff_output_prompt", default=DEFAULT_HANDOFF_OUTPUT_PROMPT
68+
)
69+
70+
return handoff_output_prompt.format(to_agent=to_agent, reason=reason)
6571

6672

6773
class AgentWorkflowMeta(WorkflowMeta, ABCMeta):
@@ -77,6 +83,7 @@ def __init__(
7783
initial_state: Optional[Dict] = None,
7884
root_agent: Optional[str] = None,
7985
handoff_prompt: Optional[Union[str, BasePromptTemplate]] = None,
86+
handoff_output_prompt: Optional[Union[str, BasePromptTemplate]] = None,
8087
state_prompt: Optional[Union[str, BasePromptTemplate]] = None,
8188
timeout: Optional[float] = None,
8289
**workflow_kwargs: Any,
@@ -106,6 +113,18 @@ def __init__(
106113
raise ValueError("Handoff prompt must contain {agent_info}")
107114
self.handoff_prompt = handoff_prompt
108115

116+
handoff_output_prompt = handoff_output_prompt or DEFAULT_HANDOFF_OUTPUT_PROMPT
117+
if isinstance(handoff_output_prompt, str):
118+
handoff_output_prompt = PromptTemplate(handoff_output_prompt)
119+
if (
120+
"{to_agent}" not in handoff_output_prompt.get_template()
121+
or "{reason}" not in handoff_output_prompt.get_template()
122+
):
123+
raise ValueError(
124+
"Handoff output prompt must contain {to_agent} and {reason}"
125+
)
126+
self.handoff_output_prompt = handoff_output_prompt
127+
109128
state_prompt = state_prompt or DEFAULT_STATE_PROMPT
110129
if isinstance(state_prompt, str):
111130
state_prompt = PromptTemplate(state_prompt)
@@ -120,6 +139,7 @@ def _get_prompts(self) -> PromptDictType:
120139
"""Get prompts."""
121140
return {
122141
"handoff_prompt": self.handoff_prompt,
142+
"handoff_output_prompt": self.handoff_output_prompt,
123143
"state_prompt": self.state_prompt,
124144
}
125145

@@ -131,6 +151,8 @@ def _update_prompts(self, prompts_dict: PromptDictType) -> None:
131151
"""Update prompts."""
132152
if "handoff_prompt" in prompts_dict:
133153
self.handoff_prompt = prompts_dict["handoff_prompt"]
154+
if "handoff_output_prompt" in prompts_dict:
155+
self.handoff_output_prompt = prompts_dict["handoff_output_prompt"]
134156
if "state_prompt" in prompts_dict:
135157
self.state_prompt = prompts_dict["state_prompt"]
136158

@@ -203,6 +225,10 @@ async def _init_context(self, ctx: Context, ev: StartEvent) -> None:
203225
await ctx.set("state", self.initial_state)
204226
if not await ctx.get("current_agent_name", default=None):
205227
await ctx.set("current_agent_name", self.root_agent)
228+
if not await ctx.get("handoff_output_prompt", default=None):
229+
await ctx.set(
230+
"handoff_output_prompt", self.handoff_output_prompt.get_template()
231+
)
206232

207233
async def _call_tool(
208234
self,

llama-index-integrations/llms/llama-index-llms-gemini/llama_index/llms/gemini/base.py

Lines changed: 121 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import os
44
import warnings
5-
from typing import Any, Dict, Optional, Sequence, cast
5+
import uuid
6+
from typing import TYPE_CHECKING, Union, List, Any, Dict, Optional, Sequence, cast
67

78
import google.generativeai as genai
89
from google.generativeai.types import generation_types
@@ -13,15 +14,17 @@
1314
ChatResponseGen,
1415
CompletionResponse,
1516
CompletionResponseGen,
17+
CompletionResponseAsyncGen,
1618
LLMMetadata,
19+
MessageRole,
1720
)
1821
from llama_index.core.bridge.pydantic import Field, PrivateAttr
1922
from llama_index.core.callbacks import CallbackManager
2023
from llama_index.core.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE
24+
from llama_index.core.llms.llm import ToolSelection
2125
from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback
22-
from llama_index.core.llms.custom import CustomLLM
26+
from llama_index.core.llms.function_calling import FunctionCallingLLM
2327
from llama_index.core.utilities.gemini_utils import (
24-
ROLES_FROM_GEMINI,
2528
merge_neighboring_same_role_messages,
2629
)
2730

@@ -49,8 +52,11 @@
4952
"gemini-1.0-pro",
5053
)
5154

55+
if TYPE_CHECKING:
56+
from llama_index.core.tools.types import BaseTool
5257

53-
class Gemini(CustomLLM):
58+
59+
class Gemini(FunctionCallingLLM):
5460
"""
5561
Gemini LLM.
5662
@@ -181,6 +187,8 @@ def metadata(self) -> LLMMetadata:
181187
num_output=self.max_tokens,
182188
model_name=self.model,
183189
is_chat_model=True,
190+
# All gemini models support function calling
191+
is_function_calling_model=True,
184192
)
185193

186194
@llm_completion_callback()
@@ -208,10 +216,30 @@ def stream_complete(
208216
self, prompt: str, formatted: bool = False, **kwargs: Any
209217
) -> CompletionResponseGen:
210218
request_options = self._request_options or kwargs.pop("request_options", None)
211-
it = self._model.generate_content(
212-
prompt, stream=True, request_options=request_options, **kwargs
213-
)
214-
yield from map(completion_from_gemini_response, it)
219+
220+
def gen():
221+
it = self._model.generate_content(
222+
prompt, stream=True, request_options=request_options, **kwargs
223+
)
224+
for r in it:
225+
yield completion_from_gemini_response(r)
226+
227+
return gen()
228+
229+
@llm_completion_callback()
230+
def astream_complete(
231+
self, prompt: str, formatted: bool = False, **kwargs: Any
232+
) -> CompletionResponseAsyncGen:
233+
request_options = self._request_options or kwargs.pop("request_options", None)
234+
235+
async def gen():
236+
it = await self._model.generate_content_async(
237+
prompt, stream=True, request_options=request_options, **kwargs
238+
)
239+
async for r in it:
240+
yield completion_from_gemini_response(r)
241+
242+
return gen()
215243

216244
@llm_chat_callback()
217245
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
@@ -256,19 +284,11 @@ def gen() -> ChatResponseGen:
256284
for r in response:
257285
top_candidate = r.candidates[0]
258286
content_delta = top_candidate.content.parts[0].text
259-
role = ROLES_FROM_GEMINI[top_candidate.content.role]
260-
raw = {
261-
**(type(top_candidate).to_dict(top_candidate)), # type: ignore
262-
**(
263-
type(response.prompt_feedback).to_dict(response.prompt_feedback) # type: ignore
264-
),
265-
}
266287
content += content_delta
267-
yield ChatResponse(
268-
message=ChatMessage(role=role, content=content),
269-
delta=content_delta,
270-
raw=raw,
271-
)
288+
llama_resp = chat_from_gemini_response(r)
289+
llama_resp.delta = content_delta
290+
llama_resp.message.content = content
291+
yield llama_resp
272292

273293
return gen()
274294

@@ -278,7 +298,7 @@ async def astream_chat(
278298
) -> ChatResponseAsyncGen:
279299
request_options = self._request_options or kwargs.pop("request_options", None)
280300
merged_messages = merge_neighboring_same_role_messages(messages)
281-
*history, next_msg = map(chat_message_to_gemini, messages)
301+
*history, next_msg = map(chat_message_to_gemini, merged_messages)
282302
chat = self._model.start_chat(history=history)
283303
response = await chat.send_message_async(
284304
next_msg, stream=True, request_options=request_options, **kwargs
@@ -289,18 +309,86 @@ async def gen() -> ChatResponseAsyncGen:
289309
async for r in response:
290310
top_candidate = r.candidates[0]
291311
content_delta = top_candidate.content.parts[0].text
292-
role = ROLES_FROM_GEMINI[top_candidate.content.role]
293-
raw = {
294-
**(type(top_candidate).to_dict(top_candidate)), # type: ignore
295-
**(
296-
type(response.prompt_feedback).to_dict(response.prompt_feedback) # type: ignore
297-
),
298-
}
299312
content += content_delta
300-
yield ChatResponse(
301-
message=ChatMessage(role=role, content=content),
302-
delta=content_delta,
303-
raw=raw,
304-
)
313+
llama_resp = chat_from_gemini_response(r)
314+
llama_resp.delta = content_delta
315+
llama_resp.message.content = content
316+
yield llama_resp
305317

306318
return gen()
319+
320+
def _prepare_chat_with_tools(
321+
self,
322+
tools: Sequence["BaseTool"],
323+
user_msg: Optional[Union[str, ChatMessage]] = None,
324+
chat_history: Optional[List[ChatMessage]] = None,
325+
verbose: bool = False,
326+
allow_parallel_tool_calls: bool = False,
327+
tool_choice: Union[str, dict] = "auto",
328+
strict: Optional[bool] = None,
329+
**kwargs: Any,
330+
) -> Dict[str, Any]:
331+
"""Predict and call the tool."""
332+
from google.generativeai.types import FunctionDeclaration, ToolDict
333+
334+
tool_declarations = []
335+
for tool in tools:
336+
descriptions = {}
337+
for param_name, param_schema in tool.metadata.get_parameters_dict()[
338+
"properties"
339+
].items():
340+
param_description = param_schema.get("description", None)
341+
if param_description:
342+
descriptions[param_name] = param_description
343+
344+
tool.metadata.fn_schema.__doc__ = tool.metadata.description
345+
tool_declarations.append(
346+
FunctionDeclaration.from_function(tool.metadata.fn_schema, descriptions)
347+
)
348+
349+
if isinstance(user_msg, str):
350+
user_msg = ChatMessage(role=MessageRole.USER, content=user_msg)
351+
352+
messages = chat_history or []
353+
if user_msg:
354+
messages.append(user_msg)
355+
356+
return {
357+
"messages": messages,
358+
"tools": ToolDict(function_declarations=tool_declarations)
359+
if tool_declarations
360+
else None,
361+
**kwargs,
362+
}
363+
364+
def get_tool_calls_from_response(
365+
self,
366+
response: ChatResponse,
367+
error_on_no_tool_call: bool = True,
368+
**kwargs: Any,
369+
) -> List[ToolSelection]:
370+
"""Predict and call the tool."""
371+
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
372+
373+
if len(tool_calls) < 1:
374+
if error_on_no_tool_call:
375+
raise ValueError(
376+
f"Expected at least one tool call, but got {len(tool_calls)} tool calls."
377+
)
378+
else:
379+
return []
380+
381+
tool_selections = []
382+
for tool_call in tool_calls:
383+
if not isinstance(tool_call, genai.protos.FunctionCall):
384+
raise ValueError("Invalid tool_call object")
385+
386+
tool_selections.append(
387+
ToolSelection(
388+
tool_id=str(uuid.uuid4()),
389+
tool_name=tool_call.name,
390+
tool_kwargs=dict(tool_call.args),
391+
)
392+
)
393+
394+
return tool_selections

llama-index-integrations/llms/llama-index-llms-gemini/llama_index/llms/gemini/utils.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union
1+
from typing import Union, Dict, Any
22

33
import google.ai.generativelanguage as glm
44
import google.generativeai as genai
@@ -68,13 +68,35 @@ def chat_from_gemini_response(
6868
response.usage_metadata
6969
)
7070
role = ROLES_FROM_GEMINI[top_candidate.content.role]
71-
return ChatResponse(message=ChatMessage(role=role, content=response.text), raw=raw)
71+
try:
72+
# When the response contains only a function call, the library
73+
# raises an exception.
74+
# The easiest way to detect this is to try access the text attribute and
75+
# catch the exception.
76+
# https://github.com/google-gemini/generative-ai-python/issues/670
77+
text = response.text
78+
except (ValueError, AttributeError):
79+
text = None
80+
81+
additional_kwargs: Dict[str, Any] = {}
82+
for part in response.parts:
83+
if fn := part.function_call:
84+
if "tool_calls" not in additional_kwargs:
85+
additional_kwargs["tool_calls"] = []
86+
additional_kwargs["tool_calls"].append(fn)
87+
88+
return ChatResponse(
89+
message=ChatMessage(
90+
role=role, content=text, additional_kwargs=additional_kwargs
91+
),
92+
raw=raw,
93+
additional_kwargs=additional_kwargs,
94+
)
7295

7396

7497
def chat_message_to_gemini(message: ChatMessage) -> "genai.types.ContentDict":
7598
"""Convert ChatMessages to Gemini-specific history, including ImageDocuments."""
7699
parts = []
77-
content_txt = ""
78100
for block in message.blocks:
79101
if isinstance(block, TextBlock):
80102
parts.append(block.text)

llama-index-integrations/llms/llama-index-llms-gemini/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
2727
license = "MIT"
2828
name = "llama-index-llms-gemini"
2929
readme = "README.md"
30-
version = "0.4.6"
30+
version = "0.4.7"
3131

3232
[tool.poetry.dependencies]
3333
python = ">=3.9,<4.0"

0 commit comments

Comments
 (0)