Skip to content

Commit 24fb447

Browse files
ThomasVitaleilayaperumalg
authored andcommitted
openai: Adopt new strategy for ObservationContext
Relates to gh-2518 Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
1 parent 3d1825d commit 24fb447

File tree

3 files changed

+54
-52
lines changed

3 files changed

+54
-52
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -42,7 +42,6 @@
4242
import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList;
4343
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
4444
import org.springframework.ai.retry.RetryUtils;
45-
import org.springframework.lang.Nullable;
4645
import org.springframework.retry.support.RetryTemplate;
4746
import org.springframework.util.Assert;
4847

@@ -148,13 +147,16 @@ public float[] embed(Document document) {
148147

149148
@Override
150149
public EmbeddingResponse call(EmbeddingRequest request) {
151-
OpenAiEmbeddingOptions requestOptions = mergeOptions(request.getOptions(), this.defaultOptions);
152-
OpenAiApi.EmbeddingRequest<List<String>> apiRequest = createRequest(request, requestOptions);
150+
// Before moving any further, build the final request EmbeddingRequest,
151+
// merging runtime and default options.
152+
EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request);
153+
154+
OpenAiApi.EmbeddingRequest<List<String>> apiRequest = createRequest(embeddingRequest);
153155

154156
var observationContext = EmbeddingModelObservationContext.builder()
155157
.embeddingRequest(request)
156158
.provider(OpenAiApiConstants.PROVIDER_NAME)
157-
.requestOptions(requestOptions)
159+
.requestOptions(embeddingRequest.getOptions())
158160
.build();
159161

160162
return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
@@ -190,35 +192,32 @@ private DefaultUsage getDefaultUsage(OpenAiApi.Usage usage) {
190192
return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage);
191193
}
192194

193-
private OpenAiApi.EmbeddingRequest<List<String>> createRequest(EmbeddingRequest request,
194-
OpenAiEmbeddingOptions requestOptions) {
195+
private OpenAiApi.EmbeddingRequest<List<String>> createRequest(EmbeddingRequest request) {
196+
OpenAiEmbeddingOptions requestOptions = (OpenAiEmbeddingOptions) request.getOptions();
195197
return new OpenAiApi.EmbeddingRequest<>(request.getInstructions(), requestOptions.getModel(),
196198
requestOptions.getEncodingFormat(), requestOptions.getDimensions(), requestOptions.getUser());
197199
}
198200

199-
/**
200-
* Merge runtime and default {@link EmbeddingOptions} to compute the final options to
201-
* use in the request.
202-
*/
203-
private OpenAiEmbeddingOptions mergeOptions(@Nullable EmbeddingOptions runtimeOptions,
204-
OpenAiEmbeddingOptions defaultOptions) {
205-
var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeOptions, EmbeddingOptions.class,
206-
OpenAiEmbeddingOptions.class);
207-
208-
if (runtimeOptionsForProvider == null) {
209-
return defaultOptions;
201+
private EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) {
202+
// Process runtime options
203+
OpenAiEmbeddingOptions runtimeOptions = null;
204+
if (embeddingRequest.getOptions() != null) {
205+
runtimeOptions = ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class,
206+
OpenAiEmbeddingOptions.class);
210207
}
211208

212-
return OpenAiEmbeddingOptions.builder()
209+
OpenAiEmbeddingOptions requestOptions = runtimeOptions == null ? this.defaultOptions : OpenAiEmbeddingOptions
210+
.builder()
213211
// Handle portable embedding options
214-
.model(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel()))
215-
.dimensions(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getDimensions(),
216-
defaultOptions.getDimensions()))
212+
.model(ModelOptionsUtils.mergeOption(runtimeOptions.getModel(), this.defaultOptions.getModel()))
213+
.dimensions(ModelOptionsUtils.mergeOption(runtimeOptions.getDimensions(), defaultOptions.getDimensions()))
217214
// Handle OpenAI specific embedding options
218-
.encodingFormat(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getEncodingFormat(),
215+
.encodingFormat(ModelOptionsUtils.mergeOption(runtimeOptions.getEncodingFormat(),
219216
defaultOptions.getEncodingFormat()))
220-
.user(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getUser(), defaultOptions.getUser()))
217+
.user(ModelOptionsUtils.mergeOption(runtimeOptions.getUser(), this.defaultOptions.getUser()))
221218
.build();
219+
220+
return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions);
222221
}
223222

