Skip to content

Commit 3d1825d

Browse files
ThomasVitaleilayaperumalg
authored andcommitted
ollama: Adopt new strategy for ObservationContext
Relates to gh-2518 Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
1 parent 20deb0e commit 3d1825d

File tree

2 files changed

+47
-39
lines changed

2 files changed

+47
-39
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaEmbeddingModel.java

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import org.springframework.ai.embedding.Embedding;
3232
import org.springframework.ai.embedding.EmbeddingModel;
3333
import org.springframework.ai.embedding.EmbeddingOptions;
34-
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
3534
import org.springframework.ai.embedding.EmbeddingRequest;
3635
import org.springframework.ai.embedding.EmbeddingResponse;
3736
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
@@ -105,13 +104,16 @@ public float[] embed(Document document) {
105104
public EmbeddingResponse call(EmbeddingRequest request) {
106105
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
107106

108-
OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest = ollamaEmbeddingRequest(request.getInstructions(),
109-
request.getOptions());
107+
// Before moving any further, build the final request EmbeddingRequest,
108+
// merging runtime and default options.
109+
EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request);
110+
111+
OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest = ollamaEmbeddingRequest(embeddingRequest);
110112

111113
var observationContext = EmbeddingModelObservationContext.builder()
112114
.embeddingRequest(request)
113115
.provider(OllamaApi.PROVIDER_NAME)
114-
.requestOptions(buildRequestOptions(ollamaEmbeddingRequest))
116+
.requestOptions(embeddingRequest.getOptions())
115117
.build();
116118

117119
return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
@@ -142,31 +144,34 @@ private DefaultUsage getDefaultUsage(OllamaApi.EmbeddingsResponse response) {
142144
return new DefaultUsage(Optional.ofNullable(response.promptEvalCount()).orElse(0), 0);
143145
}
144146

145-
/**
146-
* Package access for testing.
147-
*/
148-
OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List<String> inputContent, EmbeddingOptions options) {
149-
150-
// runtime options
147+
EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) {
148+
// Process runtime options
151149
OllamaOptions runtimeOptions = null;
152-
if (options != null && options instanceof OllamaOptions ollamaOptions) {
153-
runtimeOptions = ollamaOptions;
150+
if (embeddingRequest.getOptions() != null) {
151+
runtimeOptions = ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class,
152+
OllamaOptions.class);
154153
}
155154

156-
OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);
155+
// Define request options by merging runtime options and default options
156+
OllamaOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
157+
OllamaOptions.class);
157158

158-
// Override the model.
159-
if (!StringUtils.hasText(mergedOptions.getModel())) {
160-
throw new IllegalArgumentException("Model is not set!");
159+
// Validate request options
160+
if (!StringUtils.hasText(requestOptions.getModel())) {
161+
throw new IllegalArgumentException("model cannot be null or empty");
161162
}
162-
String model = mergedOptions.getModel();
163163

164-
return new OllamaApi.EmbeddingsRequest(model, inputContent, DurationParser.parse(mergedOptions.getKeepAlive()),
165-
OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()), mergedOptions.getTruncate());
164+
return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions);
166165
}
167166

