Skip to content

Commit 7bb553e

Browse files
andresssantosilayaperumalg
authored andcommitted
Add support for stream usage in Azure OpenAi
This PR introduces support for streamUsage in the AzureOpenAiChatOptions. - Set com.azure.ai.openai.models.ChatCompletionStreamOptions#includeUsage via AzureOpenAiChatOptions. Additionally: - Updates the unit test AzureOpenAiChatOptionsTests to reflect the changes. - Updates the documentation in azure-openai-chat.adoc. Signed-off-by: Andres da Silva Santos <40636137+andresssantos@users.noreply.github.com>
1 parent 5b7849d commit 7bb553e

File tree

4 files changed

+51
-7
lines changed

4 files changed

+51
-7
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
3737
import com.azure.ai.openai.models.ChatCompletionsOptions;
3838
import com.azure.ai.openai.models.ChatCompletionsResponseFormat;
39+
import com.azure.ai.openai.models.ChatCompletionStreamOptions;
3940
import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat;
4041
import com.azure.ai.openai.models.ChatCompletionsToolCall;
4142
import com.azure.ai.openai.models.ChatCompletionsToolDefinition;
@@ -113,6 +114,7 @@
113114
* @author Ilayaperumal Gopinathan
114115
* @author Alexandros Pappas
115116
* @author Berjan Jonker
117+
* @author Andres da Silva Santos
116118
* @see ChatModel
117119
* @see com.azure.ai.openai.OpenAIClient
118120
* @since 1.0.0
@@ -498,8 +500,9 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
498500

499501
options = this.merge(options, this.defaultOptions);
500502

503+
AzureOpenAiChatOptions updatedRuntimeOptions;
504+
501505
if (prompt.getOptions() != null) {
502-
AzureOpenAiChatOptions updatedRuntimeOptions;
503506
if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
504507
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions,
505508
ToolCallingChatOptions.class, AzureOpenAiChatOptions.class);
@@ -523,6 +526,15 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
523526
options.setTools(tools2);
524527
}
525528

529+
Boolean enableStreamUsage = (prompt.getOptions() instanceof AzureOpenAiChatOptions azureOpenAiChatOptions
530+
&& azureOpenAiChatOptions.getStreamUsage() != null) ? azureOpenAiChatOptions.getStreamUsage()
531+
: this.defaultOptions.getStreamUsage();
532+
533+
if (Boolean.TRUE.equals(enableStreamUsage) && options.getStreamOptions() == null) {
534+
ChatCompletionsOptionsAccessHelper.setStreamOptions(options,
535+
new ChatCompletionStreamOptions().setIncludeUsage(true));
536+
}
537+
526538
return options;
527539
}
528540