224223
/**

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -39,7 +39,6 @@
3939
import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata;
4040
import org.springframework.ai.retry.RetryUtils;
4141
import org.springframework.http.ResponseEntity;
42-
import org.springframework.lang.Nullable;
4342
import org.springframework.retry.support.RetryTemplate;
4443
import org.springframework.util.Assert;
4544

@@ -127,13 +126,16 @@ public OpenAiImageModel(OpenAiImageApi openAiImageApi, OpenAiImageOptions option
127126

128127
@Override
129128
public ImageResponse call(ImagePrompt imagePrompt) {
130-
OpenAiImageOptions requestImageOptions = mergeOptions(imagePrompt.getOptions(), this.defaultOptions);
131-
OpenAiImageApi.OpenAiImageRequest imageRequest = createRequest(imagePrompt, requestImageOptions);
129+
// Before moving any further, build the final request ImagePrompt,
130+
// merging runtime and default options.
131+
ImagePrompt requestImagePrompt = buildRequestImagePrompt(imagePrompt);
132+
133+
OpenAiImageApi.OpenAiImageRequest imageRequest = createRequest(requestImagePrompt);
132134

133135
var observationContext = ImageModelObservationContext.builder()
134136
.imagePrompt(imagePrompt)
135137
.provider(OpenAiApiConstants.PROVIDER_NAME)
136-
.requestOptions(requestImageOptions)
138+
.requestOptions(requestImagePrompt.getOptions())
137139
.build();
138140

139141
return ImageModelObservationDocumentation.IMAGE_MODEL_OPERATION
@@ -151,14 +153,14 @@ public ImageResponse call(ImagePrompt imagePrompt) {
151153
});
152154
}
153155

154-
private OpenAiImageApi.OpenAiImageRequest createRequest(ImagePrompt imagePrompt,
155-
OpenAiImageOptions requestImageOptions) {
156+
private OpenAiImageApi.OpenAiImageRequest createRequest(ImagePrompt imagePrompt) {
156157
String instructions = imagePrompt.getInstructions().get(0).getText();
158+
OpenAiImageOptions imageOptions = (OpenAiImageOptions) imagePrompt.getOptions();
157159

158160
OpenAiImageApi.OpenAiImageRequest imageRequest = new OpenAiImageApi.OpenAiImageRequest(instructions,
159161
OpenAiImageApi.DEFAULT_IMAGE_MODEL);
160162

161-
return ModelOptionsUtils.merge(requestImageOptions, imageRequest, OpenAiImageApi.OpenAiImageRequest.class);
163+
return ModelOptionsUtils.merge(imageOptions, imageRequest, OpenAiImageApi.OpenAiImageRequest.class);
162164
}
163165

164166
private ImageResponse convertResponse(ResponseEntity<OpenAiImageApi.OpenAiImageResponse> imageResponseEntity,
@@ -179,31 +181,29 @@ private ImageResponse convertResponse(ResponseEntity<OpenAiImageApi.OpenAiImageR
179181
return new ImageResponse(imageGenerationList, openAiImageResponseMetadata);
180182
}
181183

182-
/**
183-
* Merge runtime and default {@link ImageOptions} to compute the final options to use
184-
* in the request.
185-
*/
186-
private OpenAiImageOptions mergeOptions(@Nullable ImageOptions runtimeOptions, OpenAiImageOptions defaultOptions) {
187-
var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeOptions, ImageOptions.class,
188-
OpenAiImageOptions.class);
189-
190-
if (runtimeOptionsForProvider == null) {
191-
return defaultOptions;
184+
private ImagePrompt buildRequestImagePrompt(ImagePrompt imagePrompt) {
185+
// Process runtime options
186+
OpenAiImageOptions runtimeOptions = null;
187+
if (imagePrompt.getOptions() != null) {
188+
runtimeOptions = ModelOptionsUtils.copyToTarget(imagePrompt.getOptions(), ImageOptions.class,
189+
OpenAiImageOptions.class);
192190
}
193191

194-
return OpenAiImageOptions.builder()
192+
OpenAiImageOptions requestOptions = runtimeOptions == null ? this.defaultOptions : OpenAiImageOptions.builder()
195193
// Handle portable image options
196-
.model(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel()))
197-
.N(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getN(), defaultOptions.getN()))
198-
.responseFormat(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getResponseFormat(),
194+
.model(ModelOptionsUtils.mergeOption(runtimeOptions.getModel(), defaultOptions.getModel()))
195+
.N(ModelOptionsUtils.mergeOption(runtimeOptions.getN(), defaultOptions.getN()))
196+
.responseFormat(ModelOptionsUtils.mergeOption(runtimeOptions.getResponseFormat(),
199197
defaultOptions.getResponseFormat()))
200-
.width(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getWidth(), defaultOptions.getWidth()))
201-
.height(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getHeight(), defaultOptions.getHeight()))
202-
.style(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getStyle(), defaultOptions.getStyle()))
198+
.width(ModelOptionsUtils.mergeOption(runtimeOptions.getWidth(), defaultOptions.getWidth()))
199+
.height(ModelOptionsUtils.mergeOption(runtimeOptions.getHeight(), defaultOptions.getHeight()))
200+
.style(ModelOptionsUtils.mergeOption(runtimeOptions.getStyle(), defaultOptions.getStyle()))
203201
// Handle OpenAI specific image options
204-
.quality(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getQuality(), defaultOptions.getQuality()))
205-
.user(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getUser(), defaultOptions.getUser()))
202+
.quality(ModelOptionsUtils.mergeOption(runtimeOptions.getQuality(), defaultOptions.getQuality()))
203+
.user(ModelOptionsUtils.mergeOption(runtimeOptions.getUser(), defaultOptions.getUser()))
206204
.build();
205+
206+
return new ImagePrompt(imagePrompt.getInstructions(), requestOptions);
207207
}
208208

209209
/**

spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingRequest.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.List;
2020

2121
import org.springframework.ai.model.ModelRequest;
22+
import org.springframework.lang.Nullable;
2223

2324
/**
2425
* Request to embed a list of input instructions.
@@ -29,9 +30,10 @@ public class EmbeddingRequest implements ModelRequest<List<String>> {
2930

3031
private final List<String> inputs;
3132

33+
@Nullable
3234
private final EmbeddingOptions options;
3335

34-
public EmbeddingRequest(List<String> inputs, EmbeddingOptions options) {
36+
public EmbeddingRequest(List<String> inputs, @Nullable EmbeddingOptions options) {
3537
this.inputs = inputs;
3638
this.options = options;
3739
}
@@ -42,6 +44,7 @@ public List<String> getInstructions() {
4244
}
4345

4446
@Override
47+
@Nullable
4548
public EmbeddingOptions getOptions() {
4649
return this.options;
4750
}

0 commit comments

Comments
 (0)