Skip to content

Commit d30631e

Browse files
ilayaperumalgsobychacko
authored andcommitted
Google Vertex AI toolcalling token usage
- Accumulate token usage metrics when toolcalling is used - Fix for both call() and stream() methods - Add/update tests Resolves #1992 Signed-off-by: Ilayaperumal Gopinathan <ilayaperumal.gopinathan@broadcom.com>
1 parent 4f67959 commit d30631e

File tree

2 files changed

+77
-37
lines changed

2 files changed

+77
-37
lines changed

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
import org.slf4j.Logger;
4949
import org.slf4j.LoggerFactory;
5050
import reactor.core.publisher.Flux;
51-
import reactor.core.publisher.Mono;
5251
import reactor.core.scheduler.Schedulers;
5352

5453
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -60,7 +59,9 @@
6059
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
6160
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
6261
import org.springframework.ai.chat.metadata.DefaultUsage;
63-
import org.springframework.ai.chat.model.AbstractToolCallSupport;
62+
import org.springframework.ai.chat.metadata.EmptyUsage;
63+
import org.springframework.ai.chat.metadata.Usage;
64+
import org.springframework.ai.chat.metadata.UsageUtils;
6465
import org.springframework.ai.chat.model.ChatModel;
6566
import org.springframework.ai.chat.model.ChatResponse;
6667
import org.springframework.ai.chat.model.Generation;
@@ -71,12 +72,11 @@
7172
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
7273
import org.springframework.ai.chat.prompt.ChatOptions;
7374
import org.springframework.ai.chat.prompt.Prompt;
74-
import org.springframework.ai.model.ChatModelDescription;
7575
import org.springframework.ai.content.Media;
76+
import org.springframework.ai.model.ChatModelDescription;
7677
import org.springframework.ai.model.ModelOptionsUtils;
7778
import org.springframework.ai.model.function.FunctionCallback;
7879
import org.springframework.ai.model.function.FunctionCallbackResolver;
79-
import org.springframework.ai.model.function.FunctionCallingOptions;
8080
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
8181
import org.springframework.ai.model.tool.LegacyToolCallingManager;
8282
import org.springframework.ai.model.tool.ToolCallingChatOptions;
@@ -136,12 +136,13 @@
136136
* @author Soby Chacko
137137
* @author Jihoon Kim
138138
* @author Alexandros Pappas
139+
* @author Ilayaperumal Gopinathan
139140
* @since 0.8.1
140141
* @see VertexAiGeminiChatOptions
141142
* @see ToolCallingManager
142143
* @see ChatModel
143144
*/
144-
public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements ChatModel, DisposableBean {
145+
public class VertexAiGeminiChatModel implements ChatModel, DisposableBean {
145146

146147
private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
147148

@@ -277,8 +278,6 @@ public VertexAiGeminiChatModel(VertexAI vertexAI, VertexAiGeminiChatOptions defa
277278
ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry,
278279
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
279280

280-
super(null, VertexAiGeminiChatOptions.builder().build(), List.of());
281-
282281
Assert.notNull(vertexAI, "VertexAI must not be null");
283282
Assert.notNull(defaultOptions, "VertexAiGeminiChatOptions must not be null");
284283
Assert.notNull(defaultOptions.getModel(), "VertexAiGeminiChatOptions.modelName must not be null");
@@ -425,10 +424,10 @@ private static Schema jsonToSchema(String json) {
425424
@Override
426425
public ChatResponse call(Prompt prompt) {
427426
var requestPrompt = this.buildRequestPrompt(prompt);
428-
return this.internalCall(requestPrompt);
427+
return this.internalCall(requestPrompt, null);
429428
}
430429

431-
private ChatResponse internalCall(Prompt prompt) {
430+
private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
432431

433432
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
434433
.prompt(prompt)
@@ -451,8 +450,12 @@ private ChatResponse internalCall(Prompt prompt) {
451450
.flatMap(List::stream)
452451
.toList();
453452

454-
ChatResponse chatResponse = new ChatResponse(generations,
455-
toChatResponseMetadata(generateContentResponse));
453+
GenerateContentResponse.UsageMetadata usage = generateContentResponse.getUsageMetadata();
454+
Usage currentUsage = (usage != null)
455+
? new DefaultUsage(usage.getPromptTokenCount(), usage.getCandidatesTokenCount())
456+
: new EmptyUsage();
457+
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
458+
ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(cumulativeUsage));
456459

457460
observationContext.setResponse(chatResponse);
458461
return chatResponse;
@@ -469,7 +472,8 @@ private ChatResponse internalCall(Prompt prompt) {
469472
}
470473
else {
471474
// Send the tool execution result back to the model.
472-
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()));
475+
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
476+
response);
473477
}
474478
}
475479

@@ -485,10 +489,6 @@ Prompt buildRequestPrompt(Prompt prompt) {
485489
runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
486490
VertexAiGeminiChatOptions.class);
487491
}
488-
else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
489-
runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class,
490-
VertexAiGeminiChatOptions.class);
491-
}
492492
else {
493493
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
494494
VertexAiGeminiChatOptions.class);
@@ -535,10 +535,10 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOp
535535
@Override
536536
public Flux<ChatResponse> stream(Prompt prompt) {
537537
var requestPrompt = this.buildRequestPrompt(prompt);
538-
return this.internalStream(requestPrompt);
538+
return this.internalStream(requestPrompt, null);
539539
}
540540

