Skip to content

Commit bf84d59

Browse files
committed
Streamline ChatOptions
* Surface more configuration APIs to ChatOptions * Use abstraction in Observations directly instead of dedicated implementation * Simplify metadata config in observations for defined models * Improve merging of runtime and default options in OpenAI * Fix missing option in Mistral AI Relates to gh-1148 Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
1 parent af25430 commit bf84d59

File tree

59 files changed

+1031
-694
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+1031
-694
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
* The options to be used when sending a chat request to the Anthropic API.
3838
*
3939
* @author Christian Tzolov
40+
* @author Thomas Vitale
4041
* @since 1.0.0
4142
*/
4243
@JsonInclude(Include.NON_NULL)
@@ -149,6 +150,7 @@ public AnthropicChatOptions build() {
149150

150151
}
151152

153+
@Override
152154
public String getModel() {
153155
return model;
154156
}
@@ -157,6 +159,7 @@ public void setModel(String model) {
157159
this.model = model;
158160
}
159161

162+
@Override
160163
public Integer getMaxTokens() {
161164
return this.maxTokens;
162165
}
@@ -173,6 +176,7 @@ public void setMetadata(ChatCompletionRequest.Metadata metadata) {
173176
this.metadata = metadata;
174177
}
175178

179+
@Override
176180
public List<String> getStopSequences() {
177181
return this.stopSequences;
178182
}
@@ -199,6 +203,7 @@ public void setTopP(Float topP) {
199203
this.topP = topP;
200204
}
201205

206+
@Override
202207
public Integer getTopK() {
203208
return this.topK;
204209
}
@@ -229,6 +234,18 @@ public void setFunctions(Set<String> functions) {
229234
this.functions = functions;
230235
}
231236

