Skip to content

Commit 97537e0

Browse files
committed
Improve Ollama Options support
- Rename OllamaChatClient#withOptions(...) method to OllamaChatClient#withDefaultOptions(...) - Rename OllamaEmbeddingClient#withOptions(...) method to OllamaEmbeddingClient#withDefaultOptions(...) - Remove the Chat/Embedding Client model field by defaultOptions.mode one. - Correct default and runtime OllamaOptions merging implemented. - Added support for portable ChatOptions. - OllamaOptions a synthetic field ‘model’ not supported by Ollama API but used by the OllamaChatClient and OllamaEbeddingClients. The model field is removed before calling the OllamaApi. - Update the IT ollama docker image to 0.1.23 - Set Mistral as the default model. - Extend and improve the ITs - Add tests for testing the chat and embedding request creation and options merging logic. - Minor code-style improvements. - Split the ollama.adoc into embeddings/ollama-embeddings.addoc and clients/ollama-chat.adoc. - Improve the documentation to explain how to configure and use the Ollama Chat and Embedding clients manually or with the help of the auto-configuraitons. - Clarify the docs property sections.
1 parent 60a60b4 commit 97537e0

File tree

22 files changed

+700
-199
lines changed

22 files changed

+700
-199
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatClient.java

Lines changed: 50 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2023 the original author or authors.
2+
* Copyright 2023-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,33 +16,28 @@
1616

1717
package org.springframework.ai.ollama;
1818

19-
import java.util.HashMap;
2019
import java.util.List;
21-
import java.util.Map;
22-
import java.util.stream.Collectors;
2320

24-
import com.fasterxml.jackson.core.JsonProcessingException;
25-
import com.fasterxml.jackson.core.type.TypeReference;
26-
import com.fasterxml.jackson.databind.ObjectMapper;
27-
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
2821
import reactor.core.publisher.Flux;
2922

3023
import org.springframework.ai.chat.ChatClient;
24+
import org.springframework.ai.chat.ChatOptions;
3125
import org.springframework.ai.chat.ChatResponse;
3226
import org.springframework.ai.chat.Generation;
3327
import org.springframework.ai.chat.StreamingChatClient;
28+
import org.springframework.ai.chat.messages.Message;
29+
import org.springframework.ai.chat.messages.MessageType;
30+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
3431
import org.springframework.ai.chat.metadata.Usage;
32+
import org.springframework.ai.chat.prompt.Prompt;
33+
import org.springframework.ai.model.ModelOptionsUtils;
3534
import org.springframework.ai.ollama.api.OllamaApi;
36-
import org.springframework.ai.ollama.api.OllamaOptions;
37-
import org.springframework.ai.ollama.api.OllamaApi.ChatRequest;
3835
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
39-
40-
import org.springframework.ai.chat.prompt.Prompt;
41-
import org.springframework.ai.chat.messages.Message;
42-
import org.springframework.ai.chat.messages.MessageType;
36+
import org.springframework.ai.ollama.api.OllamaOptions;
37+
import org.springframework.util.StringUtils;
4338

4439
/**
45-
* {@link ChatClient} implementation for {@literal Ollma}.
40+
* {@link ChatClient} implementation for {@literal Ollama}.
4641
*
4742
* Ollama allows developers to run large language models and generate embeddings locally.
4843
* It supports open-source models available on [Ollama AI
@@ -57,37 +52,38 @@
5752
*/
5853
public class OllamaChatClient implements ChatClient, StreamingChatClient {
5954

55+
/**
56+
* Low-level Ollama API library.
57+
*/
6058
private final OllamaApi chatApi;
6159

62-
private String model = "orca-mini";
63-
64-
private Map<String, Object> clientOptions;
65-
66-
private final static ObjectMapper OBJECT_MAPPER = new ObjectMapper();
60+
/**
61+
* Default options to be used for all chat requests.
62+
*/
63+
private OllamaOptions defaultOptions = OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL);
6764

