Skip to content

Commit 42dcb45

Browse files
sobychackomarkpollack
authored andcommitted
Align AzureOpenAiChatOptions with Azure ChatCompletionsOptions
Add missing options from Azure ChatCompletionsOptions to Spring AI AzureOpenAiChatOptions. The following fields have been added: - seed - logprobs - topLogprobs - enhancements This change ensures better alignment between the two option sets, improving compatibility and feature parity. Resolves #889
1 parent 35e6113 commit 42dcb45

File tree

3 files changed

+166
-3
lines changed

3 files changed

+166
-3
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023 - 2024 the original author or authors.
2+
* Copyright 2023-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
1617
package org.springframework.ai.azure.openai;
1718

1819
import java.util.ArrayList;
@@ -92,6 +93,7 @@
9293
* @author Thomas Vitale
9394
* @author luocongqiu
9495
* @author timostark
96+
* @author Soby Chacko
9597
* @see ChatModel
9698
* @see com.azure.ai.openai.OpenAIClient
9799
*/
@@ -456,6 +458,18 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
456458
mergedAzureOptions.setModel(fromAzureOptions.getModel() != null ? fromAzureOptions.getModel()
457459
: toSpringAiOptions.getDeploymentName());
458460

461+
mergedAzureOptions
462+
.setSeed(fromAzureOptions.getSeed() != null ? fromAzureOptions.getSeed() : toSpringAiOptions.getSeed());
463+
464+
mergedAzureOptions.setLogprobs((fromAzureOptions.isLogprobs() != null && fromAzureOptions.isLogprobs())
465+
|| (toSpringAiOptions.isLogprobs() != null && toSpringAiOptions.isLogprobs()));
466+
467+
mergedAzureOptions.setTopLogprobs(fromAzureOptions.getTopLogprobs() != null ? fromAzureOptions.getTopLogprobs()
468+
: toSpringAiOptions.getTopLogProbs());
469+
470+
mergedAzureOptions.setEnhancements(fromAzureOptions.getEnhancements() != null
471+
? fromAzureOptions.getEnhancements() : toSpringAiOptions.getEnhancements());
472+
459473
return mergedAzureOptions;
460474
}
461475

@@ -520,6 +534,22 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
520534
mergedAzureOptions.setResponseFormat(toAzureResponseFormat(fromSpringAiOptions.getResponseFormat()));
521535
}
522536

537+
if (fromSpringAiOptions.getSeed() != null) {
538+
mergedAzureOptions.setSeed(fromSpringAiOptions.getSeed());
539+
}
540+
541+
if (fromSpringAiOptions.isLogprobs() != null) {
542+
mergedAzureOptions.setLogprobs(fromSpringAiOptions.isLogprobs());
543+
}
544+
545+
if (fromSpringAiOptions.getTopLogProbs() != null) {
546+
mergedAzureOptions.setTopLogprobs(fromSpringAiOptions.getTopLogProbs());
547+
}
548+
549+
if (fromSpringAiOptions.getEnhancements() != null) {
550+
mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements());
551+
}
552+
523553
return mergedAzureOptions;
524554
}
525555

@@ -566,6 +596,19 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
566596
if (fromOptions.getResponseFormat() != null) {
567597
copyOptions.setResponseFormat(fromOptions.getResponseFormat());
568598
}
599+
if (fromOptions.getSeed() != null) {
600+
copyOptions.setSeed(fromOptions.getSeed());
601+
}
602+
603+
copyOptions.setLogprobs(fromOptions.isLogprobs());
604+
605+
if (fromOptions.getTopLogprobs() != null) {
606+
copyOptions.setTopLogprobs(fromOptions.getTopLogprobs());
607+
}
608+
609+
if (fromOptions.getEnhancements() != null) {
610+
copyOptions.setEnhancements(fromOptions.getEnhancements());
611+
}
569612

570613
return copyOptions;
571614
}

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023 - 2024 the original author or authors.
2+
* Copyright 2023-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
1617
package org.springframework.ai.azure.openai;
1718

1819
import java.util.ArrayList;
@@ -21,6 +22,7 @@
2122
import java.util.Map;
2223
import java.util.Set;
2324

