16
16
package org .springframework .ai .mistralai ;
17
17
18
18
import java .time .Duration ;
19
+ import java .util .HashSet ;
19
20
import java .util .List ;
20
21
import java .util .Map ;
22
+ import java .util .Set ;
21
23
import java .util .concurrent .ConcurrentHashMap ;
22
24
23
25
import org .slf4j .Logger ;
32
34
import org .springframework .ai .chat .prompt .ChatOptions ;
33
35
import org .springframework .ai .chat .prompt .Prompt ;
34
36
import org .springframework .ai .mistralai .api .MistralAiApi ;
37
+ import org .springframework .ai .mistralai .api .MistralAiApi .ChatCompletion ;
38
+ import org .springframework .ai .mistralai .api .MistralAiApi .ChatCompletionMessage ;
39
+ import org .springframework .ai .mistralai .api .MistralAiApi .ChatCompletionMessage .ToolCall ;
40
+ import org .springframework .ai .mistralai .api .MistralAiApi .ChatCompletionRequest ;
35
41
import org .springframework .ai .model .ModelOptionsUtils ;
42
+ import org .springframework .ai .model .function .AbstractFunctionCallSupport ;
43
+ import org .springframework .ai .model .function .FunctionCallbackContext ;
44
+ import org .springframework .http .ResponseEntity ;
36
45
import org .springframework .retry .RetryCallback ;
37
46
import org .springframework .retry .RetryContext ;
38
47
import org .springframework .retry .RetryListener ;
39
48
import org .springframework .retry .support .RetryTemplate ;
40
49
import org .springframework .util .Assert ;
50
+ import org .springframework .util .CollectionUtils ;
41
51
42
52
/**
43
53
* @author Ricken Bazolo
54
+ * @author Christian Tzolov
44
55
* @since 0.8.1
45
56
*/
46
- public class MistralAiChatClient implements ChatClient , StreamingChatClient {
57
+ public class MistralAiChatClient extends
58
+ AbstractFunctionCallSupport <MistralAiApi .ChatCompletionMessage , MistralAiApi .ChatCompletionRequest , ResponseEntity <MistralAiApi .ChatCompletion >>
59
+ implements ChatClient , StreamingChatClient {
47
60
48
61
private final Logger log = LoggerFactory .getLogger (getClass ());
49
62
@@ -69,13 +82,6 @@ public <T extends Object, E extends Throwable> void onError(RetryContext context
69
82
})
70
83
.build ();
71
84
72
- public MistralAiChatClient (MistralAiApi mistralAiApi , MistralAiChatOptions options ) {
73
- Assert .notNull (mistralAiApi , "MistralAiApi must not be null" );
74
- Assert .notNull (options , "Options must not be null" );
75
- this .mistralAiApi = mistralAiApi ;
76
- this .defaultOptions = options ;
77
- }
78
-
79
85
public MistralAiChatClient (MistralAiApi mistralAiApi ) {
80
86
this (mistralAiApi ,
81
87
MistralAiChatOptions .builder ()
@@ -86,10 +92,79 @@ public MistralAiChatClient(MistralAiApi mistralAiApi) {
86
92
.build ());
87
93
}
88
94
95
+ public MistralAiChatClient (MistralAiApi mistralAiApi , MistralAiChatOptions options ) {
96
+ this (mistralAiApi , options , null );
97
+ }
98
+
99
+ public MistralAiChatClient (MistralAiApi mistralAiApi , MistralAiChatOptions options ,
100
+ FunctionCallbackContext functionCallbackContext ) {
101
+ super (functionCallbackContext );
102
+ Assert .notNull (mistralAiApi , "MistralAiApi must not be null" );
103
+ Assert .notNull (options , "Options must not be null" );
104
+ this .mistralAiApi = mistralAiApi ;
105
+ this .defaultOptions = options ;
106
+ }
107
+
108
+ @ Override
109
+ public ChatResponse call (Prompt prompt ) {
110
+ // return retryTemplate.execute(ctx -> {
111
+ var request = createRequest (prompt , false );
112
+
113
+ // var completionEntity = this.mistralAiApi.chatCompletionEntity(request);
114
+ ResponseEntity <ChatCompletion > completionEntity = this .callWithFunctionSupport (request );
115
+
116
+ var chatCompletion = completionEntity .getBody ();
117
+ if (chatCompletion == null ) {
118
+ log .warn ("No chat completion returned for prompt: {}" , prompt );
119
+ return new ChatResponse (List .of ());
120
+ }
121
+
122
+ List <Generation > generations = chatCompletion .choices ()
123
+ .stream ()
124
+ .map (choice -> new Generation (choice .message ().content (), Map .of ("role" , choice .message ().role ().name ()))
125
+ .withGenerationMetadata (ChatGenerationMetadata .from (choice .finishReason ().name (), null )))
126
+ .toList ();
127
+
128
+ return new ChatResponse (generations );
129
+ // });
130
+ }
131
+
132
+ @ Override
133
+ public Flux <ChatResponse > stream (Prompt prompt ) {
134
+ return retryTemplate .execute (ctx -> {
135
+ var request = createRequest (prompt , true );
136
+
137
+ var completionChunks = this .mistralAiApi .chatCompletionStream (request );
138
+
139
+ // For chunked responses, only the first chunk contains the choice role.
140
+ // The rest of the chunks with same ID share the same role.
141
+ ConcurrentHashMap <String , String > roleMap = new ConcurrentHashMap <>();
142
+
143
+ return completionChunks .map (chunk -> {
144
+ String chunkId = chunk .id ();
145
+ List <Generation > generations = chunk .choices ().stream ().map (choice -> {
146
+ if (choice .delta ().role () != null ) {
147
+ roleMap .putIfAbsent (chunkId , choice .delta ().role ().name ());
148
+ }
149
+ var generation = new Generation (choice .delta ().content (), Map .of ("role" , roleMap .get (chunkId )));
150
+ if (choice .finishReason () != null ) {
151
+ generation = generation
152
+ .withGenerationMetadata (ChatGenerationMetadata .from (choice .finishReason ().name (), null ));
153
+ }
154
+ return generation ;
155
+ }).toList ();
156
+ return new ChatResponse (generations );
157
+ });
158
+ });
159
+ }
160
+
89
161
/**
90
162
* Accessible for testing.
91
163
*/
92
- public MistralAiApi .ChatCompletionRequest createRequest (Prompt prompt , boolean stream ) {
164
+ MistralAiApi .ChatCompletionRequest createRequest (Prompt prompt , boolean stream ) {
165
+
166
+ Set <String > functionsForThisRequest = new HashSet <>();
167
+
93
168
var chatCompletionMessages = prompt .getInstructions ()
94
169
.stream ()
95
170
.map (m -> new MistralAiApi .ChatCompletionMessage (m .getContent (),
@@ -99,13 +174,23 @@ public MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean s
99
174
var request = new MistralAiApi .ChatCompletionRequest (chatCompletionMessages , stream );
100
175
101
176
if (this .defaultOptions != null ) {
177
+ Set <String > defaultEnabledFunctions = this .handleFunctionCallbackConfigurations (this .defaultOptions ,
178
+ !IS_RUNTIME_CALL );
179
+
180
+ functionsForThisRequest .addAll (defaultEnabledFunctions );
181
+
102
182
request = ModelOptionsUtils .merge (request , this .defaultOptions , MistralAiApi .ChatCompletionRequest .class );
103
183
}
104
184
105
185
if (prompt .getOptions () != null ) {
106
186
if (prompt .getOptions () instanceof ChatOptions runtimeOptions ) {
107
187
var updatedRuntimeOptions = ModelOptionsUtils .copyToTarget (runtimeOptions , ChatOptions .class ,
108
188
MistralAiChatOptions .class );
189
+
190
+ Set <String > promptEnabledFunctions = this .handleFunctionCallbackConfigurations (updatedRuntimeOptions ,
191
+ IS_RUNTIME_CALL );
192
+ functionsForThisRequest .addAll (promptEnabledFunctions );
193
+
109
194
request = ModelOptionsUtils .merge (updatedRuntimeOptions , request ,
110
195
MistralAiApi .ChatCompletionRequest .class );
111
196
}
@@ -115,60 +200,91 @@ public MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean s
115
200
}
116
201
}
117
202
203
+ // Add the enabled functions definitions to the request's tools parameter.
204
+ if (!CollectionUtils .isEmpty (functionsForThisRequest )) {
205
+
206
+ if (stream ) {
207
+ throw new IllegalArgumentException ("Currently tool functions are not supported in streaming mode" );
208
+ }
209
+
210
+ request = ModelOptionsUtils .merge (
211
+ MistralAiChatOptions .builder ().withTools (this .getFunctionTools (functionsForThisRequest )).build (),
212
+ request , ChatCompletionRequest .class );
213
+ }
214
+
118
215
return request ;
119
216
}
120
217
218
+ private List <MistralAiApi .FunctionTool > getFunctionTools (Set <String > functionNames ) {
219
+ return this .resolveFunctionCallbacks (functionNames ).stream ().map (functionCallback -> {
220
+ var function = new MistralAiApi .FunctionTool .Function (functionCallback .getDescription (),
221
+ functionCallback .getName (), functionCallback .getInputTypeSchema ());
222
+ return new MistralAiApi .FunctionTool (function );
223
+ }).toList ();
224
+ }
225
+
226
+ //
227
+ // Function Calling Support
228
+ //
121
229
@ Override
122
- public ChatResponse call (Prompt prompt ) {
123
- return retryTemplate .execute (ctx -> {
124
- var request = createRequest (prompt , false );
230
+ protected ChatCompletionRequest doCreateToolResponseRequest (ChatCompletionRequest previousRequest ,
231
+ ChatCompletionMessage responseMessage , List <ChatCompletionMessage > conversationHistory ) {
125
232
126
- var completionEntity = this .mistralAiApi .chatCompletionEntity (request );
233
+ // Every tool-call item requires a separate function call and a response (TOOL)
234
+ // message.
235
+ for (ToolCall toolCall : responseMessage .toolCalls ()) {
127
236
128
- var chatCompletion = completionEntity .getBody ();
129
- if (chatCompletion == null ) {
130
- log .warn ("No chat completion returned for prompt: {}" , prompt );
131
- return new ChatResponse (List .of ());
237
+ var functionName = toolCall .function ().name ();
238
+ String functionArguments = toolCall .function ().arguments ();
239
+
240
+ if (!this .functionCallbackRegister .containsKey (functionName )) {
241
+ throw new IllegalStateException ("No function callback found for function name: " + functionName );
132
242
}
133
243
134
- List <Generation > generations = chatCompletion .choices ()
135
- .stream ()
136
- .map (choice -> new Generation (choice .message ().content (),
137
- Map .of ("role" , choice .message ().role ().name ()))
138
- .withGenerationMetadata (ChatGenerationMetadata .from (choice .finishReason ().name (), null )))
139
- .toList ();
244
+ String functionResponse = this .functionCallbackRegister .get (functionName ).call (functionArguments );
140
245
141
- return new ChatResponse (generations );
142
- });
246
+ // Add the function response to the conversation.
247
+ conversationHistory
248
+ .add (new ChatCompletionMessage (functionResponse , ChatCompletionMessage .Role .TOOL , functionName , null ));
249
+ }
250
+
251
+ // Recursively call chatCompletionWithTools until the model doesn't call a
252
+ // functions anymore.
253
+ ChatCompletionRequest newRequest = new ChatCompletionRequest (conversationHistory , previousRequest .stream ());
254
+ newRequest = ModelOptionsUtils .merge (newRequest , previousRequest , ChatCompletionRequest .class );
255
+
256
+ return newRequest ;
143
257
}
144
258
145
259
@ Override
146
- public Flux < ChatResponse > stream ( Prompt prompt ) {
147
- return retryTemplate . execute ( ctx -> {
148
- var request = createRequest ( prompt , true );
260
+ protected List < ChatCompletionMessage > doGetUserMessages ( ChatCompletionRequest request ) {
261
+ return request . messages ();
262
+ }
149
263
150
- var completionChunks = this .mistralAiApi .chatCompletionStream (request );
264
+ @ Override
265
+ protected ChatCompletionMessage doGetToolResponseMessage (ResponseEntity <ChatCompletion > chatCompletion ) {
266
+ return chatCompletion .getBody ().choices ().iterator ().next ().message ();
267
+ }
151
268
152
- // For chunked responses, only the first chunk contains the choice role.
153
- // The rest of the chunks with same ID share the same role.
154
- ConcurrentHashMap <String , String > roleMap = new ConcurrentHashMap <>();
269
+ @ Override
270
+ protected ResponseEntity <ChatCompletion > doChatCompletion (ChatCompletionRequest request ) {
271
+ return this .mistralAiApi .chatCompletionEntity (request );
272
+ }
155
273
156
- return completionChunks .map (chunk -> {
157
- String chunkId = chunk .id ();
158
- List <Generation > generations = chunk .choices ().stream ().map (choice -> {
159
- if (choice .delta ().role () != null ) {
160
- roleMap .putIfAbsent (chunkId , choice .delta ().role ().name ());
161
- }
162
- var generation = new Generation (choice .delta ().content (), Map .of ("role" , roleMap .get (chunkId )));
163
- if (choice .finishReason () != null ) {
164
- generation = generation
165
- .withGenerationMetadata (ChatGenerationMetadata .from (choice .finishReason ().name (), null ));
166
- }
167
- return generation ;
168
- }).toList ();
169
- return new ChatResponse (generations );
170
- });
171
- });
274
+ @ Override
275
+ protected boolean isToolFunctionCall (ResponseEntity <ChatCompletion > chatCompletion ) {
276
+
277
+ var body = chatCompletion .getBody ();
278
+ if (body == null ) {
279
+ return false ;
280
+ }
281
+
282
+ var choices = body .choices ();
283
+ if (CollectionUtils .isEmpty (choices )) {
284
+ return false ;
285
+ }
286
+
287
+ return !CollectionUtils .isEmpty (choices .get (0 ).message ().toolCalls ());
172
288
}
173
289
174
290
}
0 commit comments