Skip to content

Commit 4cacbe8

Browse files
committed
Refactor Ollama embedding model implementation
- Update OllamaEmbeddingModel to support batch embedding requests - Rename EmbeddingRequest/Response to EmbeddingsRequest/Response - Add truncate option to control input truncation - Update documentation and tests for new embedding API - Remove deprecated withModel and withDefaultOptions methods - Adjust default values for various Ollama model options - Include response medata with response model name Resolves #1158
1 parent 3978e8e commit 4cacbe8

File tree

10 files changed

+371
-156
lines changed

10 files changed

+371
-156
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java

Lines changed: 56 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,26 @@
1515
*/
1616
package org.springframework.ai.ollama;
1717

18-
import java.util.ArrayList;
18+
import java.time.Duration;
1919
import java.util.List;
2020
import java.util.concurrent.atomic.AtomicInteger;
21+
import java.util.regex.Matcher;
22+
import java.util.regex.Pattern;
2123

2224
import org.slf4j.Logger;
2325
import org.slf4j.LoggerFactory;
24-
26+
import org.springframework.ai.chat.metadata.EmptyUsage;
2527
import org.springframework.ai.document.Document;
2628
import org.springframework.ai.embedding.AbstractEmbeddingModel;
2729
import org.springframework.ai.embedding.Embedding;
2830
import org.springframework.ai.embedding.EmbeddingModel;
2931
import org.springframework.ai.embedding.EmbeddingOptions;
32+
import org.springframework.ai.embedding.EmbeddingRequest;
3033
import org.springframework.ai.embedding.EmbeddingResponse;
34+
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
3135
import org.springframework.ai.model.ModelOptionsUtils;
3236
import org.springframework.ai.ollama.api.OllamaApi;
33-
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingRequest;
37+
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse;
3438
import org.springframework.ai.ollama.api.OllamaOptions;
3539
import org.springframework.util.Assert;
3640
import org.springframework.util.StringUtils;
@@ -71,70 +75,43 @@ public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions) {
7175
this.defaultOptions = defaultOptions;
7276
}
7377

74-
/**
75-
* @deprecated Use {@link OllamaOptions#setModel} instead.
76-
*/
77-
@Deprecated
78-
public OllamaEmbeddingModel withModel(String model) {
79-
this.defaultOptions.setModel(model);
80-
return this;
81-
}
82-
83-
/**
84-
* @deprecated Use {@link OllamaOptions} constructor instead.
85-
*/
86-
@Deprecated
87-
public OllamaEmbeddingModel withDefaultOptions(OllamaOptions options) {
88-
this.defaultOptions = options;
89-
return this;
90-
}
91-
9278
@Override
9379
public List<Double> embed(Document document) {
9480
return embed(document.getContent());
9581
}
9682

