48
48
import org .slf4j .Logger ;
49
49
import org .slf4j .LoggerFactory ;
50
50
import reactor .core .publisher .Flux ;
51
- import reactor .core .publisher .Mono ;
52
51
import reactor .core .scheduler .Schedulers ;
53
52
54
53
import org .springframework .ai .chat .messages .AssistantMessage ;
60
59
import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
61
60
import org .springframework .ai .chat .metadata .ChatResponseMetadata ;
62
61
import org .springframework .ai .chat .metadata .DefaultUsage ;
63
- import org .springframework .ai .chat .model .AbstractToolCallSupport ;
62
+ import org .springframework .ai .chat .metadata .EmptyUsage ;
63
+ import org .springframework .ai .chat .metadata .Usage ;
64
+ import org .springframework .ai .chat .metadata .UsageUtils ;
64
65
import org .springframework .ai .chat .model .ChatModel ;
65
66
import org .springframework .ai .chat .model .ChatResponse ;
66
67
import org .springframework .ai .chat .model .Generation ;
71
72
import org .springframework .ai .chat .observation .DefaultChatModelObservationConvention ;
72
73
import org .springframework .ai .chat .prompt .ChatOptions ;
73
74
import org .springframework .ai .chat .prompt .Prompt ;
74
- import org .springframework .ai .model .ChatModelDescription ;
75
75
import org .springframework .ai .content .Media ;
76
+ import org .springframework .ai .model .ChatModelDescription ;
76
77
import org .springframework .ai .model .ModelOptionsUtils ;
77
78
import org .springframework .ai .model .function .FunctionCallback ;
78
79
import org .springframework .ai .model .function .FunctionCallbackResolver ;
79
- import org .springframework .ai .model .function .FunctionCallingOptions ;
80
80
import org .springframework .ai .model .tool .DefaultToolExecutionEligibilityPredicate ;
81
81
import org .springframework .ai .model .tool .LegacyToolCallingManager ;
82
82
import org .springframework .ai .model .tool .ToolCallingChatOptions ;
136
136
* @author Soby Chacko
137
137
* @author Jihoon Kim
138
138
* @author Alexandros Pappas
139
+ * @author Ilayaperumal Gopinathan
139
140
* @since 0.8.1
140
141
* @see VertexAiGeminiChatOptions
141
142
* @see ToolCallingManager
142
143
* @see ChatModel
143
144
*/
144
- public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements ChatModel , DisposableBean {
145
+ public class VertexAiGeminiChatModel implements ChatModel , DisposableBean {
145
146
146
147
private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention ();
147
148
@@ -277,8 +278,6 @@ public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions defa
277
278
ToolCallingManager toolCallingManager , RetryTemplate retryTemplate , ObservationRegistry observationRegistry ,
278
279
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate ) {
279
280
280
- super (null , VertexAiGeminiChatOptions .builder ().build (), List .of ());
281
-
282
281
Assert .notNull (vertexAI , "VertexAI must not be null" );
283
282
Assert .notNull (defaultOptions , "VertexAiGeminiChatOptions must not be null" );
284
283
Assert .notNull (defaultOptions .getModel (), "VertexAiGeminiChatOptions.modelName must not be null" );
@@ -425,10 +424,10 @@ private static Schema jsonToSchema(String json) {
425
424
@ Override
426
425
public ChatResponse call (Prompt prompt ) {
427
426
var requestPrompt = this .buildRequestPrompt (prompt );
428
- return this .internalCall (requestPrompt );
427
+ return this .internalCall (requestPrompt , null );
429
428
}
430
429
431
- private ChatResponse internalCall (Prompt prompt ) {
430
+ private ChatResponse internalCall (Prompt prompt , ChatResponse previousChatResponse ) {
432
431
433
432
ChatModelObservationContext observationContext = ChatModelObservationContext .builder ()
434
433
.prompt (prompt )
@@ -451,8 +450,12 @@ private ChatResponse internalCall(Prompt prompt) {
451
450
.flatMap (List ::stream )
452
451
.toList ();
453
452
454
- ChatResponse chatResponse = new ChatResponse (generations ,
455
- toChatResponseMetadata (generateContentResponse ));
453
+ GenerateContentResponse .UsageMetadata usage = generateContentResponse .getUsageMetadata ();
454
+ Usage currentUsage = (usage != null )
455
+ ? new DefaultUsage (usage .getPromptTokenCount (), usage .getCandidatesTokenCount ())
456
+ : new EmptyUsage ();
457
+ Usage cumulativeUsage = UsageUtils .getCumulativeUsage (currentUsage , previousChatResponse );
458
+ ChatResponse chatResponse = new ChatResponse (generations , toChatResponseMetadata (cumulativeUsage ));
456
459
457
460
observationContext .setResponse (chatResponse );
458
461
return chatResponse ;
@@ -469,7 +472,8 @@ private ChatResponse internalCall(Prompt prompt) {
469
472
}
470
473
else {
471
474
// Send the tool execution result back to the model.
472
- return this .internalCall (new Prompt (toolExecutionResult .conversationHistory (), prompt .getOptions ()));
475
+ return this .internalCall (new Prompt (toolExecutionResult .conversationHistory (), prompt .getOptions ()),
476
+ response );
473
477
}
474
478
}
475
479
@@ -485,10 +489,6 @@ Prompt buildRequestPrompt(Prompt prompt) {
485
489
runtimeOptions = ModelOptionsUtils .copyToTarget (toolCallingChatOptions , ToolCallingChatOptions .class ,
486
490
VertexAiGeminiChatOptions .class );
487
491
}
488
- else if (prompt .getOptions () instanceof FunctionCallingOptions functionCallingOptions ) {
489
- runtimeOptions = ModelOptionsUtils .copyToTarget (functionCallingOptions , FunctionCallingOptions .class ,
490
- VertexAiGeminiChatOptions .class );
491
- }
492
492
else {
493
493
runtimeOptions = ModelOptionsUtils .copyToTarget (prompt .getOptions (), ChatOptions .class ,
494
494
VertexAiGeminiChatOptions .class );
@@ -535,10 +535,10 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOp
535
535
@ Override
536
536
public Flux <ChatResponse > stream (Prompt prompt ) {
537
537
var requestPrompt = this .buildRequestPrompt (prompt );
538
- return this .internalStream (requestPrompt );
538
+ return this .internalStream (requestPrompt , null );
539
539
}
540
540
541
- public Flux <ChatResponse > internalStream (Prompt prompt ) {
541
+ public Flux <ChatResponse > internalStream (Prompt prompt , ChatResponse previousChatResponse ) {
542
542
return Flux .deferContextual (contextView -> {
543
543
544
544
ChatModelObservationContext observationContext = ChatModelObservationContext .builder ()
@@ -559,21 +559,22 @@ public Flux<ChatResponse> internalStream(Prompt prompt) {
559
559
ResponseStream <GenerateContentResponse > responseStream = request .model
560
560
.generateContentStream (request .contents );
561
561
562
- Flux <ChatResponse > chatResponse1 = Flux .fromStream (responseStream .stream ())
563
- .switchMap (response2 -> Mono .just (response2 ).map (response -> {
564
-
565
- List <Generation > generations = response .getCandidatesList ()
566
- .stream ()
567
- .map (this ::responseCandidateToGeneration )
568
- .flatMap (List ::stream )
569
- .toList ();
570
-
571
- return new ChatResponse (generations , toChatResponseMetadata (response ));
562
+ Flux <ChatResponse > chatResponseFlux = Flux .fromStream (responseStream .stream ()).switchMap (response -> {
563
+ List <Generation > generations = response .getCandidatesList ()
564
+ .stream ()
565
+ .map (this ::responseCandidateToGeneration )
566
+ .flatMap (List ::stream )
567
+ .toList ();
572
568
573
- }));
569
+ GenerateContentResponse .UsageMetadata usage = response .getUsageMetadata ();
570
+ Usage currentUsage = (usage != null ) ? getDefaultUsage (usage ) : new EmptyUsage ();
571
+ Usage cumulativeUsage = UsageUtils .getCumulativeUsage (currentUsage , previousChatResponse );
572
+ ChatResponse chatResponse = new ChatResponse (generations , toChatResponseMetadata (cumulativeUsage ));
573
+ return Flux .just (chatResponse );
574
+ });
574
575
575
576
// @formatter:off
576
- Flux <ChatResponse > chatResponseFlux = chatResponse1 .flatMap (response -> {
577
+ Flux <ChatResponse > flux = chatResponseFlux .flatMap (response -> {
577
578
if (toolExecutionEligibilityPredicate .isToolExecutionRequired (prompt .getOptions (), response )) {
578
579
// FIXME: bounded elastic needs to be used since tool calling
579
580
// is currently only synchronous
@@ -586,7 +587,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt) {
586
587
.build ());
587
588
} else {
588
589
// Send the tool execution result back to the model.
589
- return this .internalStream (new Prompt (toolExecutionResult .conversationHistory (), prompt .getOptions ()));
590
+ return this .internalStream (new Prompt (toolExecutionResult .conversationHistory (), prompt .getOptions ()), response );
590
591
}
591
592
}).subscribeOn (Schedulers .boundedElastic ());
592
593
}
@@ -599,7 +600,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt) {
599
600
.contextWrite (ctx -> ctx .put (ObservationThreadLocalAccessor .KEY , observation ));
600
601
// @formatter:on;
601
602
602
- return new MessageAggregator ().aggregate (chatResponseFlux , observationContext ::setResponse );
603
+ return new MessageAggregator ().aggregate (flux , observationContext ::setResponse );
603
604
604
605
}
605
606
catch (Exception e ) {
@@ -653,8 +654,8 @@ protected List<Generation> responseCandidateToGeneration(Candidate candidate) {
653
654
}
654
655
}
655
656
656
- private ChatResponseMetadata toChatResponseMetadata (GenerateContentResponse response ) {
657
- return ChatResponseMetadata .builder ().usage (getDefaultUsage ( response . getUsageMetadata ()) ).build ();
657
+ private ChatResponseMetadata toChatResponseMetadata (Usage usage ) {
658
+ return ChatResponseMetadata .builder ().usage (usage ).build ();
658
659
}
659
660
660
661
private DefaultUsage getDefaultUsage (GenerateContentResponse .UsageMetadata usageMetadata ) {
0 commit comments