@@ -646,6 +658,8 @@ Prompt buildRequestPrompt(Prompt prompt) {
646658
requestOptions.setInternalToolExecutionEnabled(
647659
ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(),
648660
this.defaultOptions.getInternalToolExecutionEnabled()));
661+
requestOptions.setStreamUsage(ModelOptionsUtils.mergeOption(runtimeOptions.getStreamUsage(),
662+
this.defaultOptions.getStreamUsage()));
649663
requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
650664
this.defaultOptions.getToolNames()));
651665
requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(),
@@ -655,6 +669,7 @@ Prompt buildRequestPrompt(Prompt prompt) {
655669
}
656670
else {
657671
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
672+
requestOptions.setStreamUsage(this.defaultOptions.getStreamUsage());
658673
requestOptions.setToolNames(this.defaultOptions.getToolNames());
659674
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
660675
requestOptions.setToolContext(this.defaultOptions.getToolContext());

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
* @author Soby Chacko
4848
* @author Ilayaperumal Gopinathan
4949
* @author Alexandros Pappas
50+
* @author Andres da Silva Santos
5051
*/
5152
@JsonInclude(Include.NON_NULL)
5253
public class AzureOpenAiChatOptions implements ToolCallingChatOptions {
@@ -199,6 +200,13 @@ public class AzureOpenAiChatOptions implements ToolCallingChatOptions {
199200
@JsonIgnore
200201
private Boolean internalToolExecutionEnabled;
201202

203+
/**
204+
* Whether to include token usage information in streaming chat completion responses.
205+
* Only applies to streaming responses.
206+
*/
207+
@JsonIgnore
208+
private Boolean enableStreamUsage;
209+
202210
@Override
203211
@JsonIgnore
204212
public List<ToolCallback> getToolCallbacks() {
@@ -259,6 +267,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
259267
fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null)
260268
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
261269
.responseFormat(fromOptions.getResponseFormat())
270+
.streamUsage(fromOptions.getStreamUsage())
262271
.seed(fromOptions.getSeed())
263272
.logprobs(fromOptions.isLogprobs())
264273
.topLogprobs(fromOptions.getTopLogProbs())
@@ -391,6 +400,14 @@ public void setResponseFormat(AzureOpenAiResponseFormat responseFormat) {
391400
this.responseFormat = responseFormat;
392401
}
393402

403+
public Boolean getStreamUsage() {
404+
return this.enableStreamUsage;
405+
}
406+
407+
public void setStreamUsage(Boolean enableStreamUsage) {
408+
this.enableStreamUsage = enableStreamUsage;
409+
}
410+
394411
@Override
395412
@JsonIgnore
396413
public Integer getTopK() {
@@ -472,6 +489,7 @@ public boolean equals(Object o) {
472489
&& Objects.equals(this.logprobs, that.logprobs) && Objects.equals(this.topLogProbs, that.topLogProbs)
473490
&& Objects.equals(this.enhancements, that.enhancements)
474491
&& Objects.equals(this.streamOptions, that.streamOptions)
492+
&& Objects.equals(this.enableStreamUsage, that.enableStreamUsage)
475493
&& Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.maxTokens, that.maxTokens)
476494
&& Objects.equals(this.frequencyPenalty, that.frequencyPenalty)
477495
&& Objects.equals(this.presencePenalty, that.presencePenalty)
@@ -482,8 +500,8 @@ public boolean equals(Object o) {
482500
public int hashCode() {
483501
return Objects.hash(this.logitBias, this.user, this.n, this.stop, this.deploymentName, this.responseFormat,
484502
this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.seed, this.logprobs,
485-
this.topLogProbs, this.enhancements, this.streamOptions, this.toolContext, this.maxTokens,
486-
this.frequencyPenalty, this.presencePenalty, this.temperature, this.topP);
503+
this.topLogProbs, this.enhancements, this.streamOptions, this.enableStreamUsage, this.toolContext,
504+
this.maxTokens, this.frequencyPenalty, this.presencePenalty, this.temperature, this.topP);
487505
}
488506

489507
public static class Builder {
@@ -553,6 +571,11 @@ public Builder responseFormat(AzureOpenAiResponseFormat responseFormat) {
553571
return this;
554572
}
555573

574+
public Builder streamUsage(Boolean enableStreamUsage) {
575+
this.options.enableStreamUsage = enableStreamUsage;
576+
return this;
577+
}
578+
556579
public Builder seed(Long seed) {
557580
this.options.seed = seed;
558581
return this;

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ void testBuilderWithAllFields() {
5656
.topP(0.9)
5757
.user("test-user")
5858
.responseFormat(responseFormat)
59+
.streamUsage(true)
5960
.seed(12345L)
6061
.logprobs(true)
6162
.topLogprobs(5)
@@ -65,11 +66,11 @@ void testBuilderWithAllFields() {
6566

6667
assertThat(options)
6768
.extracting("deploymentName", "frequencyPenalty", "logitBias", "maxTokens", "n", "presencePenalty", "stop",
68-
"temperature", "topP", "user", "responseFormat", "seed", "logprobs", "topLogProbs", "enhancements",
69-
"streamOptions")
69+
"temperature", "topP", "user", "responseFormat", "streamUsage", "seed", "logprobs", "topLogProbs",
70+
"enhancements", "streamOptions")
7071
.containsExactly("test-deployment", 0.5, Map.of("token1", 1, "token2", -1), 200, 2, 0.8,
71-
List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, 12345L, true, 5, enhancements,
72-
streamOptions);
72+
List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, true, 12345L, true, 5,
73+
enhancements, streamOptions);
7374
}
7475

7576
@Test
@@ -94,6 +95,7 @@ void testCopy() {
9495
.topP(0.9)
9596
.user("test-user")
9697
.responseFormat(responseFormat)
98+
.streamUsage(true)
9799
.seed(12345L)
98100
.logprobs(true)
99101
.topLogprobs(5)
@@ -128,6 +130,7 @@ void testSetters() {
128130
options.setTopP(0.9);
129131
options.setUser("test-user");
130132
options.setResponseFormat(responseFormat);
133+
options.setStreamUsage(true);
131134
options.setSeed(12345L);
132135
options.setLogprobs(true);
133136
options.setTopLogProbs(5);
@@ -148,6 +151,7 @@ void testSetters() {
148151
assertThat(options.getTopP()).isEqualTo(0.9);
149152
assertThat(options.getUser()).isEqualTo("test-user");
150153
assertThat(options.getResponseFormat()).isEqualTo(responseFormat);
154+
assertThat(options.getStreamUsage()).isTrue();
151155
assertThat(options.getSeed()).isEqualTo(12345L);
152156
assertThat(options.isLogprobs()).isTrue();
153157
assertThat(options.getTopLogProbs()).isEqualTo(5);
@@ -171,6 +175,7 @@ void testDefaultValues() {
171175
assertThat(options.getTopP()).isNull();
172176
assertThat(options.getUser()).isNull();
173177
assertThat(options.getResponseFormat()).isNull();
178+
assertThat(options.getStreamUsage()).isNull();
174179
assertThat(options.getSeed()).isNull();
175180
assertThat(options.isLogprobs()).isNull();
176181
assertThat(options.getTopLogProbs()).isNull();

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ Deployments model name to provide as part of this completions request. | gpt-4o
185185
| spring.ai.azure.openai.chat.options.topP | An alternative to sampling with temperature called nucleus sampling. This value causes the model to consider the results of tokens with the provided probability mass. | -
186186
| spring.ai.azure.openai.chat.options.logitBias | A map between GPT token IDs and bias scores that influences the probability of specific tokens appearing in a completions response. Token IDs are computed via external tokenizer tools, while bias scores reside in the range of -100 to 100 with minimum and maximum values corresponding to a full ban or exclusive selection of a token, respectively. The exact behavior of a given bias score varies by model. | -
187187
| spring.ai.azure.openai.chat.options.user | An identifier for the caller or end user of the operation. This may be used for tracking or rate-limiting purposes. | -
188+
| spring.ai.azure.openai.chat.options.stream-usage | (For streaming only) Set to add an additional chunk with token usage statistics for the entire request. The `choices` field for this chunk is an empty array and all other chunks will also include a usage field, but with a null value. | false
188189
| spring.ai.azure.openai.chat.options.n | The number of chat completions choices that should be generated for a chat completions response. | -
189190
| spring.ai.azure.openai.chat.options.stop | A collection of textual sequences that will end completions generation. | -
190191
| spring.ai.azure.openai.chat.options.presencePenalty | A value that influences the probability of generated tokens appearing based on their existing presence in generated text. Positive values will make tokens less likely to appear when they already exist and increase the model's likelihood to output new topics. | -

0 commit comments

Comments
 (0)