Skip to content

Commit 8f20aab

Browse files
apappascsilayaperumalg
authored andcommitted
feat: Add equals, hashCode, deep copy, tests to OpenAiChatOptions
This commit enhances OpenAiChatOptions by: - Updating copy() method, creating new instances of mutable collections (List, Set, Map, Metadata) to prevent shared state. - Adding OpenAiChatOptionsTests to verify copy(), builders, setters, and default values. Signed-off-by: Alexandros Pappas <apappascs@gmail.com>
1 parent 25c584a commit 8f20aab

File tree

2 files changed

+270
-6
lines changed

2 files changed

+270
-6
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,24 +237,26 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) {
237237
.maxTokens(fromOptions.getMaxTokens())
238238
.maxCompletionTokens(fromOptions.getMaxCompletionTokens())
239239
.N(fromOptions.getN())
240-
.outputModalities(fromOptions.getOutputModalities())
240+
.outputModalities(fromOptions.getOutputModalities() != null
241+
? new ArrayList<>(fromOptions.getOutputModalities()) : null)
241242
.outputAudio(fromOptions.getOutputAudio())
242243
.presencePenalty(fromOptions.getPresencePenalty())
243244
.responseFormat(fromOptions.getResponseFormat())
244245
.streamUsage(fromOptions.getStreamUsage())
245246
.seed(fromOptions.getSeed())
246-
.stop(fromOptions.getStop())
247+
.stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null)
247248
.temperature(fromOptions.getTemperature())
248249
.topP(fromOptions.getTopP())
249250
.tools(fromOptions.getTools())
250251
.toolChoice(fromOptions.getToolChoice())
251252
.user(fromOptions.getUser())
252253
.parallelToolCalls(fromOptions.getParallelToolCalls())
253-
.toolCallbacks(fromOptions.getToolCallbacks())
254-
.toolNames(fromOptions.getToolNames())
255-
.httpHeaders(fromOptions.getHttpHeaders())
254+
.toolCallbacks(
255+
fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null)
256+
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
257+
.httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null)
256258
.internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled())
257-
.toolContext(fromOptions.getToolContext())
259+
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
258260
.store(fromOptions.getStore())
259261
.metadata(fromOptions.getMetadata())
260262
.reasoningEffort(fromOptions.getReasoningEffort())
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
/*
2+
* Copyright 2025-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai;
18+
19+
import java.util.ArrayList;
20+
import java.util.HashMap;
21+
import java.util.List;
22+
import java.util.Map;
23+
24+
import static org.assertj.core.api.Assertions.assertThat;
25+
import org.junit.jupiter.api.Test;
26+
import static org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters.Voice.ALLOY;
27+
28+
import org.springframework.ai.openai.api.OpenAiApi;
29+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters;
30+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.StreamOptions;
31+
import org.springframework.ai.openai.api.ResponseFormat;
32+
33+
/**
34+
* Tests for {@link OpenAiChatOptions}.
35+
*
36+
* @author Alexandros Pappas
37+
*/
38+
class OpenAiChatOptionsTests {
39+
40+
@Test
41+
void testBuilderWithAllFields() {
42+
Map<String, Integer> logitBias = new HashMap<>();
43+
logitBias.put("token1", 1);
44+
logitBias.put("token2", -1);
45+
46+
List<String> outputModalities = List.of("text", "audio");
47+
AudioParameters outputAudio = new AudioParameters(ALLOY, AudioParameters.AudioResponseFormat.MP3);
48+
ResponseFormat responseFormat = new ResponseFormat();
49+
StreamOptions streamOptions = StreamOptions.INCLUDE_USAGE;
50+
List<String> stopSequences = List.of("stop1", "stop2");
51+
List<OpenAiApi.FunctionTool> tools = new ArrayList<>();
52+
Object toolChoice = "auto";
53+
Map<String, String> metadata = Map.of("key1", "value1");
54+
Map<String, Object> toolContext = Map.of("keyA", "valueA");
55+
56+
OpenAiChatOptions options = OpenAiChatOptions.builder()
57+
.model("test-model")
58+
.frequencyPenalty(0.5)
59+
.logitBias(logitBias)
60+
.logprobs(true)
61+
.topLogprobs(5)
62+
.maxTokens(100)
63+
.maxCompletionTokens(50)
64+
.N(2)
65+
.outputModalities(outputModalities)
66+
.outputAudio(outputAudio)
67+
.presencePenalty(0.8)
68+
.responseFormat(responseFormat)
69+
.streamUsage(true)
70+
.seed(12345)
71+
.stop(stopSequences)
72+
.temperature(0.7)
73+
.topP(0.9)
74+
.tools(tools)
75+
.toolChoice(toolChoice)
76+
.user("test-user")
77+
.parallelToolCalls(true)
78+
.store(false)
79+
.metadata(metadata)
80+
.reasoningEffort("medium")
81+
.proxyToolCalls(false)
82+
.httpHeaders(Map.of("header1", "value1"))
83+
.toolContext(toolContext)
84+
.build();
85+
86+
assertThat(options)
87+
.extracting("model", "frequencyPenalty", "logitBias", "logprobs", "topLogprobs", "maxTokens",
88+
"maxCompletionTokens", "n", "outputModalities", "outputAudio", "presencePenalty", "responseFormat",
89+
"streamOptions", "seed", "stop", "temperature", "topP", "tools", "toolChoice", "user",
90+
"parallelToolCalls", "store", "metadata", "reasoningEffort", "proxyToolCalls", "httpHeaders",
91+
"toolContext")
92+
.containsExactly("test-model", 0.5, logitBias, true, 5, 100, 50, 2, outputModalities, outputAudio, 0.8,
93+
responseFormat, streamOptions, 12345, stopSequences, 0.7, 0.9, tools, toolChoice, "test-user", true,
94+
false, metadata, "medium", false, Map.of("header1", "value1"), toolContext);
95+
96+
assertThat(options.getStreamUsage()).isTrue();
97+
assertThat(options.getStreamOptions()).isEqualTo(StreamOptions.INCLUDE_USAGE);
98+
99+
}
100+
101+
@Test
102+
void testCopy() {
103+
Map<String, Integer> logitBias = new HashMap<>();
104+
logitBias.put("token1", 1);
105+
106+
List<String> outputModalities = List.of("text");
107+
AudioParameters outputAudio = new AudioParameters(ALLOY, AudioParameters.AudioResponseFormat.MP3);
108+
ResponseFormat responseFormat = new ResponseFormat();
109+
110+
List<String> stopSequences = List.of("stop1");
111+
List<OpenAiApi.FunctionTool> tools = new ArrayList<>();
112+
Object toolChoice = "none";
113+
Map<String, String> metadata = Map.of("key1", "value1");
114+
115+
OpenAiChatOptions originalOptions = OpenAiChatOptions.builder()
116+
.model("test-model")
117+
.frequencyPenalty(0.5)
118+
.logitBias(logitBias)
119+
.logprobs(true)
120+
.topLogprobs(5)
121+
.maxTokens(100)
122+
.maxCompletionTokens(50)
123+
.N(2)
124+
.outputModalities(outputModalities)
125+
.outputAudio(outputAudio)
126+
.presencePenalty(0.8)
127+
.responseFormat(responseFormat)
128+
.streamUsage(false)
129+
.seed(12345)
130+
.stop(stopSequences)
131+
.temperature(0.7)
132+
.topP(0.9)
133+
.tools(tools)
134+
.toolChoice(toolChoice)
135+
.user("test-user")
136+
.parallelToolCalls(false)
137+
.store(true)
138+
.metadata(metadata)
139+
.reasoningEffort("low")
140+
.proxyToolCalls(true)
141+
.httpHeaders(Map.of("header1", "value1"))
142+
.build();
143+
144+
OpenAiChatOptions copiedOptions = originalOptions.copy();
145+
assertThat(copiedOptions).isNotSameAs(originalOptions).isEqualTo(originalOptions);
146+
}
147+
148+
@Test
149+
void testSetters() {
150+
Map<String, Integer> logitBias = new HashMap<>();
151+
logitBias.put("token1", 1);
152+
153+
List<String> outputModalities = List.of("audio");
154+
AudioParameters outputAudio = new AudioParameters(ALLOY, AudioParameters.AudioResponseFormat.MP3);
155+
ResponseFormat responseFormat = new ResponseFormat();
156+
157+
StreamOptions streamOptions = StreamOptions.INCLUDE_USAGE;
158+
List<String> stopSequences = List.of("stop1", "stop2");
159+
List<OpenAiApi.FunctionTool> tools = new ArrayList<>();
160+
Object toolChoice = "auto";
161+
Map<String, String> metadata = Map.of("key2", "value2");
162+
163+
OpenAiChatOptions options = new OpenAiChatOptions();
164+
options.setModel("test-model");
165+
options.setFrequencyPenalty(0.5);
166+
options.setLogitBias(logitBias);
167+
options.setLogprobs(true);
168+
options.setTopLogprobs(5);
169+
options.setMaxTokens(100);
170+
options.setMaxCompletionTokens(50);
171+
options.setN(2);
172+
options.setOutputModalities(outputModalities);
173+
options.setOutputAudio(outputAudio);
174+
options.setPresencePenalty(0.8);
175+
options.setResponseFormat(responseFormat);
176+
options.setStreamOptions(streamOptions);
177+
options.setSeed(12345);
178+
options.setStop(stopSequences);
179+
options.setTemperature(0.7);
180+
options.setTopP(0.9);
181+
options.setTools(tools);
182+
options.setToolChoice(toolChoice);
183+
options.setUser("test-user");
184+
options.setParallelToolCalls(true);
185+
options.setStore(false);
186+
options.setMetadata(metadata);
187+
options.setReasoningEffort("high");
188+
options.setProxyToolCalls(false);
189+
options.setHttpHeaders(Map.of("header2", "value2"));
190+
191+
assertThat(options.getModel()).isEqualTo("test-model");
192+
assertThat(options.getFrequencyPenalty()).isEqualTo(0.5);
193+
assertThat(options.getLogitBias()).isEqualTo(logitBias);
194+
assertThat(options.getLogprobs()).isTrue();
195+
assertThat(options.getTopLogprobs()).isEqualTo(5);
196+
assertThat(options.getMaxTokens()).isEqualTo(100);
197+
assertThat(options.getMaxCompletionTokens()).isEqualTo(50);
198+
assertThat(options.getN()).isEqualTo(2);
199+
assertThat(options.getOutputModalities()).isEqualTo(outputModalities);
200+
assertThat(options.getOutputAudio()).isEqualTo(outputAudio);
201+
assertThat(options.getPresencePenalty()).isEqualTo(0.8);
202+
assertThat(options.getResponseFormat()).isEqualTo(responseFormat);
203+
assertThat(options.getStreamOptions()).isEqualTo(streamOptions);
204+
assertThat(options.getSeed()).isEqualTo(12345);
205+
assertThat(options.getStop()).isEqualTo(stopSequences);
206+
assertThat(options.getTemperature()).isEqualTo(0.7);
207+
assertThat(options.getTopP()).isEqualTo(0.9);
208+
assertThat(options.getTools()).isEqualTo(tools);
209+
assertThat(options.getToolChoice()).isEqualTo(toolChoice);
210+
assertThat(options.getUser()).isEqualTo("test-user");
211+
assertThat(options.getParallelToolCalls()).isTrue();
212+
assertThat(options.getStore()).isFalse();
213+
assertThat(options.getMetadata()).isEqualTo(metadata);
214+
assertThat(options.getReasoningEffort()).isEqualTo("high");
215+
assertThat(options.getProxyToolCalls()).isFalse();
216+
assertThat(options.getHttpHeaders()).isEqualTo(Map.of("header2", "value2"));
217+
assertThat(options.getStreamUsage()).isTrue();
218+
options.setStreamUsage(false);
219+
assertThat(options.getStreamUsage()).isFalse();
220+
assertThat(options.getStreamOptions()).isNull();
221+
options.setStopSequences(List.of("s1", "s2"));
222+
assertThat(options.getStopSequences()).isEqualTo(List.of("s1", "s2"));
223+
assertThat(options.getStop()).isEqualTo(List.of("s1", "s2"));
224+
}
225+
226+
@Test
227+
void testDefaultValues() {
228+
OpenAiChatOptions options = new OpenAiChatOptions();
229+
assertThat(options.getModel()).isNull();
230+
assertThat(options.getFrequencyPenalty()).isNull();
231+
assertThat(options.getLogitBias()).isNull();
232+
assertThat(options.getLogprobs()).isNull();
233+
assertThat(options.getTopLogprobs()).isNull();
234+
assertThat(options.getMaxTokens()).isNull();
235+
assertThat(options.getMaxCompletionTokens()).isNull();
236+
assertThat(options.getN()).isNull();
237+
assertThat(options.getOutputModalities()).isNull();
238+
assertThat(options.getOutputAudio()).isNull();
239+
assertThat(options.getPresencePenalty()).isNull();
240+
assertThat(options.getResponseFormat()).isNull();
241+
assertThat(options.getStreamOptions()).isNull();
242+
assertThat(options.getSeed()).isNull();
243+
assertThat(options.getStop()).isNull();
244+
assertThat(options.getTemperature()).isNull();
245+
assertThat(options.getTopP()).isNull();
246+
assertThat(options.getTools()).isNull();
247+
assertThat(options.getToolChoice()).isNull();
248+
assertThat(options.getUser()).isNull();
249+
assertThat(options.getParallelToolCalls()).isNull();
250+
assertThat(options.getStore()).isNull();
251+
assertThat(options.getMetadata()).isNull();
252+
assertThat(options.getReasoningEffort()).isNull();
253+
assertThat(options.getFunctionCallbacks()).isNotNull().isEmpty();
254+
assertThat(options.getFunctions()).isNotNull().isEmpty();
255+
assertThat(options.getProxyToolCalls()).isNull();
256+
assertThat(options.getHttpHeaders()).isNotNull().isEmpty();
257+
assertThat(options.getToolContext()).isEqualTo(new HashMap<>());
258+
assertThat(options.getStreamUsage()).isFalse();
259+
assertThat(options.getStopSequences()).isNull();
260+
}
261+
262+
}

0 commit comments

Comments
 (0)