9783
@Override
98-
public EmbeddingResponse call(org.springframework.ai.embedding.EmbeddingRequest request) {
99-
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
100-
if (request.getInstructions().size() != 1) {
101-
logger.warn(
102-
"Ollama Embedding does not support batch embedding. Will make multiple API calls to embed(Document)");
103-
}
84+
public EmbeddingResponse call(EmbeddingRequest request) {
10485

105-
List<List<Double>> embeddingList = new ArrayList<>();
106-
for (String inputContent : request.getInstructions()) {
86+
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
10787

108-
EmbeddingRequest ollamaEmbeddingRequest = ollamaEmbeddingRequest(inputContent, request.getOptions());
88+
OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest = ollamaEmbeddingRequest(request.getInstructions(),
89+
request.getOptions());
10990

110-
OllamaApi.EmbeddingResponse response = this.ollamaApi.embeddings(ollamaEmbeddingRequest);
91+
EmbeddingsResponse response = this.ollamaApi.embed(ollamaEmbeddingRequest);
11192

112-
embeddingList.add(response.embedding());
113-
}
11493
AtomicInteger indexCounter = new AtomicInteger(0);
11594

116-
List<Embedding> embeddings = embeddingList.stream()
95+
List<Embedding> embeddings = response.embeddings()
96+
.stream()
11797
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
11898
.toList();
119-
return new EmbeddingResponse(embeddings);
99+
100+
EmbeddingResponseMetadata embeddingResponseMetadata = new EmbeddingResponseMetadata(response.model(),
101+
new EmptyUsage());
102+
103+
return new EmbeddingResponse(embeddings, embeddingResponseMetadata);
120104
}
121105

122106
/**
123107
* Package access for testing.
124108
*/
125-
OllamaApi.EmbeddingRequest ollamaEmbeddingRequest(String inputContent, EmbeddingOptions options) {
109+
OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List<String> inputContent, EmbeddingOptions options) {
126110

127111
// runtime options
128112
OllamaOptions runtimeOptions = null;
129-
if (options != null) {
130-
if (options instanceof OllamaOptions ollamaOptions) {
131-
runtimeOptions = ollamaOptions;
132-
}
133-
else {
134-
// currently EmbeddingOptions does not have any portable options to be
135-
// merged.
136-
runtimeOptions = null;
137-
}
113+
if (options != null && options instanceof OllamaOptions ollamaOptions) {
114+
runtimeOptions = ollamaOptions;
138115
}
139116

140117
OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);
@@ -144,8 +121,40 @@ OllamaApi.EmbeddingRequest ollamaEmbeddingRequest(String inputContent, Embedding
144121
throw new IllegalArgumentException("Model is not set!");
145122
}
146123
String model = mergedOptions.getModel();
147-
return new EmbeddingRequest(model, inputContent, null,
148-
OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()));
124+
125+
return new OllamaApi.EmbeddingsRequest(model, inputContent, DurationParser.parse(mergedOptions.getKeepAlive()),
126+
OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()), mergedOptions.getTruncate());
127+
}
128+
129+
public static class DurationParser {
130+
131+
private static Pattern PATTERN = Pattern.compile("(\\d+)(ms|s|m|h)");
132+
133+
public static Duration parse(String input) {
134+
135+
if (!StringUtils.hasText(input)) {
136+
return null;
137+
}
138+
139+
Matcher matcher = PATTERN.matcher(input);
140+
141+
if (matcher.matches()) {
142+
long value = Long.parseLong(matcher.group(1));
143+
String unit = matcher.group(2);
144+
145+
return switch (unit) {
146+
case "ms" -> Duration.ofMillis(value);
147+
case "s" -> Duration.ofSeconds(value);
148+
case "m" -> Duration.ofMinutes(value);
149+
case "h" -> Duration.ofHours(value);
150+
default -> throw new IllegalArgumentException("Unsupported time unit: " + unit);
151+
};
152+
}
153+
else {
154+
throw new IllegalArgumentException("Invalid duration format: " + input);
155+
}
156+
}
157+
149158
}
150159

151160
}

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ public record GenerateResponse(
307307
@JsonProperty("eval_duration") Duration evalDuration) {
308308
}
309309