237+
@Override
238+
@JsonIgnore
239+
public Float getFrequencyPenalty() {
240+
return null;
241+
}
242+
243+
@Override
244+
@JsonIgnore
245+
public Float getPresencePenalty() {
246+
return null;
247+
}
248+
232249
@Override
233250
public AnthropicChatOptions copy() {
234251
return fromOptions(this);

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
* prompt data.
3939
*
4040
* @author Christian Tzolov
41+
* @author Thomas Vitale
4142
*/
4243
@JsonInclude(Include.NON_NULL)
4344
public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
@@ -108,7 +109,7 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
108109
* output new topics.
109110
*/
110111
@JsonProperty(value = "presence_penalty")
111-
private Double presencePenalty;
112+
private Float presencePenalty;
112113

113114
/**
114115
* A value that influences the probability of generated tokens appearing based on
@@ -117,7 +118,7 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
117118
* model repeating the same statements verbatim.
118119
*/
119120
@JsonProperty(value = "frequency_penalty")
120-
private Double frequencyPenalty;
121+
private Float frequencyPenalty;
121122

122123
/**
123124
* The deployment name as defined in Azure Open AI Studio when creating a deployment
@@ -182,9 +183,7 @@ public Builder withDeploymentName(String deploymentName) {
182183
}
183184

184185
public Builder withFrequencyPenalty(Float frequencyPenalty) {
185-
if (frequencyPenalty != null) {
186-
this.options.frequencyPenalty = frequencyPenalty.doubleValue();
187-
}
186+
this.options.frequencyPenalty = frequencyPenalty;
188187
return this;
189188
}
190189

@@ -204,9 +203,7 @@ public Builder withN(Integer n) {
204203
}
205204

206205
public Builder withPresencePenalty(Float presencePenalty) {
207-
if (presencePenalty != null) {
208-
this.options.presencePenalty = presencePenalty.doubleValue();
209-
}
206+
this.options.presencePenalty = presencePenalty;
210207
return this;
211208
}
212209

@@ -259,6 +256,7 @@ public AzureOpenAiChatOptions build() {
259256

260257
}
261258

259+
@Override
262260
public Integer getMaxTokens() {
263261
return this.maxTokens;
264262
}
@@ -291,6 +289,17 @@ public void setN(Integer n) {
291289
this.n = n;
292290
}
293291

292+
@Override
293+
@JsonIgnore
294+
public List<String> getStopSequences() {
295+
return getStop();
296+
}
297+
298+
@JsonIgnore
299+
public void setStopSequences(List<String> stopSequences) {
300+
setStop(stopSequences);
301+
}
302+
294303
public List<String> getStop() {
295304
return this.stop;
296305
}
@@ -299,22 +308,35 @@ public void setStop(List<String> stop) {
299308
this.stop = stop;
300309
}
301310

302-
public Double getPresencePenalty() {
311+
@Override
312+
public Float getPresencePenalty() {
303313
return this.presencePenalty;
304314
}
305315

306-
public void setPresencePenalty(Double presencePenalty) {
316+
public void setPresencePenalty(Float presencePenalty) {
307317
this.presencePenalty = presencePenalty;
308318
}
309319

310-
public Double getFrequencyPenalty() {
320+
@Override
321+
public Float getFrequencyPenalty() {
311322
return this.frequencyPenalty;
312323
}
313324

314-
public void setFrequencyPenalty(Double frequencyPenalty) {
325+
public void setFrequencyPenalty(Float frequencyPenalty) {
315326
this.frequencyPenalty = frequencyPenalty;
316327
}
317328

329+
@Override
330+
@JsonIgnore
331+
public String getModel() {
332+
return getDeploymentName();
333+
}
334+
335+
@JsonIgnore
336+
public void setModel(String model) {
337+
setDeploymentName(model);
338+
}
339+
318340
public String getDeploymentName() {
319341
return this.deploymentName;
320342
}
@@ -341,17 +363,6 @@ public void setTopP(Float topP) {
341363
this.topP = topP;
342364
}
343365

344-
@Override
345-
@JsonIgnore
346-
public Integer getTopK() {
347-
throw new UnsupportedOperationException("Unimplemented method 'getTopK'");
348-
}
349-
350-
@JsonIgnore
351-
public void setTopK(Integer topK) {
352-
throw new UnsupportedOperationException("Unimplemented method 'setTopK'");
353-
}
354-
355366
@Override
356367
public List<FunctionCallback> getFunctionCallbacks() {
357368
return this.functionCallbacks;
@@ -378,20 +389,24 @@ public void setResponseFormat(AzureOpenAiResponseFormat responseFormat) {
378389
this.responseFormat = responseFormat;
379390
}
380391

392+
@Override
393+
@JsonIgnore
394+
public Integer getTopK() {
395+
return null;
396+
}
397+
381398
@Override
382399
public AzureOpenAiChatOptions copy() {
383400
return fromOptions(this);
384401
}
385402

386403
public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOptions) {
387404
return builder().withDeploymentName(fromOptions.getDeploymentName())
388-
.withFrequencyPenalty(
389-
fromOptions.getFrequencyPenalty() != null ? fromOptions.getFrequencyPenalty().floatValue() : null)
405+
.withFrequencyPenalty(fromOptions.getFrequencyPenalty() != null ? fromOptions.getFrequencyPenalty() : null)
390406
.withLogitBias(fromOptions.getLogitBias())
391407
.withMaxTokens(fromOptions.getMaxTokens())
392408
.withN(fromOptions.getN())
393-
.withPresencePenalty(
394-
fromOptions.getPresencePenalty() != null ? fromOptions.getPresencePenalty().floatValue() : null)
409+
.withPresencePenalty(fromOptions.getPresencePenalty() != null ? fromOptions.getPresencePenalty() : null)
395410
.withStop(fromOptions.getStop())
396411
.withTemperature(fromOptions.getTemperature())
397412
.withTopP(fromOptions.getTopP())

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingOptions.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.util.List;
1919

20+
import com.fasterxml.jackson.annotation.JsonIgnore;
2021
import org.springframework.ai.embedding.EmbeddingOptions;
2122

2223
/**
@@ -125,10 +126,16 @@ public AzureOpenAiEmbeddingOptions build() {
125126
}
126127

127128
@Override
129+
@JsonIgnore
128130
public String getModel() {
129131
return getDeploymentName();
130132
}
131133

134+
@JsonIgnore
135+
public void setModel(String model) {
136+
setDeploymentName(model);
137+
}
138+
132139
public String getUser() {
133140
return this.user;
134141
}

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic/AnthropicChatOptions.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.util.List;
1919

20+
import com.fasterxml.jackson.annotation.JsonIgnore;
2021
import com.fasterxml.jackson.annotation.JsonInclude;
2122
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2223

@@ -26,6 +27,7 @@
2627

2728
/**
2829
* @author Christian Tzolov
30+
* @author Thomas Vitale
2931
*/
3032
@JsonInclude(Include.NON_NULL)
3133
public class AnthropicChatOptions implements ChatOptions {
@@ -122,6 +124,17 @@ public void setTemperature(Float temperature) {
122124
this.temperature = temperature;
123125
}
124126

127+
@Override
128+
@JsonIgnore
129+
public Integer getMaxTokens() {
130+
return getMaxTokensToSample();
131+
}
132+
133+
@JsonIgnore
134+
public void setMaxTokens(Integer maxTokens) {
135+
setMaxTokensToSample(maxTokens);
136+
}
137+
125138
public Integer getMaxTokensToSample() {
126139
return this.maxTokensToSample;
127140
}
@@ -148,6 +161,7 @@ public void setTopP(Float topP) {
148161
this.topP = topP;
149162
}
150163

164+
@Override
151165
public List<String> getStopSequences() {
152166
return this.stopSequences;
153167
}
@@ -164,6 +178,24 @@ public void setAnthropicVersion(String anthropicVersion) {
164178
this.anthropicVersion = anthropicVersion;
165179
}
166180

181+
@Override
182+
@JsonIgnore
183+
public String getModel() {
184+
return null;
185+
}
186+
187+
@Override
188+
@JsonIgnore
189+
public Float getFrequencyPenalty() {
190+
return null;
191+
}
192+
193+
@Override
194+
@JsonIgnore
195+
public Float getPresencePenalty() {
196+
return null;
197+
}
198+
167199
@Override
168200
public AnthropicChatOptions copy() {
169201
return fromOptions(this);

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/anthropic3/Anthropic3ChatOptions.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package org.springframework.ai.bedrock.anthropic3;
1717

18+
import com.fasterxml.jackson.annotation.JsonIgnore;
1819
import com.fasterxml.jackson.annotation.JsonInclude;
1920
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2021
import com.fasterxml.jackson.annotation.JsonProperty;
@@ -24,6 +25,7 @@
2425

2526
/**
2627
* @author Ben Middleton
28+
* @author Thomas Vitale
2729
* @since 1.0.0
2830
*/
2931
@JsonInclude(Include.NON_NULL)
@@ -121,6 +123,7 @@ public void setTemperature(Float temperature) {
121123
this.temperature = temperature;
122124
}
123125

126+
@Override
124127
public Integer getMaxTokens() {
125128
return this.maxTokens;
126129
}
@@ -147,6 +150,7 @@ public void setTopP(Float topP) {
147150
this.topP = topP;
148151
}
149152

153+
@Override
150154
public List<String> getStopSequences() {
151155
return this.stopSequences;
152156
}
@@ -163,6 +167,24 @@ public void setAnthropicVersion(String anthropicVersion) {
163167
this.anthropicVersion = anthropicVersion;
164168
}
165169

170+
@Override
171+
@JsonIgnore
172+
public String getModel() {
173+
return null;
174+
}
175+
176+
@Override
177+
@JsonIgnore
178+
public Float getFrequencyPenalty() {
179+
return null;
180+
}
181+
182+
@Override
183+
@JsonIgnore
184+
public Float getPresencePenalty() {
185+
return null;
186+
}
187+
166188
@Override
167189
public Anthropic3ChatOptions copy() {
168190
return fromOptions(this);

0 commit comments

Comments
 (0)