168-
private EmbeddingOptions buildRequestOptions(OllamaApi.EmbeddingsRequest request) {
169-
return EmbeddingOptionsBuilder.builder().withModel(request.model()).build();
167+
/**
168+
* Package access for testing.
169+
*/
170+
OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(EmbeddingRequest embeddingRequest) {
171+
OllamaOptions requestOptions = (OllamaOptions) embeddingRequest.getOptions();
172+
return new OllamaApi.EmbeddingsRequest(requestOptions.getModel(), embeddingRequest.getInstructions(),
173+
DurationParser.parse(requestOptions.getKeepAlive()),
174+
OllamaOptions.filterNonSupportedFields(requestOptions.toMap()), requestOptions.getTruncate());
170175
}
171176

172177
/**

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaEmbeddingRequestTests.java

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import org.junit.jupiter.api.Test;
2323

24+
import org.springframework.ai.embedding.EmbeddingRequest;
2425
import org.springframework.ai.ollama.api.OllamaApi;
2526
import org.springframework.ai.ollama.api.OllamaOptions;
2627

@@ -40,43 +41,45 @@ public class OllamaEmbeddingRequestTests {
4041

4142
@Test
4243
public void ollamaEmbeddingRequestDefaultOptions() {
43-
44-
var request = this.embeddingModel.ollamaEmbeddingRequest(List.of("Hello"), null);
45-
46-
assertThat(request.model()).isEqualTo("DEFAULT_MODEL");
47-
assertThat(request.options().get("num_gpu")).isEqualTo(1);
48-
assertThat(request.options().get("main_gpu")).isEqualTo(11);
49-
assertThat(request.options().get("use_mmap")).isEqualTo(true);
50-
assertThat(request.input()).isEqualTo(List.of("Hello"));
44+
var embeddingRequest = this.embeddingModel.buildEmbeddingRequest(new EmbeddingRequest(List.of("Hello"), null));
45+
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);
46+
47+
assertThat(ollamaRequest.model()).isEqualTo("DEFAULT_MODEL");
48+
assertThat(ollamaRequest.options().get("num_gpu")).isEqualTo(1);
49+
assertThat(ollamaRequest.options().get("main_gpu")).isEqualTo(11);
50+
assertThat(ollamaRequest.options().get("use_mmap")).isEqualTo(true);
51+
assertThat(ollamaRequest.input()).isEqualTo(List.of("Hello"));
5152
}
5253

5354
@Test
5455
public void ollamaEmbeddingRequestRequestOptions() {
55-
5656
var promptOptions = OllamaOptions.builder()//
5757
.model("PROMPT_MODEL")//
5858
.mainGPU(22)//
5959
.useMMap(true)//
6060
.numGPU(2)
6161
.build();
6262

63-
var request = this.embeddingModel.ollamaEmbeddingRequest(List.of("Hello"), promptOptions);
63+
var embeddingRequest = this.embeddingModel
64+
.buildEmbeddingRequest(new EmbeddingRequest(List.of("Hello"), promptOptions));
65+
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);
6466

65-
assertThat(request.model()).isEqualTo("PROMPT_MODEL");
66-
assertThat(request.options().get("num_gpu")).isEqualTo(2);
67-
assertThat(request.options().get("main_gpu")).isEqualTo(22);
68-
assertThat(request.options().get("use_mmap")).isEqualTo(true);
69-
assertThat(request.input()).isEqualTo(List.of("Hello"));
67+
assertThat(ollamaRequest.model()).isEqualTo("PROMPT_MODEL");
68+
assertThat(ollamaRequest.options().get("num_gpu")).isEqualTo(2);
69+
assertThat(ollamaRequest.options().get("main_gpu")).isEqualTo(22);
70+
assertThat(ollamaRequest.options().get("use_mmap")).isEqualTo(true);
71+
assertThat(ollamaRequest.input()).isEqualTo(List.of("Hello"));
7072
}
7173

7274
@Test
7375
public void ollamaEmbeddingRequestWithNegativeKeepAlive() {
74-
7576
var promptOptions = OllamaOptions.builder().model("PROMPT_MODEL").keepAlive("-1m").build();
7677

77-
var request = this.embeddingModel.ollamaEmbeddingRequest(List.of("Hello"), promptOptions);
78+
var embeddingRequest = this.embeddingModel
79+
.buildEmbeddingRequest(new EmbeddingRequest(List.of("Hello"), promptOptions));
80+
var ollamaRequest = this.embeddingModel.ollamaEmbeddingRequest(embeddingRequest);
7881

79-
assertThat(request.keepAlive()).isEqualTo(Duration.ofMinutes(-1));
82+
assertThat(ollamaRequest.keepAlive()).isEqualTo(Duration.ofMinutes(-1));
8083
}
8184

8285
}

0 commit comments

Comments
 (0)