Skip to content

Commit 5b4784f

Browse files
tzolovmarkpollack
authored andcommitted
Add OpenAi Chat Options
Add OpenAiChatOptions - OpenAiChatOptions implements ChatOptions and exposes all OpenAi request options, except messages and stream. - Add OpenAiChatOptions field (as defaultOptions) to OpenAiChatClient. Implement star-up/runtime options merging on chat request creation - Add OpenAiChatOptions options field to OpenAiChatProperties. Later is set as OpenAiChatClient#defaultOptions. Use the spring.ai.openai.chat.options.* prefix to set the options. - Add tests for properties and options merging. - Move the OpenAiChatOptions.java out of the api package Add OpenAiEmbeddingOptions - Add OpenAiEmbeddingOptions class implementing the EmbeddingOptions interface. - Add OpenAiEmbeddingClient#defaultOptions - Add request merging with default and propmt options. - Add OopenAiEmbeddingProperties#options field of type OpenAiEmbeddingOptions Update OpenAI client docs Part of the larger 'epic' issue #228
1 parent 7c7392e commit 5b4784f

File tree

16 files changed

+962
-104
lines changed

16 files changed

+962
-104
lines changed

models/spring-ai-openai/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@
3434
<version>2.0.4</version>
3535
</dependency>
3636

37+
<!-- NOTE: Required only by the @ConstructorBinding. -->
38+
<dependency>
39+
<groupId>org.springframework.boot</groupId>
40+
<artifactId>spring-boot</artifactId>
41+
</dependency>
42+
3743
<dependency>
3844
<groupId>io.rest-assured</groupId>
3945
<artifactId>json-path</artifactId>

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,14 @@
2828
import org.springframework.ai.chat.ChatResponse;
2929
import org.springframework.ai.chat.Generation;
3030
import org.springframework.ai.chat.StreamingChatClient;
31-
import org.springframework.ai.chat.messages.Message;
3231
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
3332
import org.springframework.ai.chat.metadata.RateLimit;
3433
import org.springframework.ai.chat.prompt.Prompt;
34+
import org.springframework.ai.model.ModelOptionsUtils;
3535
import org.springframework.ai.openai.api.OpenAiApi;
3636
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
3737
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
38+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
3839
import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException;
3940
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
4041
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
@@ -57,12 +58,13 @@
5758
*/
5859
public class OpenAiChatClient implements ChatClient, StreamingChatClient {
5960

60-
private Double temperature = 0.7;
61-
62-
private String model = "gpt-3.5-turbo";
63-
6461
private final Logger logger = LoggerFactory.getLogger(getClass());
6562

63+
private OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder()
64+
.withModel("gpt-3.5-turbo")
65+
.withTemperature(0.7f)
66+
.build();
67+
6668
public final RetryTemplate retryTemplate = RetryTemplate.builder()
6769
.maxAttempts(10)
6870
.retryOn(OpenAiApiException.class)
@@ -76,40 +78,23 @@ public OpenAiChatClient(OpenAiApi openAiApi) {
7678
this.openAiApi = openAiApi;
7779
}
7880

79-
public String getModel() {
80-
return this.model;
81-
}
82-
83-
public void setModel(String model) {
84-
this.model = model;
85-
}
86-
87-
public Double getTemperature() {
88-
return this.temperature;
89-
}
90-
91-
public void setTemperature(Double temperature) {
92-
this.temperature = temperature;
81+
public OpenAiChatClient withDefaultOptions(OpenAiChatOptions options) {
82+
this.defaultOptions = options;
83+
return this;
9384
}
9485

9586
@Override
9687
public ChatResponse call(Prompt prompt) {
9788

9889
return this.retryTemplate.execute(ctx -> {
99-
List<Message> messages = prompt.getInstructions();
10090

101-
List<ChatCompletionMessage> chatCompletionMessages = messages.stream()
102-
.map(m -> new ChatCompletionMessage(m.getContent(),
103-
ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
104-
.toList();
91+
ChatCompletionRequest request = createRequest(prompt, false);
10592

106-
ResponseEntity<ChatCompletion> completionEntity = this.openAiApi
107-
.chatCompletionEntity(new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, this.model,
108-
this.temperature.floatValue()));
93+
ResponseEntity<ChatCompletion> completionEntity = this.openAiApi.chatCompletionEntity(request);
10994

11095
var chatCompletion = completionEntity.getBody();
11196
if (chatCompletion == null) {
112-
logger.warn("No chat completion returned for request: {}", chatCompletionMessages);
97+
logger.warn("No chat completion returned for request: {}", prompt);
11398
return new ChatResponse(List.of());
11499
}
115100

@@ -128,16 +113,9 @@ public ChatResponse call(Prompt prompt) {
128113
@Override
129114
public Flux<ChatResponse> stream(Prompt prompt) {
130115
return this.retryTemplate.execute(ctx -> {
131-
List<Message> messages = prompt.getInstructions();
116+
ChatCompletionRequest request = createRequest(prompt, true);
132117

133-
List<ChatCompletionMessage> chatCompletionMessages = messages.stream()
134-
.map(m -> new ChatCompletionMessage(m.getContent(),
135-
ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
136-
.toList();
137-
138-
Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi
139-
.chatCompletionStream(new OpenAiApi.ChatCompletionRequest(chatCompletionMessages, this.model,
140-
this.temperature.floatValue(), true));
118+
Flux<OpenAiApi.ChatCompletionChunk> completionChunks = this.openAiApi.chatCompletionStream(request);
141119

142120
// For chunked responses, only the first chunk contains the choice role.
143121
// The rest of the chunks with same ID share the same role.
@@ -161,4 +139,34 @@ public Flux<ChatResponse> stream(Prompt prompt) {
161139
});
162140
}
163141

142+
/**
143+
* Accessible for testing.
144+
*/
145+
ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
146+
147+
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions()
148+
.stream()
149+
.map(m -> new ChatCompletionMessage(m.getContent(),
150+
ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
151+
.toList();
152+
153+
ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream);
154+
155+
if (this.defaultOptions != null) {
156+
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
157+
}
158+
159+
if (prompt.getOptions() != null) {
160+
if (prompt.getOptions() instanceof OpenAiChatOptions runtimeOptions) {
161+
request = ModelOptionsUtils.merge(runtimeOptions, request, ChatCompletionRequest.class);
162+
}
163+
else {
164+
throw new IllegalArgumentException("Prompt options are not of type ChatCompletionRequest:"
165+
+ prompt.getOptions().getClass().getSimpleName());
166+
}
167+
}
168+
169+
return request;
170+
}
171+
164172
}

0 commit comments

Comments
 (0)