Skip to content

Commit 11d0578

Browse files
ThomasVitaletzolov
authored andcommitted
OpenAI: Added missing fields to API
Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
1 parent 4539a41 commit 11d0578

File tree

4 files changed

+83
-6
lines changed

4 files changed

+83
-6
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,17 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
5959
* or 100 should result in a ban or exclusive selection of the relevant token.
6060
*/
6161
private @JsonProperty("logit_bias") Map<String, Integer> logitBias;
62+
/**
63+
* Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities
64+
* of each output token returned in the 'content' of 'message'. This option is currently not available
65+
* on the 'gpt-4-vision-preview' model.
66+
*/
67+
private @JsonProperty("logprobs") Boolean logprobs;
68+
/**
69+
* An integer between 0 and 5 specifying the number of most likely tokens to return at each token position,
70+
* each with an associated log probability. 'logprobs' must be set to 'true' if this parameter is used.
71+
*/
72+
private @JsonProperty("top_logprobs") Integer topLogprobs;
6273
/**
6374
* The maximum number of tokens to generate in the chat completion. The total length of input
6475
* tokens and generated tokens is limited by the model's context length.
@@ -177,6 +188,16 @@ public Builder withLogitBias(Map<String, Integer> logitBias) {
177188
return this;
178189
}
179190

191+
public Builder withLogprobs(Boolean logprobs) {
192+
this.options.logprobs = logprobs;
193+
return this;
194+
}
195+
196+
public Builder withTopLogprobs(Integer topLogprobs) {
197+
this.options.topLogprobs = topLogprobs;
198+
return this;
199+
}
200+
180201
public Builder withMaxTokens(Integer maxTokens) {
181202
this.options.maxTokens = maxTokens;
182203
return this;
@@ -279,6 +300,22 @@ public void setLogitBias(Map<String, Integer> logitBias) {
279300
this.logitBias = logitBias;
280301
}
281302

303+
public Boolean getLogprobs() {
304+
return this.logprobs;
305+
}
306+
307+
public void setLogprobs(Boolean logprobs) {
308+
this.logprobs = logprobs;
309+
}
310+
311+
public Integer getTopLogprobs() {
312+
return this.topLogprobs;
313+
}
314+
315+
public void setTopLogprobs(Integer topLogprobs) {
316+
this.topLogprobs = topLogprobs;
317+
}
318+
282319
public Integer getMaxTokens() {
283320
return this.maxTokens;
284321
}
@@ -395,6 +432,8 @@ public int hashCode() {
395432
result = prime * result + ((model == null) ? 0 : model.hashCode());
396433
result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode());
397434
result = prime * result + ((logitBias == null) ? 0 : logitBias.hashCode());
435+
result = prime * result + ((logprobs == null) ? 0 : logprobs.hashCode());
436+
result = prime * result + ((topLogprobs == null) ? 0 : topLogprobs.hashCode());
398437
result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode());
399438
result = prime * result + ((n == null) ? 0 : n.hashCode());
400439
result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode());
@@ -436,6 +475,18 @@ else if (!this.frequencyPenalty.equals(other.frequencyPenalty))
436475
}
437476
else if (!this.logitBias.equals(other.logitBias))
438477
return false;
478+
if (this.logprobs == null) {
479+
if (other.logprobs != null)
480+
return false;
481+
}
482+
else if (!this.logprobs.equals(other.logprobs))
483+
return false;
484+
if (this.topLogprobs == null) {
485+
if (other.topLogprobs != null)
486+
return false;
487+
}
488+
else if (!this.topLogprobs.equals(other.topLogprobs))
489+
return false;
439490
if (this.maxTokens == null) {
440491
if (other.maxTokens != null)
441492
return false;

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ public EmbeddingResponse call(EmbeddingRequest request) {
118118
org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<List<String>> apiRequest = (this.defaultOptions != null)
119119
? new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>(request.getInstructions(),
120120
this.defaultOptions.getModel(), this.defaultOptions.getEncodingFormat(),
121-
this.defaultOptions.getUser())
121+
this.defaultOptions.getDimensions(), this.defaultOptions.getUser())
122122
: new org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest<>(request.getInstructions(),
123123
OpenAiApi.DEFAULT_EMBEDDING_MODEL);
124124

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ public class OpenAiEmbeddingOptions implements EmbeddingOptions {
3737
* The format to return the embeddings in. Can be either float or base64.
3838
*/
3939
private @JsonProperty("encoding_format") String encodingFormat;
40+
/**
41+
* The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
42+
*/
43+
private @JsonProperty("dimensions") Integer dimensions;
4044
/**
4145
* A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
4246
*/
@@ -65,6 +69,11 @@ public Builder withEncodingFormat(String encodingFormat) {
6569
return this;
6670
}
6771