25+
import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
2426
import com.fasterxml.jackson.annotation.JsonIgnore;
2527
import com.fasterxml.jackson.annotation.JsonInclude;
2628
import com.fasterxml.jackson.annotation.JsonInclude.Include;
@@ -40,6 +42,7 @@
4042
*
4143
* @author Christian Tzolov
4244
* @author Thomas Vitale
45+
* @author Soby Chacko
4346
*/
4447
@JsonInclude(Include.NON_NULL)
4548
public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
@@ -165,6 +168,37 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
165168
@JsonIgnore
166169
private Boolean proxyToolCalls;
167170

171+
/**
172+
* Seed value for deterministic sampling such that the same seed and parameters return
173+
* the same result.
174+
*/
175+
@JsonProperty(value = "seed")
176+
private Long seed;
177+
178+
/**
179+
* Whether to return log probabilities of the output tokens or not. If true, returns
180+
* the log probabilities of each output token returned in the `content` of `message`.
181+
* This option is currently not available on the `gpt-4-vision-preview` model.
182+
*/
183+
@JsonProperty(value = "log_probs")
184+
private Boolean logprobs;
185+
186+
/*
187+
* An integer between 0 and 5 specifying the number of most likely tokens to return at
188+
* each token position, each with an associated log probability. `logprobs` must be
189+
* set to `true` if this parameter is used.
190+
*/
191+
@JsonProperty(value = "top_log_probs")
192+
private Integer topLogProbs;
193+
194+
/*
195+
* If provided, the configuration options for available Azure OpenAI chat
196+
* enhancements.
197+
*/
198+
@NestedConfigurationProperty
199+
@JsonIgnore
200+
private AzureChatEnhancementConfiguration enhancements;
201+
168202
public static Builder builder() {
169203
return new Builder();
170204
}
@@ -259,6 +293,30 @@ public Builder withProxyToolCalls(Boolean proxyToolCalls) {
259293
return this;
260294
}
261295

296+
public Builder withSeed(Long seed) {
297+
Assert.notNull(seed, "seed must not be null");
298+
this.options.seed = seed;
299+
return this;
300+
}
301+
302+
public Builder withLogprobs(Boolean logprobs) {
303+
Assert.notNull(logprobs, "logprobs must not be null");
304+
this.options.logprobs = logprobs;
305+
return this;
306+
}
307+
308+
public Builder withTopLogprobs(Integer topLogprobs) {
309+
Assert.notNull(topLogprobs, "topLogprobs must not be null");
310+
this.options.topLogProbs = topLogprobs;
311+
return this;
312+
}
313+
314+
public Builder withEnhancements(AzureChatEnhancementConfiguration enhancements) {
315+
Assert.notNull(enhancements, "enhancements must not be null");
316+
this.options.enhancements = enhancements;
317+
return this;
318+
}
319+
262320
public AzureOpenAiChatOptions build() {
263321
return this.options;
264322
}
@@ -404,6 +462,38 @@ public Integer getTopK() {
404462
return null;
405463
}
406464

465+
public Long getSeed() {
466+
return this.seed;
467+
}
468+
469+
public void setSeed(Long seed) {
470+
this.seed = seed;
471+
}
472+
473+
public Boolean isLogprobs() {
474+
return this.logprobs;
475+
}
476+
477+
public void setLogprobs(Boolean logprobs) {
478+
this.logprobs = logprobs;
479+
}
480+
481+
public Integer getTopLogProbs() {
482+
return this.topLogProbs;
483+
}
484+
485+
public void setTopLogProbs(Integer topLogProbs) {
486+
this.topLogProbs = topLogProbs;
487+
}
488+
489+
public AzureChatEnhancementConfiguration getEnhancements() {
490+
return this.enhancements;
491+
}
492+
493+
public void setEnhancements(AzureChatEnhancementConfiguration enhancements) {
494+
this.enhancements = enhancements;
495+
}
496+
407497
@Override
408498
public Boolean getProxyToolCalls() {
409499
return this.proxyToolCalls;
@@ -432,6 +522,10 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
432522
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
433523
.withFunctions(fromOptions.getFunctions())
434524
.withResponseFormat(fromOptions.getResponseFormat())
525+
.withSeed(fromOptions.getSeed())
526+
.withLogprobs(fromOptions.isLogprobs())
527+
.withTopLogprobs(fromOptions.getTopLogProbs())
528+
.withEnhancements(fromOptions.getEnhancements())
435529
.build();
436530
}
437531

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023 - 2024 the original author or authors.
2+
* Copyright 2023-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -13,9 +13,12 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
1617
package org.springframework.ai.azure.openai;
1718