310-
/**
310+
/**
311311
* Generate a completion for the given prompt.
312312
* @param completionRequest Completion request.
313313
* @return Completion response.
@@ -691,11 +691,40 @@ public Flux<ChatResponse> streamingChat(ChatRequest chatRequest) {
691691
* Generate embeddings from a model.
692692
*
693693
* @param model The name of model to generate embeddings from.
694-
* @param prompt The text to generate embeddings for.
694+
* @param input The text or list of text to generate embeddings for.
695+
* @param keepAlive Controls how long the model will stay loaded into memory following the request (default: 5m).
696+
* @param options Additional model parameters listed in the documentation for the
697+
* @param truncate Truncates the end of each input to fit within context length.
698+
* Returns error if false and context length is exceeded. Defaults to true.
699+
*/
700+
@JsonInclude(Include.NON_NULL)
701+
public record EmbeddingsRequest(
702+
@JsonProperty("model") String model,
703+
@JsonProperty("input") List<String> input,
704+
@JsonProperty("keep_alive") Duration keepAlive,
705+
@JsonProperty("options") Map<String, Object> options,
706+
@JsonProperty("truncate") Boolean truncate) {
707+
708+
/**
709+
* Shortcut constructor to create a EmbeddingRequest without options.
710+
* @param model The name of model to generate embeddings from.
711+
* @param input The text or list of text to generate embeddings for.
712+
*/
713+
public EmbeddingsRequest(String model, String input) {
714+
this(model, List.of(input), null, null, null);
715+
}
716+
}
717+
718+
/**
719+
* Generate embeddings from a model.
720+
*
721+
* @param model The name of model to generate embeddings from.
722+
* @param prompt The text generate embeddings for
695723
* @param keepAlive Controls how long the model will stay loaded into memory following the request (default: 5m).
696724
* @param options Additional model parameters listed in the documentation for the
697-
* Model file such as temperature.
725+
* @deprecated Use {@link EmbeddingsRequest} instead.
698726
*/
727+
@Deprecated(since = "1.0.0-M2", forRemoval = true)
699728
@JsonInclude(Include.NON_NULL)
700729
public record EmbeddingRequest(
701730
@JsonProperty("model") String model,
@@ -717,17 +746,49 @@ public EmbeddingRequest(String model, String prompt) {
717746
* The response object returned from the /embedding endpoint.
718747
*
719748
* @param embedding The embedding generated from the model.
749+
* @deprecated Use {@link EmbeddingsResponse} instead.
720750
*/
751+
@Deprecated(since = "1.0.0-M2", forRemoval = true)
721752
@JsonInclude(Include.NON_NULL)
722753
public record EmbeddingResponse(
723754
@JsonProperty("embedding") List<Double> embedding) {
724755
}
725756

757+
758+
/**
759+
* The response object returned from the /embedding endpoint.
760+
* @param model The model used for generating the embeddings.
761+
* @param embeddings The list of embeddings generated from the model.
762+
* Each embedding (list of doubles) corresponds to a single input text.
763+
*/
764+
@JsonInclude(Include.NON_NULL)
765+
public record EmbeddingsResponse(
766+
@JsonProperty("model") String model,
767+
@JsonProperty("embeddings") List<List<Double>> embeddings) {
768+
}
769+
770+
/**
771+
* Generate embeddings from a model.
772+
* @param embeddingsRequest Embedding request.
773+
* @return Embeddings response.
774+
*/
775+
public EmbeddingsResponse embed(EmbeddingsRequest embeddingsRequest) {
776+
Assert.notNull(embeddingsRequest, REQUEST_BODY_NULL_ERROR);
777+
778+
return this.restClient.post()
779+
.uri("/api/embed")
780+
.body(embeddingsRequest)
781+
.retrieve()
782+
.onStatus(this.responseErrorHandler)
783+
.body(EmbeddingsResponse.class);
784+
}
726785
/**
727786
* Generate embeddings from a model.
728787
* @param embeddingRequest Embedding request.
729788
* @return Embedding response.
789+
* @deprecated Use {@link #embed(EmbeddingsRequest)} instead.
730790
*/
791+
@Deprecated(since = "1.0.0-M2", forRemoval = true)
731792
public EmbeddingResponse embeddings(EmbeddingRequest embeddingRequest) {
732793
Assert.notNull(embeddingRequest, REQUEST_BODY_NULL_ERROR);
733794

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed
5252

5353
public static final String DEFAULT_MODEL = OllamaModel.MISTRAL.id();
5454

55-
private static final List<String> NON_SUPPORTED_FIELDS = List.of("model", "format", "keep_alive");
55+
private static final List<String> NON_SUPPORTED_FIELDS = List.of("model", "format", "keep_alive", "truncate");
5656

5757
// Following fields are options which must be set when the model is loaded into
5858
// memory.
@@ -267,6 +267,13 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed
267267
* Part of Chat completion <a href="https://github.com/ollama/ollama/blob/main/docs/api.md#parameters-1">advanced parameters</a>.
268268
*/
269269
@JsonProperty("keep_alive") private String keepAlive;
270+
271+
272+
/**
273+
* Truncates the end of each input to fit within context length. Returns error if false and context length is exceeded.
274+
* Defaults to true.
275+
*/
276+
@JsonProperty("truncate") private Boolean truncate;
270277

271278
/**
272279
* Tool Function Callbacks to register with the ChatModel.
@@ -312,14 +319,6 @@ public OllamaOptions withModel(OllamaModel model) {
312319
return this;
313320
}
314321

315-
public String getModel() {
316-
return model;
317-
}
318-
319-
public void setModel(String model) {
320-
this.model = model;
321-
}
322-
323322
public OllamaOptions withFormat(String format) {
324323
this.format = format;
325324
return this;
@@ -330,6 +329,11 @@ public OllamaOptions withKeepAlive(String keepAlive) {
330329
return this;
331330
}
332331

332+
public OllamaOptions withTruncate(Boolean truncate) {
333+
this.truncate = truncate;
334+
return this;
335+
}
336+
333337
public OllamaOptions withUseNUMA(Boolean useNUMA) {
334338
this.useNUMA = useNUMA;
335339
return this;
@@ -491,6 +495,17 @@ public OllamaOptions withFunction(String functionName) {
491495
return this;
492496
}
493497

498+
// -------------------
499+
// Getters and Setters
500+
// -------------------
501+
public String getModel() {
502+
return model;
503+
}
504+
505+
public void setModel(String model) {
506+
this.model = model;
507+
}
508+
494509
public String getFormat() {
495510
return this.format;
496511
}
@@ -739,6 +754,14 @@ public void setStop(List<String> stop) {
739754
this.stop = stop;
740755
}
741756

757+
public Boolean getTruncate() {
758+
return this.truncate;
759+
}
760+
761+
public void setTruncate(Boolean truncate) {
762+
this.truncate = truncate;
763+
}
764+
742765
@Override
743766
public List<FunctionCallback> getFunctionCallbacks() {
744767
return this.functionCallbacks;
@@ -797,6 +820,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) {
797820
.withModel(fromOptions.getModel())
798821
.withFormat(fromOptions.getFormat())
799822
.withKeepAlive(fromOptions.getKeepAlive())
823+
.withTruncate(fromOptions.getTruncate())
800824
.withUseNUMA(fromOptions.getUseNUMA())
801825
.withNumCtx(fromOptions.getNumCtx())
802826
.withNumBatch(fromOptions.getNumBatch())
@@ -839,15 +863,16 @@ public boolean equals(Object o) {
839863
return false;
840864
OllamaOptions that = (OllamaOptions) o;
841865
return Objects.equals(model, that.model) && Objects.equals(format, that.format)
842-
&& Objects.equals(keepAlive, that.keepAlive) && Objects.equals(useNUMA, that.useNUMA)
843-
&& Objects.equals(numCtx, that.numCtx) && Objects.equals(numBatch, that.numBatch)
844-
&& Objects.equals(numGPU, that.numGPU) && Objects.equals(mainGPU, that.mainGPU)
845-
&& Objects.equals(lowVRAM, that.lowVRAM) && Objects.equals(f16KV, that.f16KV)
846-
&& Objects.equals(logitsAll, that.logitsAll) && Objects.equals(vocabOnly, that.vocabOnly)
847-
&& Objects.equals(useMMap, that.useMMap) && Objects.equals(useMLock, that.useMLock)
848-
&& Objects.equals(numThread, that.numThread) && Objects.equals(numKeep, that.numKeep)
849-
&& Objects.equals(seed, that.seed) && Objects.equals(numPredict, that.numPredict)
850-
&& Objects.equals(topK, that.topK) && Objects.equals(topP, that.topP) && Objects.equals(tfsZ, that.tfsZ)
866+
&& Objects.equals(keepAlive, that.keepAlive) && Objects.equals(truncate, that.truncate)
867+
&& Objects.equals(useNUMA, that.useNUMA) && Objects.equals(numCtx, that.numCtx)
868+
&& Objects.equals(numBatch, that.numBatch) && Objects.equals(numGPU, that.numGPU)
869+
&& Objects.equals(mainGPU, that.mainGPU) && Objects.equals(lowVRAM, that.lowVRAM)
870+
&& Objects.equals(f16KV, that.f16KV) && Objects.equals(logitsAll, that.logitsAll)
871+
&& Objects.equals(vocabOnly, that.vocabOnly) && Objects.equals(useMMap, that.useMMap)
872+
&& Objects.equals(useMLock, that.useMLock) && Objects.equals(numThread, that.numThread)
873+
&& Objects.equals(numKeep, that.numKeep) && Objects.equals(seed, that.seed)
874+
&& Objects.equals(numPredict, that.numPredict) && Objects.equals(topK, that.topK)
875+
&& Objects.equals(topP, that.topP) && Objects.equals(tfsZ, that.tfsZ)
851876
&& Objects.equals(typicalP, that.typicalP) && Objects.equals(repeatLastN, that.repeatLastN)
852877
&& Objects.equals(temperature, that.temperature) && Objects.equals(repeatPenalty, that.repeatPenalty)
853878
&& Objects.equals(presencePenalty, that.presencePenalty)
@@ -860,12 +885,12 @@ public boolean equals(Object o) {
860885

861886
@Override
862887
public int hashCode() {
863-
return Objects.hash(this.model, this.format, this.keepAlive, this.useNUMA, this.numCtx, this.numBatch,
864-
this.numGPU, this.mainGPU, lowVRAM, this.f16KV, this.logitsAll, this.vocabOnly, this.useMMap,
865-
this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK, this.topP, tfsZ,
866-
this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, this.presencePenalty,
867-
this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta, this.penalizeNewline,
868-
this.stop, this.functionCallbacks, this.functions);
888+
return Objects.hash(this.model, this.format, this.keepAlive, this.truncate, this.useNUMA, this.numCtx,
889+
this.numBatch, this.numGPU, this.mainGPU, lowVRAM, this.f16KV, this.logitsAll, this.vocabOnly,
890+
this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK,
891+
this.topP, tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty,
892+
this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta,
893+
this.penalizeNewline, this.stop, this.functionCallbacks, this.functions);
869894
}
870895

871896
}

0 commit comments

Comments
 (0)