541-
public Flux<ChatResponse> internalStream(Prompt prompt) {
541+
public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
542542
return Flux.deferContextual(contextView -> {
543543

544544
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
@@ -559,21 +559,22 @@ public Flux<ChatResponse> internalStream(Prompt prompt) {
559559
ResponseStream<GenerateContentResponse> responseStream = request.model
560560
.generateContentStream(request.contents);
561561

562-
Flux<ChatResponse> chatResponse1 = Flux.fromStream(responseStream.stream())
563-
.switchMap(response2 -> Mono.just(response2).map(response -> {
564-
565-
List<Generation> generations = response.getCandidatesList()
566-
.stream()
567-
.map(this::responseCandidateToGeneration)
568-
.flatMap(List::stream)
569-
.toList();
570-
571-
return new ChatResponse(generations, toChatResponseMetadata(response));
562+
Flux<ChatResponse> chatResponseFlux = Flux.fromStream(responseStream.stream()).switchMap(response -> {
563+
List<Generation> generations = response.getCandidatesList()
564+
.stream()
565+
.map(this::responseCandidateToGeneration)
566+
.flatMap(List::stream)
567+
.toList();
572568

573-
}));
569+
GenerateContentResponse.UsageMetadata usage = response.getUsageMetadata();
570+
Usage currentUsage = (usage != null) ? getDefaultUsage(usage) : new EmptyUsage();
571+
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
572+
ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(cumulativeUsage));
573+
return Flux.just(chatResponse);
574+
});
574575

575576
// @formatter:off
576-
Flux<ChatResponse> chatResponseFlux = chatResponse1.flatMap(response -> {
577+
Flux<ChatResponse> flux = chatResponseFlux.flatMap(response -> {
577578
if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
578579
// FIXME: bounded elastic needs to be used since tool calling
579580
// is currently only synchronous
@@ -586,7 +587,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt) {
586587
.build());
587588
} else {
588589
// Send the tool execution result back to the model.
589-
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()));
590+
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response);
590591
}
591592
}).subscribeOn(Schedulers.boundedElastic());
592593
}
@@ -599,7 +600,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt) {
599600
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
600601
// @formatter:on;
601602

602-
return new MessageAggregator().aggregate(chatResponseFlux, observationContext::setResponse);
603+
return new MessageAggregator().aggregate(flux, observationContext::setResponse);
603604

604605
}
605606
catch (Exception e) {
@@ -653,8 +654,8 @@ protected List<Generation> responseCandidateToGeneration(Candidate candidate) {
653654
}
654655
}
655656

656-
private ChatResponseMetadata toChatResponseMetadata(GenerateContentResponse response) {
657-
return ChatResponseMetadata.builder().usage(getDefaultUsage(response.getUsageMetadata())).build();
657+
private ChatResponseMetadata toChatResponseMetadata(Usage usage) {
658+
return ChatResponseMetadata.builder().usage(usage).build();
658659
}
659660

660661
private DefaultUsage getDefaultUsage(GenerateContentResponse.UsageMetadata usageMetadata) {

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/tool/VertexAiGeminiChatModelToolCallingIT.java

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,15 @@ public void functionCallTestInferredOpenApiSchema() {
118118
.build()))
119119
.build();
120120

121-
ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions));
121+
ChatResponse chatResponse = this.chatModel.call(new Prompt(messages, promptOptions));
122122

123-
logger.info("Response: {}", response);
123+
assertThat(chatResponse).isNotNull();
124+
logger.info("Response: {}", chatResponse);
125+
assertThat(chatResponse.getResult().getOutput().getText()).contains("30", "10", "15");
124126

125-
assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15");
127+
assertThat(chatResponse.getMetadata()).isNotNull();
128+
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
129+
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(150).isLessThan(310);
126130

127131
ChatResponse response2 = this.chatModel
128132
.call(new Prompt("What is the payment status for transaction 696?", promptOptions));
@@ -166,6 +170,41 @@ public void functionCallTestInferredOpenApiSchemaStream() {
166170

167171
}
168172

173+
@Test
174+
public void functionCallUsageTestInferredOpenApiSchemaStream() {
175+
176+
UserMessage userMessage = new UserMessage(
177+
"What's the weather like in San Francisco, Paris and in Tokyo? Return the temperature in Celsius.");
178+
179+
List<Message> messages = new ArrayList<>(List.of(userMessage));
180+
181+
var promptOptions = VertexAiGeminiChatOptions.builder()
182+
.model(VertexAiGeminiChatModel.ChatModel.GEMINI_2_0_FLASH)
183+
.toolCallbacks(List.of(
184+
FunctionToolCallback.builder("get_current_weather", new MockWeatherService())
185+
.description("Get the current weather in a given location.")
186+
.inputType(MockWeatherService.Request.class)
187+
.build(),
188+
FunctionToolCallback.builder("get_payment_status", new PaymentStatus())
189+
.description(
190+
"Retrieves the payment status for transaction. For example what is the payment status for transaction 700?")
191+
.inputType(PaymentInfoRequest.class)
192+
.build()))
193+
.build();
194+
195+
Flux<ChatResponse> response = this.chatModel.stream(new Prompt(messages, promptOptions));
196+
197+
ChatResponse chatResponse = response.blockLast();
198+
199+
logger.info("Response: {}", chatResponse);
200+
201+
assertThat(chatResponse).isNotNull();
202+
assertThat(chatResponse.getMetadata()).isNotNull();
203+
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
204+
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isGreaterThan(150).isLessThan(310);
205+
206+
}
207+
169208
public record PaymentInfoRequest(String id) {
170209

171210
}

0 commit comments

Comments
 (0)