Skip to content

Commit 076726c

Browse files
committed
Add Mistral AI Function Calling support
- Make MistralAiChatClient extend the AbstractFunctionCallSupport and implement the necessary abstract classes. - Extend the MistralAiApi to include the latest (undocumented) changes providing function calling support as well. The Mistral AI is almost identical to the OpenAI API except it doesn't support parallel function colling (e.g. missing tool_call_id). - Add MistralAiApi function calling tests (implement the Mistral tutorial). - Extend the misral chat options to include the new API features and function call abstractions. - Extend Mistral's chat auto-configration to accomodate the function callback support. - Add ITs for testing function calling. - Remove redundant code from MistralAiApi and OpenAiApi. - Simplify and improve the HTTP error handling in OpenAiApi, ImageAiApi and MistralAiApi.
1 parent 5552a11 commit 076726c

File tree

24 files changed

+1564
-276
lines changed

24 files changed

+1564
-276
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java

Lines changed: 164 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
package org.springframework.ai.mistralai;
1717

1818
import java.time.Duration;
19+
import java.util.HashSet;
1920
import java.util.List;
2021
import java.util.Map;
22+
import java.util.Set;
2123
import java.util.concurrent.ConcurrentHashMap;
2224

2325
import org.slf4j.Logger;
@@ -32,18 +34,29 @@
3234
import org.springframework.ai.chat.prompt.ChatOptions;
3335
import org.springframework.ai.chat.prompt.Prompt;
3436
import org.springframework.ai.mistralai.api.MistralAiApi;
37+
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletion;
38+
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage;
39+
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ToolCall;
40+
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest;
3541
import org.springframework.ai.model.ModelOptionsUtils;
42+
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
43+
import org.springframework.ai.model.function.FunctionCallbackContext;
44+
import org.springframework.http.ResponseEntity;
3645
import org.springframework.retry.RetryCallback;
3746
import org.springframework.retry.RetryContext;
3847
import org.springframework.retry.RetryListener;
3948
import org.springframework.retry.support.RetryTemplate;
4049
import org.springframework.util.Assert;
50+
import org.springframework.util.CollectionUtils;
4151

