28
28
import org .springframework .ai .chat .ChatResponse ;
29
29
import org .springframework .ai .chat .Generation ;
30
30
import org .springframework .ai .chat .StreamingChatClient ;
31
- import org .springframework .ai .chat .messages .Message ;
32
31
import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
33
32
import org .springframework .ai .chat .metadata .RateLimit ;
34
33
import org .springframework .ai .chat .prompt .Prompt ;
34
+ import org .springframework .ai .model .ModelOptionsUtils ;
35
35
import org .springframework .ai .openai .api .OpenAiApi ;
36
36
import org .springframework .ai .openai .api .OpenAiApi .ChatCompletion ;
37
37
import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionMessage ;
38
+ import org .springframework .ai .openai .api .OpenAiApi .ChatCompletionRequest ;
38
39
import org .springframework .ai .openai .api .OpenAiApi .OpenAiApiException ;
39
40
import org .springframework .ai .openai .metadata .OpenAiChatResponseMetadata ;
40
41
import org .springframework .ai .openai .metadata .support .OpenAiResponseHeaderExtractor ;
57
58
*/
58
59
public class OpenAiChatClient implements ChatClient , StreamingChatClient {
59
60
60
- private Double temperature = 0.7 ;
61
-
62
- private String model = "gpt-3.5-turbo" ;
63
-
64
61
private final Logger logger = LoggerFactory .getLogger (getClass ());
65
62
63
+ private OpenAiChatOptions defaultOptions = OpenAiChatOptions .builder ()
64
+ .withModel ("gpt-3.5-turbo" )
65
+ .withTemperature (0.7f )
66
+ .build ();
67
+
66
68
public final RetryTemplate retryTemplate = RetryTemplate .builder ()
67
69
.maxAttempts (10 )
68
70
.retryOn (OpenAiApiException .class )
@@ -76,40 +78,23 @@ public OpenAiChatClient(OpenAiApi openAiApi) {
76
78
this .openAiApi = openAiApi ;
77
79
}
78
80
79
- public String getModel () {
80
- return this .model ;
81
- }
82
-
83
- public void setModel (String model ) {
84
- this .model = model ;
85
- }
86
-
87
- public Double getTemperature () {
88
- return this .temperature ;
89
- }
90
-
91
- public void setTemperature (Double temperature ) {
92
- this .temperature = temperature ;
81
+ public OpenAiChatClient withDefaultOptions (OpenAiChatOptions options ) {
82
+ this .defaultOptions = options ;
83
+ return this ;
93
84
}
94
85
95
86
@ Override
96
87
public ChatResponse call (Prompt prompt ) {
97
88
98
89
return this .retryTemplate .execute (ctx -> {
99
- List <Message > messages = prompt .getInstructions ();
100
90
101
- List <ChatCompletionMessage > chatCompletionMessages = messages .stream ()
102
- .map (m -> new ChatCompletionMessage (m .getContent (),
103
- ChatCompletionMessage .Role .valueOf (m .getMessageType ().name ())))
104
- .toList ();
91
+ ChatCompletionRequest request = createRequest (prompt , false );
105
92
106
- ResponseEntity <ChatCompletion > completionEntity = this .openAiApi
107
- .chatCompletionEntity (new OpenAiApi .ChatCompletionRequest (chatCompletionMessages , this .model ,
108
- this .temperature .floatValue ()));
93
+ ResponseEntity <ChatCompletion > completionEntity = this .openAiApi .chatCompletionEntity (request );
109
94
110
95
var chatCompletion = completionEntity .getBody ();
111
96
if (chatCompletion == null ) {
112
- logger .warn ("No chat completion returned for request: {}" , chatCompletionMessages );
97
+ logger .warn ("No chat completion returned for request: {}" , prompt );
113
98
return new ChatResponse (List .of ());
114
99
}
115
100
@@ -128,16 +113,9 @@ public ChatResponse call(Prompt prompt) {
128
113
@ Override
129
114
public Flux <ChatResponse > stream (Prompt prompt ) {
130
115
return this .retryTemplate .execute (ctx -> {
131
- List < Message > messages = prompt . getInstructions ( );
116
+ ChatCompletionRequest request = createRequest ( prompt , true );
132
117
133
- List <ChatCompletionMessage > chatCompletionMessages = messages .stream ()
134
- .map (m -> new ChatCompletionMessage (m .getContent (),
135
- ChatCompletionMessage .Role .valueOf (m .getMessageType ().name ())))
136
- .toList ();
137
-
138
- Flux <OpenAiApi .ChatCompletionChunk > completionChunks = this .openAiApi
139
- .chatCompletionStream (new OpenAiApi .ChatCompletionRequest (chatCompletionMessages , this .model ,
140
- this .temperature .floatValue (), true ));
118
+ Flux <OpenAiApi .ChatCompletionChunk > completionChunks = this .openAiApi .chatCompletionStream (request );
141
119
142
120
// For chunked responses, only the first chunk contains the choice role.
143
121
// The rest of the chunks with same ID share the same role.
@@ -161,4 +139,34 @@ public Flux<ChatResponse> stream(Prompt prompt) {
161
139
});
162
140
}
163
141
142
+ /**
143
+ * Accessible for testing.
144
+ */
145
+ ChatCompletionRequest createRequest (Prompt prompt , boolean stream ) {
146
+
147
+ List <ChatCompletionMessage > chatCompletionMessages = prompt .getInstructions ()
148
+ .stream ()
149
+ .map (m -> new ChatCompletionMessage (m .getContent (),
150
+ ChatCompletionMessage .Role .valueOf (m .getMessageType ().name ())))
151
+ .toList ();
152
+
153
+ ChatCompletionRequest request = new ChatCompletionRequest (chatCompletionMessages , stream );
154
+
155
+ if (this .defaultOptions != null ) {
156
+ request = ModelOptionsUtils .merge (request , this .defaultOptions , ChatCompletionRequest .class );
157
+ }
158
+
159
+ if (prompt .getOptions () != null ) {
160
+ if (prompt .getOptions () instanceof OpenAiChatOptions runtimeOptions ) {
161
+ request = ModelOptionsUtils .merge (runtimeOptions , request , ChatCompletionRequest .class );
162
+ }
163
+ else {
164
+ throw new IllegalArgumentException ("Prompt options are not of type ChatCompletionRequest:"
165
+ + prompt .getOptions ().getClass ().getSimpleName ());
166
+ }
167
+ }
168
+
169
+ return request ;
170
+ }
171
+
164
172
}
0 commit comments