72+
public Builder withDimensions(Integer dimensions) {
73+
this.options.dimensions = dimensions;
74+
return this;
75+
}
76+
6877
public Builder withUser(String user) {
6978
this.options.setUser(user);
7079
return this;
@@ -92,6 +101,14 @@ public void setEncodingFormat(String encodingFormat) {
92101
this.encodingFormat = encodingFormat;
93102
}
94103

104+
public Integer getDimensions() {
105+
return dimensions;
106+
}
107+
108+
public void setDimensions(Integer dimensions) {
109+
this.dimensions = dimensions;
110+
}
111+
95112
public String getUser() {
96113
return user;
97114
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,11 @@ public Function(String description, String name, String jsonSchema) {
265265
* Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will
266266
* vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100
267267
* or 100 should result in a ban or exclusive selection of the relevant token.
268+
* @param logprobs Whether to return log probabilities of the output tokens or not. If true, returns the log
269+
* probabilities of each output token returned in the 'content' of 'message'. This option is currently not available
270+
* on the 'gpt-4-vision-preview' model.
271+
* @param topLogprobs An integer between 0 and 5 specifying the number of most likely tokens to return at each token
272+
* position, each with an associated log probability. 'logprobs' must be set to 'true' if this parameter is used.
268273
* @param maxTokens The maximum number of tokens to generate in the chat completion. The total length of input
269274
* tokens and generated tokens is limited by the model's context length.
270275
* @param n How many chat completion choices to generate for each input message. Note that you will be charged based
@@ -302,6 +307,8 @@ public record ChatCompletionRequest (
302307
@JsonProperty("model") String model,
303308
@JsonProperty("frequency_penalty") Float frequencyPenalty,
304309
@JsonProperty("logit_bias") Map<String, Integer> logitBias,
310+
@JsonProperty("logprobs") Boolean logprobs,
311+
@JsonProperty("top_logprobs") Integer topLogprobs,
305312
@JsonProperty("max_tokens") Integer maxTokens,
306313
@JsonProperty("n") Integer n,
307314
@JsonProperty("presence_penalty") Float presencePenalty,
@@ -323,7 +330,7 @@ public record ChatCompletionRequest (
323330
* @param temperature What sampling temperature to use, between 0 and 1.
324331
*/
325332
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Float temperature) {
326-
this(messages, model, null, null, null, null, null,
333+
this(messages, model, null, null, null, null, null, null, null,
327334
null, null, null, false, temperature, null,
328335
null, null, null);
329336
}
@@ -338,7 +345,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
338345
* as they become available, with the stream terminated by a data: [DONE] message.
339346
*/
340347
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model, Float temperature, boolean stream) {
341-
this(messages, model, null, null, null, null, null,
348+
this(messages, model, null, null, null, null, null, null, null,
342349
null, null, null, stream, temperature, null,
343350
null, null, null);
344351
}
@@ -354,7 +361,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
354361
*/
355362
public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
356363
List<FunctionTool> tools, String toolChoice) {
357-
this(messages, model, null, null, null, null, null,
364+
this(messages, model, null, null, null, null, null, null, null,
358365
null, null, null, false, 0.8f, null,
359366
tools, toolChoice, null);
360367
}
@@ -368,7 +375,7 @@ public ChatCompletionRequest(List<ChatCompletionMessage> messages, String model,
368375
* as they become available, with the stream terminated by a data: [DONE] message.
369376
*/
370377
public ChatCompletionRequest(List<ChatCompletionMessage> messages, Boolean stream) {
371-
this(messages, null, null, null, null, null, null,
378+
this(messages, null, null, null, null, null, null, null, null,
372379
null, null, null, stream, null, null,
373380
null, null, null);
374381
}
@@ -869,13 +876,15 @@ public Embedding(Integer index, List<Double> embedding) {
869876
* dimensions or less.
870877
* @param model ID of the model to use.
871878
* @param encodingFormat The format to return the embeddings in. Can be either float or base64.
879+
* @param dimensions The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
872880
* @param user A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
873881
*/
874882
@JsonInclude(Include.NON_NULL)
875883
public record EmbeddingRequest<T>(
876884
@JsonProperty("input") T input,
877885
@JsonProperty("model") String model,
878886
@JsonProperty("encoding_format") String encodingFormat,
887+
@JsonProperty("dimensions") Integer dimensions,
879888
@JsonProperty("user") String user) {
880889

881890
/**
@@ -884,7 +893,7 @@ public record EmbeddingRequest<T>(
884893
* @param model ID of the model to use.
885894
*/
886895
public EmbeddingRequest(T input, String model) {
887-
this(input, model, "float", null);
896+
this(input, model, "float", null, null);
888897
}
889898

890899
/**

0 commit comments

Comments
 (0)