4252
/**
4353
* @author Ricken Bazolo
54+
* @author Christian Tzolov
4455
* @since 0.8.1
4556
*/
46-
public class MistralAiChatClient implements ChatClient, StreamingChatClient {
57+
public class MistralAiChatClient extends
58+
AbstractFunctionCallSupport<MistralAiApi.ChatCompletionMessage, MistralAiApi.ChatCompletionRequest, ResponseEntity<MistralAiApi.ChatCompletion>>
59+
implements ChatClient, StreamingChatClient {
4760

4861
private final Logger log = LoggerFactory.getLogger(getClass());
4962

@@ -69,13 +82,6 @@ public <T extends Object, E extends Throwable> void onError(RetryContext context
6982
})
7083
.build();
7184

72-
public MistralAiChatClient(MistralAiApi mistralAiApi, MistralAiChatOptions options) {
73-
Assert.notNull(mistralAiApi, "MistralAiApi must not be null");
74-
Assert.notNull(options, "Options must not be null");
75-
this.mistralAiApi = mistralAiApi;
76-
this.defaultOptions = options;
77-
}
78-
7985
public MistralAiChatClient(MistralAiApi mistralAiApi) {
8086
this(mistralAiApi,
8187
MistralAiChatOptions.builder()
@@ -86,10 +92,79 @@ public MistralAiChatClient(MistralAiApi mistralAiApi) {
8692
.build());
8793
}
8894

95+
public MistralAiChatClient(MistralAiApi mistralAiApi, MistralAiChatOptions options) {
96+
this(mistralAiApi, options, null);
97+
}
98+
99+
public MistralAiChatClient(MistralAiApi mistralAiApi, MistralAiChatOptions options,
100+
FunctionCallbackContext functionCallbackContext) {
101+
super(functionCallbackContext);
102+
Assert.notNull(mistralAiApi, "MistralAiApi must not be null");
103+
Assert.notNull(options, "Options must not be null");
104+
this.mistralAiApi = mistralAiApi;
105+
this.defaultOptions = options;
106+
}
107+
108+
@Override
109+
public ChatResponse call(Prompt prompt) {
110+
// return retryTemplate.execute(ctx -> {
111+
var request = createRequest(prompt, false);
112+
113+
// var completionEntity = this.mistralAiApi.chatCompletionEntity(request);
114+
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
115+
116+
var chatCompletion = completionEntity.getBody();
117+
if (chatCompletion == null) {
118+
log.warn("No chat completion returned for prompt: {}", prompt);
119+
return new ChatResponse(List.of());
120+
}
121+
122+
List<Generation> generations = chatCompletion.choices()
123+
.stream()
124+
.map(choice -> new Generation(choice.message().content(), Map.of("role", choice.message().role().name()))
125+
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)))
126+
.toList();
127+
128+
return new ChatResponse(generations);
129+
// });
130+
}
131+
132+
@Override
133+
public Flux<ChatResponse> stream(Prompt prompt) {
134+
return retryTemplate.execute(ctx -> {
135+
var request = createRequest(prompt, true);
136+
137+
var completionChunks = this.mistralAiApi.chatCompletionStream(request);
138+
139+
// For chunked responses, only the first chunk contains the choice role.
140+
// The rest of the chunks with same ID share the same role.
141+
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
142+
143+
return completionChunks.map(chunk -> {
144+
String chunkId = chunk.id();
145+
List<Generation> generations = chunk.choices().stream().map(choice -> {
146+
if (choice.delta().role() != null) {
147+
roleMap.putIfAbsent(chunkId, choice.delta().role().name());
148+
}
149+
var generation = new Generation(choice.delta().content(), Map.of("role", roleMap.get(chunkId)));
150+
if (choice.finishReason() != null) {
151+
generation = generation
152+
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null));
153+
}
154+
return generation;
155+
}).toList();
156+
return new ChatResponse(generations);
157+
});
158+
});
159+
}
160+
89161
/**
90162
* Accessible for testing.
91163
*/
92-
public MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
164+
MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
165+
166+
Set<String> functionsForThisRequest = new HashSet<>();
167+
93168
var chatCompletionMessages = prompt.getInstructions()
94169
.stream()
95170
.map(m -> new MistralAiApi.ChatCompletionMessage(m.getContent(),
@@ -99,13 +174,23 @@ public MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean s
99174
var request = new MistralAiApi.ChatCompletionRequest(chatCompletionMessages, stream);
100175

101176
if (this.defaultOptions != null) {
177+
Set<String> defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions,
178+
!IS_RUNTIME_CALL);
179+
180+
functionsForThisRequest.addAll(defaultEnabledFunctions);
181+
102182
request = ModelOptionsUtils.merge(request, this.defaultOptions, MistralAiApi.ChatCompletionRequest.class);
103183
}
104184

105185
if (prompt.getOptions() != null) {
106186
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
107187
var updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, ChatOptions.class,
108188
MistralAiChatOptions.class);
189+
190+
Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
191+
IS_RUNTIME_CALL);
192+
functionsForThisRequest.addAll(promptEnabledFunctions);
193+
109194
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request,
110195
MistralAiApi.ChatCompletionRequest.class);
111196
}
@@ -115,60 +200,91 @@ public MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean s
115200
}
116201
}
117202

203+
// Add the enabled functions definitions to the request's tools parameter.
204+
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
205+
206+
if (stream) {
207+
throw new IllegalArgumentException("Currently tool functions are not supported in streaming mode");
208+
}
209+
210+
request = ModelOptionsUtils.merge(
211+
MistralAiChatOptions.builder().withTools(this.getFunctionTools(functionsForThisRequest)).build(),
212+
request, ChatCompletionRequest.class);
213+
}
214+
118215
return request;
119216
}
120217

