Skip to content

Commit e92616b

Browse files
ricken07tzolov
authored andcommitted
fix(mistral) Added index of tool call in the list of tool calls
Additiona fixes: API key validation and tool calling backward compatibility - Fix API key validation in OpenAiApi builder - Standardize API key validation using Assert.notNull - Add backward compatibility support for FunctionCallback in tool calling - Update integration tests to use LegacyToolCallingManager Co-authored-by:Christian Tzolov <christian.tzolov@broadcom.com> Signed-off-by: Ricken Bazolo <ricken.bazolo@gmail.com> Signed-off-by: Christian Tzolov <christian.tzolov@broadcom.com>
1 parent 011deb0 commit e92616b

File tree

8 files changed

+48
-39
lines changed

8 files changed

+48
-39
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ else if (message instanceof AssistantMessage assistantMessage) {
384384
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
385385
toolCalls = assistantMessage.getToolCalls().stream().map(toolCall -> {
386386
var function = new ChatCompletionFunction(toolCall.name(), toolCall.arguments());
387-
return new ToolCall(toolCall.id(), toolCall.type(), function);
387+
return new ToolCall(toolCall.id(), toolCall.type(), function, null);
388388
}).toList();
389389
}
390390

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,10 +833,11 @@ public enum Role {
833833
* @param type The type of tool call the output is required for. For now, this is
834834
* always function.
835835
* @param function The function definition.
836+
* @param index The index of the tool call in the list of tool calls.
836837
*/
837838
@JsonInclude(Include.NON_NULL)
838839
public record ToolCall(@JsonProperty("id") String id, @JsonProperty("type") String type,
839-
@JsonProperty("function") ChatCompletionFunction function) {
840+
@JsonProperty("function") ChatCompletionFunction function, @JsonProperty("index") Integer index) {
840841

841842
}
842843

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiStreamFunctionCallingHelper.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.util.ArrayList;
2020
import java.util.List;
21+
import java.util.Objects;
2122
import java.util.Optional;
2223
import java.util.UUID;
2324

@@ -74,16 +75,16 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
7475
Optional<String> id = current.delta()
7576
.toolCalls()
7677
.stream()
77-
.filter(tool -> tool.id() != null)
78-
.map(tool -> tool.id())
78+
.map(ToolCall::id)
79+
.filter(Objects::nonNull)
7980
.findFirst();
80-
if (!id.isPresent()) {
81+
if (id.isEmpty()) {
8182
var newId = UUID.randomUUID().toString();
8283

8384
var toolCallsWithID = current.delta()
8485
.toolCalls()
8586
.stream()
86-
.map(toolCall -> new ToolCall(newId, "function", toolCall.function()))
87+
.map(toolCall -> new ToolCall(newId, "function", toolCall.function(), toolCall.index()))
8788
.toList();
8889

8990
var role = current.delta().role() != null ? current.delta().role() : Role.ASSISTANT;
@@ -151,7 +152,8 @@ private ToolCall merge(ToolCall previous, ToolCall current) {
151152
String id = (current.id() != null ? current.id() : previous.id());
152153
String type = (current.type() != null ? current.type() : previous.type());
153154
ChatCompletionFunction function = merge(previous.function(), current.function());
154-
return new ToolCall(id, type, function);
155+
Integer index = (current.index() != null ? current.index() : previous.index());
156+
return new ToolCall(id, type, function, index);
155157
}
156158

157159
private ChatCompletionFunction merge(ChatCompletionFunction previous, ChatCompletionFunction current) {

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,6 +1675,12 @@ public Builder apiKey(ApiKey apiKey) {
16751675
return this;
16761676
}
16771677

1678+
public Builder apiKey(String simpleApiKey) {
1679+
Assert.notNull(simpleApiKey, "apiKey cannot be null");
1680+
this.apiKey = new SimpleApiKey(simpleApiKey);
1681+
return this;
1682+
}
1683+
16781684
public Builder headers(MultiValueMap<String, String> headers) {
16791685
Assert.notNull(headers, "headers cannot be null");
16801686
this.headers = headers;

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/MistralWithOpenAiChatModelIT.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,16 @@
4646
import org.springframework.ai.converter.ListOutputConverter;
4747
import org.springframework.ai.converter.MapOutputConverter;
4848
import org.springframework.ai.model.Media;
49+
import org.springframework.ai.model.SimpleApiKey;
4950
import org.springframework.ai.model.function.FunctionCallback;
51+
import org.springframework.ai.model.tool.LegacyToolCallingManager;
5052
import org.springframework.ai.openai.OpenAiChatModel;
5153
import org.springframework.ai.openai.OpenAiChatOptions;
5254
import org.springframework.ai.openai.api.OpenAiApi;
5355
import org.springframework.ai.openai.api.tool.MockWeatherService;
5456
import org.springframework.ai.openai.chat.ActorsFilms;
57+
import org.springframework.ai.tool.function.FunctionToolCallback;
58+
import org.springframework.ai.tool.method.MethodToolCallback;
5559
import org.springframework.beans.factory.annotation.Autowired;
5660
import org.springframework.beans.factory.annotation.Value;
5761
import org.springframework.boot.SpringBootConfiguration;
@@ -61,6 +65,7 @@
6165
import org.springframework.core.io.ClassPathResource;
6266
import org.springframework.core.io.Resource;
6367
import org.springframework.util.MimeTypeUtils;
68+
import org.springframework.util.ReflectionUtils;
6469

6570
import static org.assertj.core.api.Assertions.assertThat;
6671

@@ -251,8 +256,7 @@ void functionCallTest(String modelName) {
251256

252257
var promptOptions = OpenAiChatOptions.builder()
253258
.model(modelName)
254-
.functionCallbacks(List.of(FunctionCallback.builder()
255-
.function("getCurrentWeather", new MockWeatherService())
259+
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
256260
.description("Get the weather in location")
257261
.inputType(MockWeatherService.Request.class)
258262
.build()))
@@ -276,8 +280,7 @@ void streamFunctionCallTest(String modelName) {
276280

277281
var promptOptions = OpenAiChatOptions.builder()
278282
.model(modelName)
279-
.functionCallbacks(List.of(FunctionCallback.builder()
280-
.function("getCurrentWeather", new MockWeatherService())
283+
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
281284
.description("Get the weather in location")
282285
.inputType(MockWeatherService.Request.class)
283286
.build()))
@@ -388,12 +391,16 @@ static class Config {
388391

389392
@Bean
390393
public OpenAiApi chatCompletionApi() {
391-
return new OpenAiApi(MISTRAL_BASE_URL, System.getenv("MISTRAL_AI_API_KEY"));
394+
return OpenAiApi.builder().baseUrl(MISTRAL_BASE_URL).apiKey(System.getenv("MISTRAL_AI_API_KEY")).build();
392395
}
393396

394397
@Bean
395398
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
396-
return new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().model(MISTRAL_DEFAULT_MODEL).build());
399+
return OpenAiChatModel.builder()
400+
.openAiApi(openAiApi)
401+
.toolCallingManager(LegacyToolCallingManager.builder().build())
402+
.defaultOptions(OpenAiChatOptions.builder().model(MISTRAL_DEFAULT_MODEL).build())
403+
.build();
397404
}
398405

399406
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/OllamaWithOpenAiChatModelIT.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,13 @@
5050
import org.springframework.ai.converter.MapOutputConverter;
5151
import org.springframework.ai.model.Media;
5252
import org.springframework.ai.model.function.FunctionCallback;
53+
import org.springframework.ai.model.tool.LegacyToolCallingManager;
5354
import org.springframework.ai.openai.OpenAiChatModel;
5455
import org.springframework.ai.openai.OpenAiChatOptions;
5556
import org.springframework.ai.openai.api.OpenAiApi;
5657
import org.springframework.ai.openai.api.tool.MockWeatherService;
5758
import org.springframework.ai.openai.chat.ActorsFilms;
59+
import org.springframework.ai.tool.function.FunctionToolCallback;
5860
import org.springframework.beans.factory.annotation.Autowired;
5961
import org.springframework.beans.factory.annotation.Value;
6062
import org.springframework.boot.SpringBootConfiguration;
@@ -272,8 +274,7 @@ void functionCallTest(String modelName) {
272274
// Note for Ollama you must set the tool choice to explicitly. Unlike OpenAI
273275
// (which defaults to "auto") Ollama defaults to "nono"
274276
.toolChoice("auto")
275-
.functionCallbacks(List.of(FunctionCallback.builder()
276-
.function("getCurrentWeather", new MockWeatherService())
277+
.toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService())
277278
.description("Get the weather in location")
278279
.inputType(MockWeatherService.Request.class)
279280
.build()))
@@ -412,12 +413,16 @@ static class Config {
412413

413414
@Bean
414415
public OpenAiApi chatCompletionApi() {
415-
return new OpenAiApi(baseUrl, "");
416+
return OpenAiApi.builder().baseUrl(baseUrl).apiKey("").build();
416417
}
417418

418419
@Bean
419420
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
420-
return new OpenAiChatModel(openAiApi, OpenAiChatOptions.builder().model(DEFAULT_OLLAMA_MODEL).build());
421+
return OpenAiChatModel.builder()
422+
.openAiApi(openAiApi)
423+
.toolCallingManager(LegacyToolCallingManager.builder().build())
424+
.defaultOptions(OpenAiChatOptions.builder().model(DEFAULT_OLLAMA_MODEL).build())
425+
.build();
421426
}
422427

423428
}

spring-ai-core/src/main/java/org/springframework/ai/model/SimpleApiKey.java

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,46 +23,28 @@
2323
* be refreshed or rotated.
2424
*
2525
* @author Adib Saikali
26+
* @author Christian Tzolov
2627
* @since 1.0.0
2728
*/
28-
public final class SimpleApiKey implements ApiKey {
29-
30-
private final String value;
29+
public record SimpleApiKey(String value) implements ApiKey {
3130

3231
/**
3332
* Create a new SimpleApiKey.
3433
* @param value the API key value, must not be null or empty
3534
* @throws IllegalArgumentException if value is null or empty
3635
*/
3736
public SimpleApiKey(String value) {
38-
Assert.hasText(value, "API key value must not be null or empty");
37+
Assert.notNull(value, "API key value must not be null or empty");
3938
this.value = value;
4039
}
4140

4241
@Override
4342
public String getValue() {
44-
return this.value;
43+
return this.value();
4544
}
4645

4746
@Override
4847
public String toString() {
4948
return "SimpleApiKey{value='***'}";
5049
}
51-
52-
@Override
53-
public boolean equals(Object o) {
54-
if (this == o) {
55-
return true;
56-
}
57-
if (!(o instanceof SimpleApiKey that)) {
58-
return false;
59-
}
60-
return this.value.equals(that.value);
61-
}
62-
63-
@Override
64-
public int hashCode() {
65-
return this.value.hashCode();
66-
}
67-
6850
}

spring-ai-core/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,12 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions)
214214
else if (toolCallback instanceof ToolCallback callback) {
215215
returnDirect = returnDirect && callback.getToolMetadata().returnDirect();
216216
}
217+
else if (returnDirect == null) {
218+
// This is a temporary solution to ensure backward compatibility with
219+
// FunctionCallback.
220+
// TODO: remove this block when FunctionCallback is removed.
221+
returnDirect = false;
222+
}
217223

218224
String toolResult;
219225
try {

0 commit comments

Comments
 (0)