1819
import com.azure.ai.openai.OpenAIClient;
20+
import com.azure.ai.openai.models.AzureChatEnhancementConfiguration;
21+
import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration;
1922
import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat;
2023
import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat;
2124
import org.junit.jupiter.api.Test;
@@ -34,6 +37,7 @@
3437

3538
/**
3639
* @author Christian Tzolov
40+
* @author Soby Chacko
3741
*/
3842
public class AzureChatCompletionsOptionsTests {
3943

@@ -42,6 +46,9 @@ public void createRequestWithChatOptions() {
4246

4347
OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);
4448

49+
AzureChatEnhancementConfiguration mockAzureChatEnhancementConfiguration = Mockito
50+
.mock(AzureChatEnhancementConfiguration.class);
51+
4552
var defaultOptions = AzureOpenAiChatOptions.builder()
4653
.withDeploymentName("DEFAULT_MODEL")
4754
.withTemperature(66.6)
@@ -53,6 +60,10 @@ public void createRequestWithChatOptions() {
5360
.withStop(List.of("foo", "bar"))
5461
.withTopP(0.69)
5562
.withUser("user")
63+
.withSeed(123L)
64+
.withLogprobs(true)
65+
.withTopLogprobs(5)
66+
.withEnhancements(mockAzureChatEnhancementConfiguration)
5667
.withResponseFormat(AzureOpenAiResponseFormat.TEXT)
5768
.build();
5869

@@ -72,8 +83,15 @@ public void createRequestWithChatOptions() {
7283
assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar"));
7384
assertThat(requestOptions.getTopP()).isEqualTo(0.69);
7485
assertThat(requestOptions.getUser()).isEqualTo("user");
86+
assertThat(requestOptions.getSeed()).isEqualTo(123L);
87+
assertThat(requestOptions.isLogprobs()).isTrue();
88+
assertThat(requestOptions.getTopLogprobs()).isEqualTo(5);
89+
assertThat(requestOptions.getEnhancements()).isEqualTo(mockAzureChatEnhancementConfiguration);
7590
assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsTextResponseFormat.class);
7691

92+
AzureChatEnhancementConfiguration anotherMockAzureChatEnhancementConfiguration = Mockito
93+
.mock(AzureChatEnhancementConfiguration.class);
94+
7795
var runtimeOptions = AzureOpenAiChatOptions.builder()
7896
.withDeploymentName("PROMPT_MODEL")
7997
.withTemperature(99.9)
@@ -85,6 +103,10 @@ public void createRequestWithChatOptions() {
85103
.withStop(List.of("foo", "bar"))
86104
.withTopP(0.111)
87105
.withUser("user2")
106+
.withSeed(1234L)
107+
.withLogprobs(true)
108+
.withTopLogprobs(4)
109+
.withEnhancements(anotherMockAzureChatEnhancementConfiguration)
88110
.withResponseFormat(AzureOpenAiResponseFormat.JSON)
89111
.build();
90112

@@ -102,6 +124,10 @@ public void createRequestWithChatOptions() {
102124
assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar"));
103125
assertThat(requestOptions.getTopP()).isEqualTo(0.111);
104126
assertThat(requestOptions.getUser()).isEqualTo("user2");
127+
assertThat(requestOptions.getSeed()).isEqualTo(1234L);
128+
assertThat(requestOptions.isLogprobs()).isTrue();
129+
assertThat(requestOptions.getTopLogprobs()).isEqualTo(4);
130+
assertThat(requestOptions.getEnhancements()).isEqualTo(anotherMockAzureChatEnhancementConfiguration);
105131
assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsJsonResponseFormat.class);
106132
}
107133

0 commit comments

Comments
 (0)