2
2
3
3
import os
4
4
import warnings
5
- from typing import Any , Dict , Optional , Sequence , cast
5
+ import uuid
6
+ from typing import TYPE_CHECKING , Union , List , Any , Dict , Optional , Sequence , cast
6
7
7
8
import google .generativeai as genai
8
9
from google .generativeai .types import generation_types
13
14
ChatResponseGen ,
14
15
CompletionResponse ,
15
16
CompletionResponseGen ,
17
+ CompletionResponseAsyncGen ,
16
18
LLMMetadata ,
19
+ MessageRole ,
17
20
)
18
21
from llama_index .core .bridge .pydantic import Field , PrivateAttr
19
22
from llama_index .core .callbacks import CallbackManager
20
23
from llama_index .core .constants import DEFAULT_NUM_OUTPUTS , DEFAULT_TEMPERATURE
24
+ from llama_index .core .llms .llm import ToolSelection
21
25
from llama_index .core .llms .callbacks import llm_chat_callback , llm_completion_callback
22
- from llama_index .core .llms .custom import CustomLLM
26
+ from llama_index .core .llms .function_calling import FunctionCallingLLM
23
27
from llama_index .core .utilities .gemini_utils import (
24
- ROLES_FROM_GEMINI ,
25
28
merge_neighboring_same_role_messages ,
26
29
)
27
30
49
52
"gemini-1.0-pro" ,
50
53
)
51
54
55
+ if TYPE_CHECKING :
56
+ from llama_index .core .tools .types import BaseTool
52
57
53
- class Gemini (CustomLLM ):
58
+
59
+ class Gemini (FunctionCallingLLM ):
54
60
"""
55
61
Gemini LLM.
56
62
@@ -181,6 +187,8 @@ def metadata(self) -> LLMMetadata:
181
187
num_output = self .max_tokens ,
182
188
model_name = self .model ,
183
189
is_chat_model = True ,
190
+ # All gemini models support function calling
191
+ is_function_calling_model = True ,
184
192
)
185
193
186
194
@llm_completion_callback ()
@@ -208,10 +216,30 @@ def stream_complete(
208
216
self , prompt : str , formatted : bool = False , ** kwargs : Any
209
217
) -> CompletionResponseGen :
210
218
request_options = self ._request_options or kwargs .pop ("request_options" , None )
211
- it = self ._model .generate_content (
212
- prompt , stream = True , request_options = request_options , ** kwargs
213
- )
214
- yield from map (completion_from_gemini_response , it )
219
+
220
+ def gen ():
221
+ it = self ._model .generate_content (
222
+ prompt , stream = True , request_options = request_options , ** kwargs
223
+ )
224
+ for r in it :
225
+ yield completion_from_gemini_response (r )
226
+
227
+ return gen ()
228
+
229
+ @llm_completion_callback ()
230
+ def astream_complete (
231
+ self , prompt : str , formatted : bool = False , ** kwargs : Any
232
+ ) -> CompletionResponseAsyncGen :
233
+ request_options = self ._request_options or kwargs .pop ("request_options" , None )
234
+
235
+ async def gen ():
236
+ it = await self ._model .generate_content_async (
237
+ prompt , stream = True , request_options = request_options , ** kwargs
238
+ )
239
+ async for r in it :
240
+ yield completion_from_gemini_response (r )
241
+
242
+ return gen ()
215
243
216
244
@llm_chat_callback ()
217
245
def chat (self , messages : Sequence [ChatMessage ], ** kwargs : Any ) -> ChatResponse :
@@ -256,19 +284,11 @@ def gen() -> ChatResponseGen:
256
284
for r in response :
257
285
top_candidate = r .candidates [0 ]
258
286
content_delta = top_candidate .content .parts [0 ].text
259
- role = ROLES_FROM_GEMINI [top_candidate .content .role ]
260
- raw = {
261
- ** (type (top_candidate ).to_dict (top_candidate )), # type: ignore
262
- ** (
263
- type (response .prompt_feedback ).to_dict (response .prompt_feedback ) # type: ignore
264
- ),
265
- }
266
287
content += content_delta
267
- yield ChatResponse (
268
- message = ChatMessage (role = role , content = content ),
269
- delta = content_delta ,
270
- raw = raw ,
271
- )
288
+ llama_resp = chat_from_gemini_response (r )
289
+ llama_resp .delta = content_delta
290
+ llama_resp .message .content = content
291
+ yield llama_resp
272
292
273
293
return gen ()
274
294
@@ -278,7 +298,7 @@ async def astream_chat(
278
298
) -> ChatResponseAsyncGen :
279
299
request_options = self ._request_options or kwargs .pop ("request_options" , None )
280
300
merged_messages = merge_neighboring_same_role_messages (messages )
281
- * history , next_msg = map (chat_message_to_gemini , messages )
301
+ * history , next_msg = map (chat_message_to_gemini , merged_messages )
282
302
chat = self ._model .start_chat (history = history )
283
303
response = await chat .send_message_async (
284
304
next_msg , stream = True , request_options = request_options , ** kwargs
@@ -289,18 +309,86 @@ async def gen() -> ChatResponseAsyncGen:
289
309
async for r in response :
290
310
top_candidate = r .candidates [0 ]
291
311
content_delta = top_candidate .content .parts [0 ].text
292
- role = ROLES_FROM_GEMINI [top_candidate .content .role ]
293
- raw = {
294
- ** (type (top_candidate ).to_dict (top_candidate )), # type: ignore
295
- ** (
296
- type (response .prompt_feedback ).to_dict (response .prompt_feedback ) # type: ignore
297
- ),
298
- }
299
312
content += content_delta
300
- yield ChatResponse (
301
- message = ChatMessage (role = role , content = content ),
302
- delta = content_delta ,
303
- raw = raw ,
304
- )
313
+ llama_resp = chat_from_gemini_response (r )
314
+ llama_resp .delta = content_delta
315
+ llama_resp .message .content = content
316
+ yield llama_resp
305
317
306
318
return gen ()
319
+
320
+ def _prepare_chat_with_tools (
321
+ self ,
322
+ tools : Sequence ["BaseTool" ],
323
+ user_msg : Optional [Union [str , ChatMessage ]] = None ,
324
+ chat_history : Optional [List [ChatMessage ]] = None ,
325
+ verbose : bool = False ,
326
+ allow_parallel_tool_calls : bool = False ,
327
+ tool_choice : Union [str , dict ] = "auto" ,
328
+ strict : Optional [bool ] = None ,
329
+ ** kwargs : Any ,
330
+ ) -> Dict [str , Any ]:
331
+ """Predict and call the tool."""
332
+ from google .generativeai .types import FunctionDeclaration , ToolDict
333
+
334
+ tool_declarations = []
335
+ for tool in tools :
336
+ descriptions = {}
337
+ for param_name , param_schema in tool .metadata .get_parameters_dict ()[
338
+ "properties"
339
+ ].items ():
340
+ param_description = param_schema .get ("description" , None )
341
+ if param_description :
342
+ descriptions [param_name ] = param_description
343
+
344
+ tool .metadata .fn_schema .__doc__ = tool .metadata .description
345
+ tool_declarations .append (
346
+ FunctionDeclaration .from_function (tool .metadata .fn_schema , descriptions )
347
+ )
348
+
349
+ if isinstance (user_msg , str ):
350
+ user_msg = ChatMessage (role = MessageRole .USER , content = user_msg )
351
+
352
+ messages = chat_history or []
353
+ if user_msg :
354
+ messages .append (user_msg )
355
+
356
+ return {
357
+ "messages" : messages ,
358
+ "tools" : ToolDict (function_declarations = tool_declarations )
359
+ if tool_declarations
360
+ else None ,
361
+ ** kwargs ,
362
+ }
363
+
364
+ def get_tool_calls_from_response (
365
+ self ,
366
+ response : ChatResponse ,
367
+ error_on_no_tool_call : bool = True ,
368
+ ** kwargs : Any ,
369
+ ) -> List [ToolSelection ]:
370
+ """Predict and call the tool."""
371
+ tool_calls = response .message .additional_kwargs .get ("tool_calls" , [])
372
+
373
+ if len (tool_calls ) < 1 :
374
+ if error_on_no_tool_call :
375
+ raise ValueError (
376
+ f"Expected at least one tool call, but got { len (tool_calls )} tool calls."
377
+ )
378
+ else :
379
+ return []
380
+
381
+ tool_selections = []
382
+ for tool_call in tool_calls :
383
+ if not isinstance (tool_call , genai .protos .FunctionCall ):
384
+ raise ValueError ("Invalid tool_call object" )
385
+
386
+ tool_selections .append (
387
+ ToolSelection (
388
+ tool_id = str (uuid .uuid4 ()),
389
+ tool_name = tool_call .name ,
390
+ tool_kwargs = dict (tool_call .args ),
391
+ )
392
+ )
393
+
394
+ return tool_selections
0 commit comments