Skip to content

Commit 93bd502

Browse files
committed
Improve Azure OpenAI options merging logic
1 parent 220ec7f commit 93bd502

File tree

2 files changed

+167
-65
lines changed

2 files changed

+167
-65
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java

Lines changed: 120 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -277,120 +277,135 @@ private <T> List<T> nullSafeList(List<T> list) {
277277
return list != null ? list : Collections.emptyList();
278278
}
279279

280-
// JSON merge doesn't due to Azure OpenAI service bug:
281-
// https://github.com/Azure/azure-sdk-for-java/issues/38183
282-
private ChatCompletionsOptions merge(ChatCompletionsOptions azureOptions, AzureOpenAiChatOptions springAiOptions) {
280+
/**
281+
* Merges the Azure's {@link ChatCompletionsOptions} (fromAzureOptions) into the
282+
* Spring AI's {@link AzureOpenAiChatOptions} (toSpringAiOptions) and return a new
283+
* {@link ChatCompletionsOptions} instance.
284+
*/
285+
private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions,
286+
AzureOpenAiChatOptions toSpringAiOptions) {
283287

284-
if (springAiOptions == null) {
285-
return azureOptions;
288+
if (toSpringAiOptions == null) {
289+
return fromAzureOptions;
286290
}
287291

288-
ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(azureOptions.getMessages());
289-
mergedAzureOptions.setStream(azureOptions.isStream());
292+
ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(fromAzureOptions.getMessages());
293+
mergedAzureOptions.setStream(fromAzureOptions.isStream());
290294

291-
mergedAzureOptions.setMaxTokens(
292-
(azureOptions.getMaxTokens() != null) ? azureOptions.getMaxTokens() : springAiOptions.getMaxTokens());
295+
mergedAzureOptions.setMaxTokens((fromAzureOptions.getMaxTokens() != null) ? fromAzureOptions.getMaxTokens()
296+
: toSpringAiOptions.getMaxTokens());
293297

294-
mergedAzureOptions.setLogitBias(
295-
azureOptions.getLogitBias() != null ? azureOptions.getLogitBias() : springAiOptions.getLogitBias());
298+
mergedAzureOptions.setLogitBias(fromAzureOptions.getLogitBias() != null ? fromAzureOptions.getLogitBias()
299+
: toSpringAiOptions.getLogitBias());
296300

297-
mergedAzureOptions.setStop(azureOptions.getStop() != null ? azureOptions.getStop() : springAiOptions.getStop());
301+
mergedAzureOptions
302+
.setStop(fromAzureOptions.getStop() != null ? fromAzureOptions.getStop() : toSpringAiOptions.getStop());
298303

299-
mergedAzureOptions.setTemperature(azureOptions.getTemperature());
300-
if (mergedAzureOptions.getTemperature() == null && springAiOptions.getTemperature() != null) {
301-
mergedAzureOptions.setTemperature(springAiOptions.getTemperature().doubleValue());
304+
mergedAzureOptions.setTemperature(fromAzureOptions.getTemperature());
305+
if (mergedAzureOptions.getTemperature() == null && toSpringAiOptions.getTemperature() != null) {
306+
mergedAzureOptions.setTemperature(toSpringAiOptions.getTemperature().doubleValue());
302307
}
303308

304-
mergedAzureOptions.setTopP(azureOptions.getTopP());
305-
if (mergedAzureOptions.getTopP() == null && springAiOptions.getTopP() != null) {
306-
mergedAzureOptions.setTopP(springAiOptions.getTopP().doubleValue());
309+
mergedAzureOptions.setTopP(fromAzureOptions.getTopP());
310+
if (mergedAzureOptions.getTopP() == null && toSpringAiOptions.getTopP() != null) {
311+
mergedAzureOptions.setTopP(toSpringAiOptions.getTopP().doubleValue());
307312
}
308313

309-
mergedAzureOptions.setFrequencyPenalty(azureOptions.getFrequencyPenalty());
310-
if (mergedAzureOptions.getFrequencyPenalty() == null && springAiOptions.getFrequencyPenalty() != null) {
311-
mergedAzureOptions.setFrequencyPenalty(springAiOptions.getFrequencyPenalty().doubleValue());
314+
mergedAzureOptions.setFrequencyPenalty(fromAzureOptions.getFrequencyPenalty());
315+
if (mergedAzureOptions.getFrequencyPenalty() == null && toSpringAiOptions.getFrequencyPenalty() != null) {
316+
mergedAzureOptions.setFrequencyPenalty(toSpringAiOptions.getFrequencyPenalty().doubleValue());
312317
}
313318

314-
mergedAzureOptions.setPresencePenalty(azureOptions.getPresencePenalty());
315-
if (mergedAzureOptions.getPresencePenalty() == null && springAiOptions.getPresencePenalty() != null) {
316-
mergedAzureOptions.setPresencePenalty(springAiOptions.getPresencePenalty().doubleValue());
319+
mergedAzureOptions.setPresencePenalty(fromAzureOptions.getPresencePenalty());
320+
if (mergedAzureOptions.getPresencePenalty() == null && toSpringAiOptions.getPresencePenalty() != null) {
321+
mergedAzureOptions.setPresencePenalty(toSpringAiOptions.getPresencePenalty().doubleValue());
317322
}
318323

319-
mergedAzureOptions.setN(azureOptions.getN() != null ? azureOptions.getN() : springAiOptions.getN());
320-
321-
mergedAzureOptions.setUser(azureOptions.getUser() != null ? azureOptions.getUser() : springAiOptions.getUser());
324+
mergedAzureOptions.setN(fromAzureOptions.getN() != null ? fromAzureOptions.getN() : toSpringAiOptions.getN());
322325

323326
mergedAzureOptions
324-
.setModel(azureOptions.getModel() != null ? azureOptions.getModel() : springAiOptions.getDeploymentName());
327+
.setUser(fromAzureOptions.getUser() != null ? fromAzureOptions.getUser() : toSpringAiOptions.getUser());
328+
329+
mergedAzureOptions.setModel(fromAzureOptions.getModel() != null ? fromAzureOptions.getModel()
330+
: toSpringAiOptions.getDeploymentName());
325331

326332
return mergedAzureOptions;
327333
}
328334

329-
// JSON merge doesn't due to Azure OpenAI service bug:
330-
// https://github.com/Azure/azure-sdk-for-java/issues/38183
331-
private ChatCompletionsOptions merge(AzureOpenAiChatOptions springAiOptions, ChatCompletionsOptions azureOptions) {
332-
if (springAiOptions == null) {
333-
return azureOptions;
334-
}
335+
/**
336+
* Merges the {@link AzureOpenAiChatOptions}, fromSpringAiOptions, into the
337+
* {@link ChatCompletionsOptions}, toAzureOptions, and returns a new
338+
* {@link ChatCompletionsOptions} instance.
339+
* @param fromSpringAiOptions the {@link AzureOpenAiChatOptions} to merge from.
340+
* @param toAzureOptions the {@link ChatCompletionsOptions} to merge to.
341+
* @return a new {@link ChatCompletionsOptions} instance.
342+
*/
343+
private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions,
344+
ChatCompletionsOptions toAzureOptions) {
335345

336-
ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions(azureOptions.getMessages());
337-
mergedAzureOptions = merge(azureOptions, mergedAzureOptions);
346+
if (fromSpringAiOptions == null) {
347+
return toAzureOptions;
348+
}
338349

339-
mergedAzureOptions.setStream(azureOptions.isStream());
350+
ChatCompletionsOptions mergedAzureOptions = this.copy(toAzureOptions);
340351

341-
if (springAiOptions.getMaxTokens() != null) {
342-
mergedAzureOptions.setMaxTokens(springAiOptions.getMaxTokens());
352+
if (fromSpringAiOptions.getMaxTokens() != null) {
353+
mergedAzureOptions.setMaxTokens(fromSpringAiOptions.getMaxTokens());
343354
}
344355

345-
if (springAiOptions.getLogitBias() != null) {
346-
mergedAzureOptions.setLogitBias(springAiOptions.getLogitBias());
356+
if (fromSpringAiOptions.getLogitBias() != null) {
357+
mergedAzureOptions.setLogitBias(fromSpringAiOptions.getLogitBias());
347358
}
348359

349-
if (springAiOptions.getStop() != null) {
350-
mergedAzureOptions.setStop(springAiOptions.getStop());
360+
if (fromSpringAiOptions.getStop() != null) {
361+
mergedAzureOptions.setStop(fromSpringAiOptions.getStop());
351362
}
352363

353-
if (springAiOptions.getTemperature() != null && springAiOptions.getTemperature() != null) {
354-
mergedAzureOptions.setTemperature(springAiOptions.getTemperature().doubleValue());
364+
if (fromSpringAiOptions.getTemperature() != null) {
365+
mergedAzureOptions.setTemperature(fromSpringAiOptions.getTemperature().doubleValue());
355366
}
356367

357-
if (springAiOptions.getTopP() != null && springAiOptions.getTopP() != null) {
358-
mergedAzureOptions.setTopP(springAiOptions.getTopP().doubleValue());
368+
if (fromSpringAiOptions.getTopP() != null) {
369+
mergedAzureOptions.setTopP(fromSpringAiOptions.getTopP().doubleValue());
359370
}
360371

361-
if (springAiOptions.getFrequencyPenalty() != null && springAiOptions.getFrequencyPenalty() != null) {
362-
mergedAzureOptions.setFrequencyPenalty(springAiOptions.getFrequencyPenalty().doubleValue());
372+
if (fromSpringAiOptions.getFrequencyPenalty() != null) {
373+
mergedAzureOptions.setFrequencyPenalty(fromSpringAiOptions.getFrequencyPenalty().doubleValue());
363374
}
364375

365-
if (springAiOptions.getPresencePenalty() != null && springAiOptions.getPresencePenalty() != null) {
366-
mergedAzureOptions.setPresencePenalty(springAiOptions.getPresencePenalty().doubleValue());
376+
if (fromSpringAiOptions.getPresencePenalty() != null) {
377+
mergedAzureOptions.setPresencePenalty(fromSpringAiOptions.getPresencePenalty().doubleValue());
367378
}
368379

369-
if (springAiOptions.getN() != null) {
370-
mergedAzureOptions.setN(springAiOptions.getN());
380+
if (fromSpringAiOptions.getN() != null) {
381+
mergedAzureOptions.setN(fromSpringAiOptions.getN());
371382
}
372383

373-
if (springAiOptions.getUser() != null) {
374-
mergedAzureOptions.setUser(springAiOptions.getUser());
384+
if (fromSpringAiOptions.getUser() != null) {
385+
mergedAzureOptions.setUser(fromSpringAiOptions.getUser());
375386
}
376387

377-
if (springAiOptions.getDeploymentName() != null) {
378-
mergedAzureOptions.setModel(springAiOptions.getDeploymentName());
388+
if (fromSpringAiOptions.getDeploymentName() != null) {
389+
mergedAzureOptions.setModel(fromSpringAiOptions.getDeploymentName());
379390
}
380391

381392
return mergedAzureOptions;
382393
}
383394

384-
// https://github.com/Azure/azure-sdk-for-java/blob/azure-ai-openai_1.0.0-beta.6/sdk/openai/azure-ai-openai/src/samples/java/com/azure/ai/openai/usage/GetChatCompletionsToolCallSample.java
385-
395+
/**
396+
* Merges the fromOptions into the toOptions and returns a new ChatCompletionsOptions
397+
* instance.
398+
* @param fromOptions the ChatCompletionsOptions to merge from.
399+
* @param toOptions the ChatCompletionsOptions to merge to.
400+
* @return a new ChatCompletionsOptions instance.
401+
*/
386402
private ChatCompletionsOptions merge(ChatCompletionsOptions fromOptions, ChatCompletionsOptions toOptions) {
387403

388404
if (fromOptions == null) {
389405
return toOptions;
390406
}
391407

392-
ChatCompletionsOptions mergedOptions = new ChatCompletionsOptions(toOptions.getMessages());
393-
mergedOptions.setStream(toOptions.isStream());
408+
ChatCompletionsOptions mergedOptions = this.copy(toOptions);
394409

395410
if (fromOptions.getMaxTokens() != null) {
396411
mergedOptions.setMaxTokens(fromOptions.getMaxTokens());
@@ -426,6 +441,50 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromOptions, ChatCom
426441
return mergedOptions;
427442
}
428443

444+
/**
445+
* Copy the fromOptions into a new ChatCompletionsOptions instance.
446+
* @param fromOptions the ChatCompletionsOptions to copy from.
447+
* @return a new ChatCompletionsOptions instance.
448+
*/
449+
private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
450+
451+
ChatCompletionsOptions copyOptions = new ChatCompletionsOptions(fromOptions.getMessages());
452+
copyOptions.setStream(fromOptions.isStream());
453+
454+
if (fromOptions.getMaxTokens() != null) {
455+
copyOptions.setMaxTokens(fromOptions.getMaxTokens());
456+
}
457+
if (fromOptions.getLogitBias() != null) {
458+
copyOptions.setLogitBias(fromOptions.getLogitBias());
459+
}
460+
if (fromOptions.getStop() != null) {
461+
copyOptions.setStop(fromOptions.getStop());
462+
}
463+
if (fromOptions.getTemperature() != null) {
464+
copyOptions.setTemperature(fromOptions.getTemperature());
465+
}
466+
if (fromOptions.getTopP() != null) {
467+
copyOptions.setTopP(fromOptions.getTopP());
468+
}
469+
if (fromOptions.getFrequencyPenalty() != null) {
470+
copyOptions.setFrequencyPenalty(fromOptions.getFrequencyPenalty());
471+
}
472+
if (fromOptions.getPresencePenalty() != null) {
473+
copyOptions.setPresencePenalty(fromOptions.getPresencePenalty());
474+
}
475+
if (fromOptions.getN() != null) {
476+
copyOptions.setN(fromOptions.getN());
477+
}
478+
if (fromOptions.getUser() != null) {
479+
copyOptions.setUser(fromOptions.getUser());
480+
}
481+
if (fromOptions.getModel() != null) {
482+
copyOptions.setModel(fromOptions.getModel());
483+
}
484+
485+
return copyOptions;
486+
}
487+
429488
@Override
430489
protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest,
431490
ChatRequestMessage responseMessage, List<ChatRequestMessage> conversationHistory) {

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
import org.springframework.ai.chat.prompt.Prompt;
2626

27+
import java.util.List;
28+
import java.util.Map;
2729
import java.util.stream.Stream;
2830

2931
import static org.assertj.core.api.Assertions.assertThat;
@@ -37,23 +39,64 @@ public class AzureChatCompletionsOptionsTests {
3739
public void createRequestWithChatOptions() {
3840

3941
OpenAIClient mockClient = Mockito.mock(OpenAIClient.class);
40-
var client = new AzureOpenAiChatClient(mockClient,
41-
AzureOpenAiChatOptions.builder().withDeploymentName("DEFAULT_MODEL").withTemperature(66.6f).build());
42+
43+
var defaultOptions = AzureOpenAiChatOptions.builder()
44+
.withDeploymentName("DEFAULT_MODEL")
45+
.withTemperature(66.6f)
46+
.withFrequencyPenalty(696.9f)
47+
.withPresencePenalty(969.6f)
48+
.withLogitBias(Map.of("foo", 1))
49+
.withMaxTokens(969)
50+
.withN(69)
51+
.withStop(List.of("foo", "bar"))
52+
.withTopP(0.69f)
53+
.withUser("user")
54+
.build();
55+
56+
var client = new AzureOpenAiChatClient(mockClient, defaultOptions);
4257

4358
var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content"));
4459

4560
assertThat(requestOptions.getMessages()).hasSize(1);
4661

4762
assertThat(requestOptions.getModel()).isEqualTo("DEFAULT_MODEL");
4863
assertThat(requestOptions.getTemperature()).isEqualTo(66.6f);
64+
assertThat(requestOptions.getFrequencyPenalty()).isEqualTo(696.9f);
65+
assertThat(requestOptions.getPresencePenalty()).isEqualTo(969.6f);
66+
assertThat(requestOptions.getLogitBias()).isEqualTo(Map.of("foo", 1));
67+
assertThat(requestOptions.getMaxTokens()).isEqualTo(969);
68+
assertThat(requestOptions.getN()).isEqualTo(69);
69+
assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar"));
70+
assertThat(requestOptions.getTopP()).isEqualTo(0.69f);
71+
assertThat(requestOptions.getUser()).isEqualTo("user");
72+
73+
var runtimeOptions = AzureOpenAiChatOptions.builder()
74+
.withDeploymentName("PROMPT_MODEL")
75+
.withTemperature(99.9f)
76+
.withFrequencyPenalty(100f)
77+
.withPresencePenalty(100f)
78+
.withLogitBias(Map.of("foo", 2))
79+
.withMaxTokens(100)
80+
.withN(100)
81+
.withStop(List.of("foo", "bar"))
82+
.withTopP(0.111f)
83+
.withUser("user2")
84+
.build();
4985

50-
requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content",
51-
AzureOpenAiChatOptions.builder().withDeploymentName("PROMPT_MODEL").withTemperature(99.9f).build()));
86+
requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content", runtimeOptions));
5287

5388
assertThat(requestOptions.getMessages()).hasSize(1);
5489

5590
assertThat(requestOptions.getModel()).isEqualTo("PROMPT_MODEL");
5691
assertThat(requestOptions.getTemperature()).isEqualTo(99.9f);
92+
assertThat(requestOptions.getFrequencyPenalty()).isEqualTo(100f);
93+
assertThat(requestOptions.getPresencePenalty()).isEqualTo(100f);
94+
assertThat(requestOptions.getLogitBias()).isEqualTo(Map.of("foo", 2));
95+
assertThat(requestOptions.getMaxTokens()).isEqualTo(100);
96+
assertThat(requestOptions.getN()).isEqualTo(100);
97+
assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar"));
98+
assertThat(requestOptions.getTopP()).isEqualTo(0.111f);
99+
assertThat(requestOptions.getUser()).isEqualTo("user2");
57100
}
58101

59102
private static Stream<Arguments> providePresencePenaltyAndFrequencyPenaltyTest() {

0 commit comments

Comments
 (0)