Skip to content

Commit 00ede3f

Browse files
tzolovilayaperumalg
authored andcommitted
Fix Bedrock Converse streaming/call and token handling
- Modify stream method to support recursive tool call handling - Update token tracking and metadata merging for streamed responses - Improve token usage calculation for tool use events - Update test cases to handle new response processing - Modify call method to support recursive tool call handling - Add support for cumulative token tracking across tool call iterations - Introduce internal call method to track and aggregate token usage - Merge previous chat response tokens with current response tokens Resolves #1743
1 parent 13d2074 commit 00ede3f

File tree

4 files changed

+145
-24
lines changed

4 files changed

+145
-24
lines changed

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient,
169169
*/
170170
@Override
171171
public ChatResponse call(Prompt prompt) {
172+
return this.internalCall(prompt, null);
173+
}
174+
175+
private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatResponse) {
172176

173177
ConverseRequest converseRequest = this.createRequest(prompt);
174178

@@ -185,7 +189,7 @@ public ChatResponse call(Prompt prompt) {
185189

186190
ConverseResponse converseResponse = this.bedrockRuntimeClient.converse(converseRequest);
187191

188-
var response = this.toChatResponse(converseResponse);
192+
var response = this.toChatResponse(converseResponse, perviousChatResponse);
189193

190194
observationContext.setResponse(response);
191195

@@ -195,7 +199,7 @@ public ChatResponse call(Prompt prompt) {
195199
if (!this.isProxyToolCalls(prompt, this.defaultOptions) && chatResponse != null
196200
&& this.isToolCall(chatResponse, Set.of("tool_use"))) {
197201
var toolCallConversation = this.handleToolCalls(prompt, chatResponse);
198-
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
202+
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse);
199203
}
200204

201205
return chatResponse;
@@ -402,7 +406,7 @@ else if (mediaData instanceof URL url) {
402406
* @param response The Bedrock Converse response.
403407
* @return The ChatResponse entity.
404408
*/
405-
private ChatResponse toChatResponse(ConverseResponse response) {
409+
private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perviousChatResponse) {
406410

407411
Assert.notNull(response, "'response' must not be null.");
408412

@@ -448,8 +452,19 @@ private ChatResponse toChatResponse(ConverseResponse response) {
448452
allGenerations.add(toolCallGeneration);
449453
}
450454

451-
DefaultUsage usage = new DefaultUsage(response.usage().inputTokens().longValue(),
452-
response.usage().outputTokens().longValue(), response.usage().totalTokens().longValue());
455+
Long promptTokens = response.usage().inputTokens().longValue();
456+
Long generationTokens = response.usage().outputTokens().longValue();
457+
Long totalTokens = response.usage().totalTokens().longValue();
458+
459+
if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null
460+
&& perviousChatResponse.getMetadata().getUsage() != null) {
461+
462+
promptTokens += perviousChatResponse.getMetadata().getUsage().getPromptTokens();
463+
generationTokens += perviousChatResponse.getMetadata().getUsage().getGenerationTokens();
464+
totalTokens += perviousChatResponse.getMetadata().getUsage().getTotalTokens();
465+
}
466+
467+
DefaultUsage usage = new DefaultUsage(promptTokens, generationTokens, totalTokens);
453468

454469
Document modelResponseFields = response.additionalModelResponseFields();
455470

@@ -473,14 +488,16 @@ private ChatResponse toChatResponse(ConverseResponse response) {
473488
*/
474489
@Override
475490
public Flux<ChatResponse> stream(Prompt prompt) {
491+
return this.internalStream(prompt, null);
492+
}
493+
494+
private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousChatResponse) {
476495
Assert.notNull(prompt, "'prompt' must not be null");
477496

478497
return Flux.deferContextual(contextView -> {
479498

480499
ConverseRequest converseRequest = this.createRequest(prompt);
481500

482-
// System.out.println(">>>>> CONVERSE REQUEST: " + converseRequest);
483-
484501
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
485502
.prompt(prompt)
486503
.provider(AiProvider.BEDROCK_CONVERSE.value())
@@ -504,13 +521,13 @@ public Flux<ChatResponse> stream(Prompt prompt) {
504521
Flux<ConverseStreamOutput> response = converseStream(converseStreamRequest);
505522

506523
// @formatter:off
507-
Flux<ChatResponse> chatResponses = ConverseApiUtils.toChatResponse(response);
524+
Flux<ChatResponse> chatResponses = ConverseApiUtils.toChatResponse(response, perviousChatResponse);
508525

509526
Flux<ChatResponse> chatResponseFlux = chatResponses.switchMap(chatResponse -> {
510527
if (!this.isProxyToolCalls(prompt, this.defaultOptions) && chatResponse != null
511528
&& this.isToolCall(chatResponse, Set.of("tool_use"))) {
512529
var toolCallConversation = this.handleToolCalls(prompt, chatResponse);
513-
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
530+
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), chatResponse);
514531
}
515532
return Mono.just(chatResponse);
516533
})

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ public static boolean isToolUseFinish(ConverseStreamOutput event) {
9090
return true;
9191
}
9292

93-
public static Flux<ChatResponse> toChatResponse(Flux<ConverseStreamOutput> responses) {
93+
public static Flux<ChatResponse> toChatResponse(Flux<ConverseStreamOutput> responses,
94+
ChatResponse perviousChatResponse) {
9495

9596
AtomicBoolean isInsideTool = new AtomicBoolean(false);
9697

@@ -120,20 +121,30 @@ public static Flux<ChatResponse> toChatResponse(Flux<ConverseStreamOutput> respo
120121

121122
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();
122123

124+
Long promptTokens = 0L;
125+
Long generationTokens = 0L;
126+
Long totalTokens = 0L;
127+
123128
for (ToolUseAggregationEvent.ToolUseEntry toolUseEntry : toolUseAggregationEvent.toolUseEntries()) {
124129
var functionCallId = toolUseEntry.id();
125130
var functionName = toolUseEntry.name();
126131
var functionArguments = toolUseEntry.input();
127132
toolCalls.add(
128133
new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
134+
135+
if (toolUseEntry.usage() != null) {
136+
promptTokens += toolUseEntry.usage().getPromptTokens();
137+
generationTokens += toolUseEntry.usage().getGenerationTokens();
138+
totalTokens += toolUseEntry.usage().getTotalTokens();
139+
}
129140
}
130141

131142
AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
132143
Generation toolCallGeneration = new Generation(assistantMessage,
133144
ChatGenerationMetadata.from("tool_use", null));
134145

135146
var chatResponseMetaData = ChatResponseMetadata.builder()
136-
.withUsage(toolUseAggregationEvent.usage)
147+
.withUsage(new DefaultUsage(promptTokens, generationTokens, totalTokens))
137148
.build();
138149

139150
return new Aggregation(
@@ -181,22 +192,22 @@ else if (nextEvent instanceof ContentBlockStopEvent contentBlockStopEvent) {
181192
return new Aggregation();
182193
}
183194
else if (nextEvent instanceof ConverseStreamMetadataEvent metadataEvent) {
184-
// return new Aggregation();
195+
185196
var newMeta = MetadataAggregation.builder()
186197
.copy(lastAggregation.metadataAggregation())
187198
.withTokenUsage(metadataEvent.usage())
188199
.withMetrics(metadataEvent.metrics())
189200
.withTrace(metadataEvent.trace())
190201
.build();
191202

192-
DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens().longValue(),
193-
metadataEvent.usage().outputTokens().longValue(),
194-
metadataEvent.usage().totalTokens().longValue());
195-
196203
// TODO
197204
Document modelResponseFields = lastAggregation.metadataAggregation().additionalModelResponseFields();
198205
ConverseStreamMetrics metrics = metadataEvent.metrics();
199206

207+
DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens().longValue(),
208+
metadataEvent.usage().outputTokens().longValue(),
209+
metadataEvent.usage().totalTokens().longValue());
210+
200211
var chatResponseMetaData = ChatResponseMetadata.builder().withUsage(usage).build();
201212

202213
return new Aggregation(newMeta, new ChatResponse(List.of(), chatResponseMetaData));
@@ -206,8 +217,42 @@ else if (nextEvent instanceof ConverseStreamMetadataEvent metadataEvent) {
206217
}
207218
})
208219
// .skip(1)
209-
.map(aggregation -> aggregation.chatResponse())
210-
.filter(chatResponse -> chatResponse != ConverseApiUtils.EMPTY_CHAT_RESPONSE);
220+
.filter(aggregation -> aggregation.chatResponse() != ConverseApiUtils.EMPTY_CHAT_RESPONSE)
221+
.map(aggregation -> {
222+
223+
var chatResponse = aggregation.chatResponse();
224+
225+
// Merge the previous chat response metadata with the current one.
226+
if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null
227+
&& perviousChatResponse.getMetadata().getUsage() != null) {
228+
229+
var metadataBuilder = ChatResponseMetadata.builder();
230+
231+
Long promptTokens = perviousChatResponse.getMetadata().getUsage().getPromptTokens();
232+
Long generationTokens = perviousChatResponse.getMetadata().getUsage().getGenerationTokens();
233+
Long totalTokens = perviousChatResponse.getMetadata().getUsage().getTotalTokens();
234+
235+
if (chatResponse.getMetadata() != null) {
236+
metadataBuilder.withId(chatResponse.getMetadata().getId());
237+
metadataBuilder.withModel(chatResponse.getMetadata().getModel());
238+
metadataBuilder.withRateLimit(chatResponse.getMetadata().getRateLimit());
239+
metadataBuilder.withPromptMetadata(chatResponse.getMetadata().getPromptMetadata());
240+
241+
if (chatResponse.getMetadata().getUsage() != null) {
242+
promptTokens = promptTokens + chatResponse.getMetadata().getUsage().getPromptTokens();
243+
generationTokens = generationTokens
244+
+ chatResponse.getMetadata().getUsage().getGenerationTokens();
245+
totalTokens = totalTokens + chatResponse.getMetadata().getUsage().getTotalTokens();
246+
}
247+
}
248+
249+
metadataBuilder.withUsage(new DefaultUsage(promptTokens, generationTokens, totalTokens));
250+
251+
return new ChatResponse(chatResponse.getResults(), metadataBuilder.build());
252+
}
253+
254+
return aggregation.chatResponse();
255+
});
211256
}
212257

213258
public static ConverseStreamOutput mergeToolUseEvents(ConverseStreamOutput previousEvent,
@@ -245,7 +290,7 @@ else if (event.sdkEventType() == EventType.METADATA) {
245290
DefaultUsage usage = new DefaultUsage(metadataEvent.usage().inputTokens().longValue(),
246291
metadataEvent.usage().outputTokens().longValue(), metadataEvent.usage().totalTokens().longValue());
247292
toolUseEventAggregator.withUsage(usage);
248-
// TODO
293+
249294
if (!toolUseEventAggregator.isEmpty()) {
250295
toolUseEventAggregator.squashIntoContentBlock();
251296
return toolUseEventAggregator;
@@ -400,7 +445,7 @@ ToolUseAggregationEvent appendPartialJson(String partialJson) {
400445
}
401446

402447
void squashIntoContentBlock() {
403-
this.toolUseEntries.add(new ToolUseEntry(this.index, this.id, this.name, this.partialJson));
448+
this.toolUseEntries.add(new ToolUseEntry(this.index, this.id, this.name, this.partialJson, this.usage));
404449
this.index = null;
405450
this.id = null;
406451
this.name = null;
@@ -424,7 +469,7 @@ public void accept(Visitor visitor) {
424469
throw new UnsupportedOperationException();
425470
}
426471

427-
public record ToolUseEntry(Integer index, String id, String name, String input) {
472+
public record ToolUseEntry(Integer index, String id, String name, String input, DefaultUsage usage) {
428473
}
429474

430475
}

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import org.springframework.util.MimeTypeUtils;
5050

5151
import static org.assertj.core.api.Assertions.assertThat;
52+
import static org.mockito.ArgumentMatchers.matches;
5253

5354
@SpringBootTest(classes = BedrockConverseTestConfiguration.class)
5455
@EnabledIfEnvironmentVariable(named = "AWS_ACCESS_KEY_ID", matches = ".*")
@@ -227,6 +228,41 @@ void functionCallTest() {
227228
assertThat(response).contains("30", "10", "15");
228229
}
229230

231+
@Test
232+
void functionCallWithUsageMetadataTest() {
233+
234+
// @formatter:off
235+
ChatResponse response = ChatClient.create(this.chatModel)
236+
.prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
237+
.functions(FunctionCallback.builder()
238+
.description("Get the weather in location")
239+
.function("getCurrentWeather", new MockWeatherService())
240+
.inputType(MockWeatherService.Request.class)
241+
.build())
242+
.call()
243+
.chatResponse();
244+
// @formatter:on
245+
246+
var metadata = response.getMetadata();
247+
248+
assertThat(metadata.getUsage()).isNotNull();
249+
250+
logger.info(metadata.getUsage().toString());
251+
252+
assertThat(metadata.getUsage().getPromptTokens()).isGreaterThan(500);
253+
assertThat(metadata.getUsage().getPromptTokens()).isLessThan(3500);
254+
255+
assertThat(metadata.getUsage().getGenerationTokens()).isGreaterThan(0);
256+
assertThat(metadata.getUsage().getGenerationTokens()).isLessThan(1500);
257+
258+
assertThat(metadata.getUsage().getTotalTokens())
259+
.isEqualTo(metadata.getUsage().getPromptTokens() + metadata.getUsage().getGenerationTokens());
260+
261+
logger.info("Response: {}", response);
262+
263+
assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15");
264+
}
265+
230266
@Test
231267
void functionCallWithAdvisorTest() {
232268

@@ -274,18 +310,39 @@ void defaultFunctionCallTest() {
274310
void streamFunctionCallTest() {
275311

276312
// @formatter:off
277-
Flux<String> response = ChatClient.create(this.chatModel).prompt()
313+
Flux<ChatResponse> response = ChatClient.create(this.chatModel).prompt()
278314
.user("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
279315
.functions(FunctionCallback.builder()
280316
.description("Get the weather in location")
281317
.function("getCurrentWeather", new MockWeatherService())
282318
.inputType(MockWeatherService.Request.class)
283319
.build())
284320
.stream()
285-
.content();
321+
.chatResponse();
286322
// @formatter:on
287323

288-
String content = response.collectList().block().stream().collect(Collectors.joining());
324+
List<ChatResponse> chatResponses = response.collectList().block();
325+
326+
// chatResponses.forEach(cr -> logger.info("Response: {}", cr));
327+
var lastChatResponse = chatResponses.get(chatResponses.size() - 1);
328+
var metadata = lastChatResponse.getMetadata();
329+
assertThat(metadata.getUsage()).isNotNull();
330+
331+
logger.info(metadata.getUsage().toString());
332+
333+
assertThat(metadata.getUsage().getPromptTokens()).isGreaterThan(1500);
334+
assertThat(metadata.getUsage().getPromptTokens()).isLessThan(3500);
335+
336+
assertThat(metadata.getUsage().getGenerationTokens()).isGreaterThan(0);
337+
assertThat(metadata.getUsage().getGenerationTokens()).isLessThan(1500);
338+
339+
assertThat(metadata.getUsage().getTotalTokens())
340+
.isEqualTo(metadata.getUsage().getPromptTokens() + metadata.getUsage().getGenerationTokens());
341+
342+
String content = chatResponses.stream()
343+
.filter(cr -> cr.getResult() != null)
344+
.map(cr -> cr.getResult().getOutput().getContent())
345+
.collect(Collectors.joining());
289346
logger.info("Response: {}", content);
290347

291348
assertThat(content).contains("30", "10", "15");

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/experiements/BedrockConverseChatModelMain2.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ public static void main(String[] args) {
6969
Flux<ConverseStreamOutput> responses = chatModel.converseStream(streamRequest);
7070
List<ConverseStreamOutput> responseList = responses.collectList().block();
7171
System.out.println(responseList);
72+
System.out.println("Response count: " + responseList.size());
73+
responseList.forEach(System.out::println);
7274
}
7375

7476
}

0 commit comments

Comments
 (0)