Skip to content

Commit 60a60b4

Browse files
committed
Fix handling for openai chat portable options
- Now you can use an arbitrary ChatOption instance as Prompt options for the OpenAIChatClinet. - Add Unit tests for OpenAiApi and ModelOptionsUtils. - Document portable options support.
1 parent d71f90f commit 60a60b4

File tree

5 files changed

+359
-15
lines changed

5 files changed

+359
-15
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import reactor.core.publisher.Flux;
2626

2727
import org.springframework.ai.chat.ChatClient;
28+
import org.springframework.ai.chat.ChatOptions;
2829
import org.springframework.ai.chat.ChatResponse;
2930
import org.springframework.ai.chat.Generation;
3031
import org.springframework.ai.chat.StreamingChatClient;
@@ -166,8 +167,10 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
166167
}
167168

168169
if (prompt.getOptions() != null) {
169-
if (prompt.getOptions() instanceof OpenAiChatOptions runtimeOptions) {
170-
request = ModelOptionsUtils.merge(runtimeOptions, request, ChatCompletionRequest.class);
170+
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
171+
OpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
172+
ChatOptions.class, OpenAiChatOptions.class);
173+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);
171174
}
172175
else {
173176
throw new IllegalArgumentException("Prompt options are not of type ChatCompletionRequest:"

spring-ai-core/src/main/java/org/springframework/ai/chat/ChatOptions.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@
1919
import org.springframework.ai.model.ModelOptions;
2020

2121
/**
22-
* portable options
22+
* The ChatOptions represent the common options, portable across different chat models.
2323
*/
2424
public interface ChatOptions extends ModelOptions {
2525

26-
// determine portable optionsb
27-
2826
Float getTemperature();
2927

3028
void setTemperature(Float temperature);

spring-ai-core/src/main/java/org/springframework/ai/model/ModelOptionsUtils.java

Lines changed: 119 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
package org.springframework.ai.model;
1818

19+
import java.beans.PropertyDescriptor;
1920
import java.lang.reflect.Field;
2021
import java.util.ArrayList;
22+
import java.util.Arrays;
2123
import java.util.HashMap;
2224
import java.util.List;
2325
import java.util.Map;
@@ -30,25 +32,37 @@
3032
import com.fasterxml.jackson.databind.ObjectMapper;
3133
import com.fasterxml.jackson.databind.SerializationFeature;
3234

35+
import org.springframework.beans.BeanWrapper;
36+
import org.springframework.beans.BeanWrapperImpl;
37+
import org.springframework.util.Assert;
3338
import org.springframework.util.CollectionUtils;
3439

3540
/**
3641
* Utility class for manipulating {@link ModelOptions} objects.
3742
*
3843
* @author Christian Tzolov
44+
* @since 0.8.0
3945
*/
4046
public final class ModelOptionsUtils {
4147

4248
private final static ObjectMapper OBJECT_MAPPER = new ObjectMapper()
4349
.disable(SerializationFeature.FAIL_ON_EMPTY_BEANS);
4450

51+
private final static List<String> BEAN_MERGE_FIELD_EXCISIONS = List.of("class");
52+
53+
private static ConcurrentHashMap<Class<?>, List<String>> REQUEST_FIELD_NAMES_PER_CLASS = new ConcurrentHashMap<Class<?>, List<String>>();
54+
4555
private ModelOptionsUtils() {
4656

4757
}
4858

4959
/**
5060
* Merges the source object into the target object and returns an object represented
51-
* by the given class. The source null values are ignored.
61+
* by the given class. The JSON property names are used to match the fields to merge.
62+
* The source non-null values override the target values with the same field name. The
63+
* source null values are ignored. If the acceptedFieldNames is not empty, only the
64+
* fields with the given names are merged and returned. If the acceptedFieldNames is
65+
* empty, use the {@code @JsonProperty} names, inferred from the provided clazz.
5266
* @param <T> they type of the class to return.
5367
* @param source the source object to merge.
5468
* @param target the target object to merge into.
@@ -62,8 +76,12 @@ public static <T> T merge(Object source, Object target, Class<T> clazz, List<Str
6276
? REQUEST_FIELD_NAMES_PER_CLASS.computeIfAbsent(clazz, ModelOptionsUtils::getJsonPropertyValues)
6377
: acceptedFieldNames;
6478

65-
Map<String, Object> sourceMap = objectToMap(source);
66-
Map<String, Object> targetMap = objectToMap(target);
79+
if (CollectionUtils.isEmpty(requestFieldNames)) {
80+
throw new IllegalArgumentException("No @JsonProperty fields found in the " + clazz.getName());
81+
}
82+
83+
Map<String, Object> sourceMap = ModelOptionsUtils.objectToMap(source);
84+
Map<String, Object> targetMap = ModelOptionsUtils.objectToMap(target);
6785

6886
targetMap.putAll(sourceMap.entrySet()
6987
.stream()
@@ -77,22 +95,23 @@ public static <T> T merge(Object source, Object target, Class<T> clazz, List<Str
7795
.collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()));
7896
}
7997

80-
return mapToClass(targetMap, clazz);
98+
return ModelOptionsUtils.mapToClass(targetMap, clazz);
8199
}
82100

83-
private static ConcurrentHashMap<Class<?>, List<String>> REQUEST_FIELD_NAMES_PER_CLASS = new ConcurrentHashMap<Class<?>, List<String>>();
84-
85101
/**
86102
* Merges the source object into the target object and returns an object represented
87-
* by the given class. The source null values are ignored.
103+
* by the given class. The JSON property names are used to match the fields to merge.
104+
* The source non-null values override the target values with the same field name. The
105+
* source null values are ignored. Returns the only field names that match the
106+
* {@code @JsonProperty} names, inferred from the provided clazz.
88107
* @param <T> they type of the class to return.
89108
* @param source the source object to merge.
90109
* @param target the target object to merge into.
91110
* @param clazz the class to return.
92111
* @return the merged object represented by the given class.
93112
*/
94113
public static <T> T merge(Object source, Object target, Class<T> clazz) {
95-
return merge(source, target, clazz, null);
114+
return ModelOptionsUtils.merge(source, target, clazz, null);
96115
}
97116

98117
/**
@@ -132,7 +151,7 @@ public static <T> T mapToClass(Map<String, Object> source, Class<T> clazz) {
132151
}
133152

134153
/**
135-
* Returns the list of values of the {@link JsonProperty} annotations.
154+
* Returns the list of name values of the {@link JsonProperty} annotations.
136155
* @param clazz the class that contains fields annotated with {@link JsonProperty}.
137156
* @return the list of values of the {@link JsonProperty} annotations.
138157
*/
@@ -148,4 +167,95 @@ public static List<String> getJsonPropertyValues(Class<?> clazz) {
148167
return values;
149168
}
150169

170+
/**
171+
* Returns a new instance of the targetBeanClazz that copies the bean values from the
172+
* sourceBean instance.
173+
* @param sourceBean the source bean to copy the values from.
174+
* @param sourceInterfaceClazz the source interface class. Only the fields with the
175+
* same name as the interface methods are copied. This allow the source object to be a
176+
* subclass of the source interface with additional, non-interface fields.
177+
* @param targetBeanClazz the target class, a subclass of the ChatOptions, to convert
178+
* into.
179+
* @param <T> the target class type.
180+
* @return a new instance of the targetBeanClazz with the values from the sourceBean
181+
* instance.
182+
*/
183+
public static <I, S extends I, T extends S> T copyToTarget(S sourceBean, Class<I> sourceInterfaceClazz,
184+
Class<T> targetBeanClazz) {
185+
186+
Assert.notNull(sourceInterfaceClazz, "SourceOptionsClazz must not be null");
187+
Assert.notNull(targetBeanClazz, "TargetOptionsClazz must not be null");
188+
189+
if (sourceBean == null) {
190+
return null;
191+
}
192+
193+
if (sourceBean.getClass().isAssignableFrom(targetBeanClazz)) {
194+
return (T) sourceBean;
195+
}
196+
197+
try {
198+
T targetOptions = targetBeanClazz.getConstructor().newInstance();
199+
200+
ModelOptionsUtils.mergeBeans(sourceBean, targetOptions, sourceInterfaceClazz, true);
201+
202+
return targetOptions;
203+
}
204+
catch (Exception e) {
205+
throw new RuntimeException(
206+
"Failed to convert the " + sourceInterfaceClazz.getName() + " into " + targetBeanClazz.getName(),
207+
e);
208+
}
209+
}
210+
211+
/**
212+
* Merges the source object into the target object. The source null values are
213+
* ignored. Only objects with Getter and Setter methods are supported.
214+
* @param <T> the type of the source and target object.
215+
* @param source the source object to merge.
216+
* @param target the target object to merge into.
217+
* @param sourceInterfaceClazz the source interface class. Only the fields with the
218+
* same name as the interface methods are merged. This allow the source object to be a
219+
* subclass of the source interface with additional, non-interface fields.
220+
* @param overrideNonNullTargetValues if true, the source non-null values override the
221+
* target values with the same field name. If false, the source non-null values are
222+
* ignored.
223+
* @return the merged target object.
224+
*/
225+
public static <I, S extends I, T extends S> T mergeBeans(S source, T target, Class<I> sourceInterfaceClazz,
226+
boolean overrideNonNullTargetValues) {
227+
Assert.notNull(source, "Source object must not be null");
228+
Assert.notNull(target, "Target object must not be null");
229+
230+
BeanWrapper sourceBeanWrap = new BeanWrapperImpl(source);
231+
BeanWrapper targetBeanWrap = new BeanWrapperImpl(target);
232+
233+
List<String> interfaceNames = Arrays.stream(sourceInterfaceClazz.getMethods()).map(m -> m.getName()).toList();
234+
235+
for (PropertyDescriptor descriptor : sourceBeanWrap.getPropertyDescriptors()) {
236+
237+
if (!BEAN_MERGE_FIELD_EXCISIONS.contains(descriptor.getName())
238+
&& interfaceNames.contains(toGetName(descriptor.getName()))) {
239+
240+
String propertyName = descriptor.getName();
241+
Object value = sourceBeanWrap.getPropertyValue(propertyName);
242+
243+
// Copy value to the target object
244+
if (value != null) {
245+
var targetValue = targetBeanWrap.getPropertyValue(propertyName);
246+
247+
if (targetValue == null || overrideNonNullTargetValues) {
248+
targetBeanWrap.setPropertyValue(propertyName, value);
249+
}
250+
}
251+
}
252+
}
253+
254+
return target;
255+
}
256+
257+
private static String toGetName(String name) {
258+
return "get" + name.substring(0, 1).toUpperCase() + name.substring(1);
259+
}
260+
151261
}

0 commit comments

Comments
 (0)