6865
public OllamaChatClient(OllamaApi chatApi) {
6966
this.chatApi = chatApi;
7067
}
7168

69+
/**
70+
* @deprecated Use {@link OllamaOptions#setModel} instead.
71+
*/
72+
@Deprecated
7273
public OllamaChatClient withModel(String model) {
73-
this.model = model;
74+
this.defaultOptions.setModel(model);
7475
return this;
7576
}
7677

77-
public OllamaChatClient withOptions(Map<String, Object> options) {
78-
this.clientOptions = options;
79-
return this;
80-
}
81-
82-
public OllamaChatClient withOptions(OllamaOptions options) {
83-
this.clientOptions = options.toMap();
78+
public OllamaChatClient withDefaultOptions(OllamaOptions options) {
79+
this.defaultOptions = options;
8480
return this;
8581
}
8682

8783
@Override
8884
public ChatResponse call(Prompt prompt) {
8985

90-
OllamaApi.ChatResponse response = this.chatApi.chat(request(prompt, this.model, false));
86+
OllamaApi.ChatResponse response = this.chatApi.chat(ollamaChatRequest(prompt, false));
9187
var generator = new Generation(response.message().content());
9288
if (response.promptEvalCount() != null && response.evalCount() != null) {
9389
generator = generator
@@ -99,7 +95,7 @@ public ChatResponse call(Prompt prompt) {
9995
@Override
10096
public Flux<ChatResponse> stream(Prompt prompt) {
10197

102-
Flux<OllamaApi.ChatResponse> response = this.chatApi.streamingChat(request(prompt, this.model, true));
98+
Flux<OllamaApi.ChatResponse> response = this.chatApi.streamingChat(ollamaChatRequest(prompt, true));
10399

104100
return response.map(chunk -> {
105101
Generation generation = (chunk.message() != null) ? new Generation(chunk.message().content())
@@ -127,7 +123,10 @@ public Long getGenerationTokens() {
127123
};
128124
}
129125

130-
private OllamaApi.ChatRequest request(Prompt prompt, String model, boolean stream) {
126+
/**
127+
* Package access for testing.
128+
*/
129+
OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) {
131130

132131
List<OllamaApi.Message> ollamaMessages = prompt.getInstructions()
133132
.stream()
@@ -138,49 +137,31 @@ private OllamaApi.ChatRequest request(Prompt prompt, String model, boolean strea
138137
.toList();
139138

140139
// runtime options
141-
Map<String, Object> clientOptionsToUse = merge(prompt.getOptions(), this.clientOptions, HashMap.class);
142-
143-
return ChatRequest.builder(model)
144-
.withStream(stream)
145-
.withMessages(ollamaMessages)
146-
.withOptions(clientOptionsToUse)
147-
.build();
148-
}
149-
150-
public static Map<String, Object> objectToMap(Object source) {
151-
try {
152-
String json = OBJECT_MAPPER.writeValueAsString(source);
153-
return OBJECT_MAPPER.readValue(json, new TypeReference<Map<String, Object>>() {
154-
});
155-
}
156-
catch (JsonProcessingException e) {
157-
throw new RuntimeException(e);
140+
OllamaOptions runtimeOptions = null;
141+
if (prompt.getOptions() != null) {
142+
if (prompt.getOptions() instanceof ChatOptions runtimeChatOptions) {
143+
runtimeOptions = ModelOptionsUtils.copyToTarget(runtimeChatOptions, ChatOptions.class,
144+
OllamaOptions.class);
145+
}
146+
else {
147+
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
148+
+ prompt.getOptions().getClass().getSimpleName());
149+
}
158150
}
159-
}
160151

161-
public static <T> T mapToClass(Map<String, Object> source, Class<T> clazz) {
162-
try {
163-
String json = OBJECT_MAPPER.writeValueAsString(source);
164-
return OBJECT_MAPPER.readValue(json, clazz);
165-
}
166-
catch (JsonProcessingException e) {
167-
throw new RuntimeException(e);
168-
}
169-
}
152+
OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);
170153

171-
public static <T> T merge(Object source, Object target, Class<T> clazz) {
172-
if (source == null) {
173-
source = Map.of();
154+
// Override the model.
155+
if (!StringUtils.hasText(mergedOptions.getModel())) {
156+
throw new IllegalArgumentException("Model is not set!");
174157
}
175-
Map<String, Object> sourceMap = objectToMap(source);
176-
Map<String, Object> targetMap = objectToMap(target);
177-
178-
targetMap.putAll(sourceMap.entrySet()
179-
.stream()
180-
.filter(e -> e.getValue() != null)
181-
.collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())));
182158

183-
return mapToClass(targetMap, clazz);
159+
String model = mergedOptions.getModel();
160+
return OllamaApi.ChatRequest.builder(model)
161+
.withStream(stream)
162+
.withMessages(ollamaMessages)
163+
.withOptions(mergedOptions)
164+
.build();
184165
}
185166

186167
private OllamaApi.Message.Role toRole(Message message) {

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingClient.java

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2023 the original author or authors.
2+
* Copyright 2023-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,23 +18,26 @@
1818

1919
import java.util.ArrayList;
2020
import java.util.List;
21-
import java.util.Map;
2221
import java.util.concurrent.atomic.AtomicInteger;
2322

2423
import org.slf4j.Logger;
2524
import org.slf4j.LoggerFactory;
25+
2626
import org.springframework.ai.document.Document;
2727
import org.springframework.ai.embedding.AbstractEmbeddingClient;
2828
import org.springframework.ai.embedding.Embedding;
2929
import org.springframework.ai.embedding.EmbeddingClient;
30+
import org.springframework.ai.embedding.EmbeddingOptions;
3031
import org.springframework.ai.embedding.EmbeddingResponse;
32+
import org.springframework.ai.model.ModelOptionsUtils;
3133
import org.springframework.ai.ollama.api.OllamaApi;
3234
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingRequest;
3335
import org.springframework.ai.ollama.api.OllamaOptions;
3436
import org.springframework.util.Assert;
37+
import org.springframework.util.StringUtils;
3538

3639
/**
37-
* {@link EmbeddingClient} implementation for {@literal Ollma}.
40+
* {@link EmbeddingClient} implementation for {@literal Ollama}.
3841
*
3942
* Ollama allows developers to run large language models and generate embeddings locally.
4043
* It supports open-source models available on [Ollama AI
@@ -43,39 +46,38 @@
4346
* Examples of models supported: - Llama 2 (7B parameters, 3.8GB size) - Mistral (7B
4447
* parameters, 4.1GB size)
4548
*
46-
*
47-
*
4849
* Please refer to the <a href="https://ollama.ai/">official Ollama website</a> for the
4950
* most up-to-date information on available models.
5051
*
5152
* @author Christian Tzolov
53+
* @since 0.8.0
5254
*/
5355
public class OllamaEmbeddingClient extends AbstractEmbeddingClient {
5456

5557
private final Logger logger = LoggerFactory.getLogger(getClass());
5658

5759
private final OllamaApi ollamaApi;
5860

59-
private String model = "orca-mini";
60-
61-
private Map<String, Object> clientOptions;
61+
/**
62+
* Default options to be used for all chat requests.
63+
*/
64+
private OllamaOptions defaultOptions = OllamaOptions.create().withModel(OllamaOptions.DEFAULT_MODEL);
6265

6366
public OllamaEmbeddingClient(OllamaApi ollamaApi) {
6467
this.ollamaApi = ollamaApi;
6568
}
6669

70+
/**
71+
* @deprecated Use {@link OllamaOptions#setModel} instead.
72+
*/
73+
@Deprecated
6774
public OllamaEmbeddingClient withModel(String model) {
68-
this.model = model;
75+
this.defaultOptions.setModel(model);
6976
return this;
7077
}
7178

72-
public OllamaEmbeddingClient withOptions(Map<String, Object> options) {
73-
this.clientOptions = options;
74-
return this;
75-
}
76-
77-
public OllamaEmbeddingClient withOptions(OllamaOptions options) {
78-
this.clientOptions = options.toMap();
79+
public OllamaEmbeddingClient withDefaultOptions(OllamaOptions options) {
80+
this.defaultOptions = options;
7981
return this;
8082
}
8183

@@ -94,15 +96,51 @@ public EmbeddingResponse call(org.springframework.ai.embedding.EmbeddingRequest
9496

9597
List<List<Double>> embeddingList = new ArrayList<>();
9698
for (String inputContent : request.getInstructions()) {
97-
OllamaApi.EmbeddingResponse response = this.ollamaApi
98-
.embeddings(new EmbeddingRequest(this.model, inputContent, this.clientOptions));
99+
100+
var ollamaEmbeddingRequest = ollamaEmbeddingRequest(inputContent, request.getOptions());
101+
102+
OllamaApi.EmbeddingResponse response = this.ollamaApi.embeddings(ollamaEmbeddingRequest);
103+
99104
embeddingList.add(response.embedding());
100105
}
101106
var indexCounter = new AtomicInteger(0);
107+
102108
List<Embedding> embeddings = embeddingList.stream()
103109
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
104110
.toList();
105111
return new EmbeddingResponse(embeddings);
106112
}
107113

114+
/**
115+
* Package access for testing.
116+
*/
117+
OllamaApi.EmbeddingRequest ollamaEmbeddingRequest(String inputContent, EmbeddingOptions options) {
118+
119+
// runtime options
120+
OllamaOptions runtimeOptions = null;
121+
if (options != null) {
122+
if (options instanceof OllamaOptions ollamaOptions) {
123+
runtimeOptions = ollamaOptions;
124+
}
125+
else if (options instanceof EmbeddingOptions embeddingOptions) {
126+
// currently EmbeddingOptions does not have any portable options to be
127+
// merged.
128+
runtimeOptions = null;
129+
}
130+
else {
131+
throw new IllegalArgumentException("Request embedding options are not of type EmbeddingOptions: "
132+
+ options.getClass().getSimpleName());
133+
}
134+
}
135+
136+
OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);
137+
138+
// Override the model.
139+
if (!StringUtils.hasText(mergedOptions.getModel())) {
140+
throw new IllegalArgumentException("Model is not set!");
141+
}
142+
String model = mergedOptions.getModel();
143+
return new EmbeddingRequest(model, inputContent, OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()));
144+
}
145+
108146
}

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2023 the original author or authors.
2+
* Copyright 2023-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -21,6 +21,7 @@
2121
import java.time.Instant;
2222
import java.util.List;
2323
import java.util.Map;
24+
import java.util.Objects;
2425
import java.util.function.Consumer;
2526

2627
import com.fasterxml.jackson.annotation.JsonInclude;
@@ -445,19 +446,21 @@ public Builder withFormat(String format) {
445446
}
446447

447448
public Builder withOptions(Map<String, Object> options) {
448-
this.options = options;
449+
Objects.requireNonNullElse(options, "The options can not be null.");
450+
451+
this.options = OllamaOptions.filterNonSupportedFields(options);
449452
return this;
450453
}
451454

452455
public Builder withOptions(OllamaOptions options) {
453-
this.options = options.toMap();
456+
Objects.requireNonNullElse(options, "The options can not be null.");
457+
this.options = OllamaOptions.filterNonSupportedFields(options.toMap());
454458
return this;
455459
}
456460

457461
public ChatRequest build() {
458462
return new ChatRequest(model, messages, stream, format, options);
459463
}
460-
461464
}
462465
}
463466

0 commit comments

Comments
 (0)