218+
private List<MistralAiApi.FunctionTool> getFunctionTools(Set<String> functionNames) {
219+
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
220+
var function = new MistralAiApi.FunctionTool.Function(functionCallback.getDescription(),
221+
functionCallback.getName(), functionCallback.getInputTypeSchema());
222+
return new MistralAiApi.FunctionTool(function);
223+
}).toList();
224+
}
225+
226+
//
227+
// Function Calling Support
228+
//
121229
@Override
122-
public ChatResponse call(Prompt prompt) {
123-
return retryTemplate.execute(ctx -> {
124-
var request = createRequest(prompt, false);
230+
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
231+
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {
125232

126-
var completionEntity = this.mistralAiApi.chatCompletionEntity(request);
233+
// Every tool-call item requires a separate function call and a response (TOOL)
234+
// message.
235+
for (ToolCall toolCall : responseMessage.toolCalls()) {
127236

128-
var chatCompletion = completionEntity.getBody();
129-
if (chatCompletion == null) {
130-
log.warn("No chat completion returned for prompt: {}", prompt);
131-
return new ChatResponse(List.of());
237+
var functionName = toolCall.function().name();
238+
String functionArguments = toolCall.function().arguments();
239+
240+
if (!this.functionCallbackRegister.containsKey(functionName)) {
241+
throw new IllegalStateException("No function callback found for function name: " + functionName);
132242
}
133243

134-
List<Generation> generations = chatCompletion.choices()
135-
.stream()
136-
.map(choice -> new Generation(choice.message().content(),
137-
Map.of("role", choice.message().role().name()))
138-
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)))
139-
.toList();
244+
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
140245

141-
return new ChatResponse(generations);
142-
});
246+
// Add the function response to the conversation.
247+
conversationHistory
248+
.add(new ChatCompletionMessage(functionResponse, ChatCompletionMessage.Role.TOOL, functionName, null));
249+
}
250+
251+
// Recursively call chatCompletionWithTools until the model doesn't call a
252+
// functions anymore.
253+
ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationHistory, previousRequest.stream());
254+
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class);
255+
256+
return newRequest;
143257
}
144258

145259
@Override
146-
public Flux<ChatResponse> stream(Prompt prompt) {
147-
return retryTemplate.execute(ctx -> {
148-
var request = createRequest(prompt, true);
260+
protected List<ChatCompletionMessage> doGetUserMessages(ChatCompletionRequest request) {
261+
return request.messages();
262+
}
149263

150-
var completionChunks = this.mistralAiApi.chatCompletionStream(request);
264+
@Override
265+
protected ChatCompletionMessage doGetToolResponseMessage(ResponseEntity<ChatCompletion> chatCompletion) {
266+
return chatCompletion.getBody().choices().iterator().next().message();
267+
}
151268

152-
// For chunked responses, only the first chunk contains the choice role.
153-
// The rest of the chunks with same ID share the same role.
154-
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
269+
@Override
270+
protected ResponseEntity<ChatCompletion> doChatCompletion(ChatCompletionRequest request) {
271+
return this.mistralAiApi.chatCompletionEntity(request);
272+
}
155273

156-
return completionChunks.map(chunk -> {
157-
String chunkId = chunk.id();
158-
List<Generation> generations = chunk.choices().stream().map(choice -> {
159-
if (choice.delta().role() != null) {
160-
roleMap.putIfAbsent(chunkId, choice.delta().role().name());
161-
}
162-
var generation = new Generation(choice.delta().content(), Map.of("role", roleMap.get(chunkId)));
163-
if (choice.finishReason() != null) {
164-
generation = generation
165-
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null));
166-
}
167-
return generation;
168-
}).toList();
169-
return new ChatResponse(generations);
170-
});
171-
});
274+
@Override
275+
protected boolean isToolFunctionCall(ResponseEntity<ChatCompletion> chatCompletion) {
276+
277+
var body = chatCompletion.getBody();
278+
if (body == null) {
279+
return false;
280+
}
281+
282+
var choices = body.choices();
283+
if (CollectionUtils.isEmpty(choices)) {
284+
return false;
285+
}
286+
287+
return !CollectionUtils.isEmpty(choices.get(0).message().toolCalls());
172288
}
173289

174290
}

0 commit comments

Comments
 (0)