Skip to content

Commit ae6a019

Browse files
PabloSanchitzolov
authored andcommitted
Add additional properties attribute to watsonx ai option and tests
- fix: remove non-sense underscore checking - fix: check the model is included in request parameters - fix: refactor, use constant - fix: add jsonproperty decorator to additional attribute - feat: allow additional params merging into the default watsonx option - fix: remove not needed key - fix: remove default values in watsonx ai options class - fix: add default property
1 parent 50f549d commit ae6a019

File tree

5 files changed

+122
-7
lines changed

5 files changed

+122
-7
lines changed

models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatClient.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ public WatsonxAiChatClient(WatsonxAiApi watsonxAiApi) {
6464
.withMaxNewTokens(20)
6565
.withMinNewTokens(0)
6666
.withRepetitionPenalty(1.0f)
67+
.withStopSequences(List.of())
6768
.build());
6869
}
6970

@@ -114,7 +115,10 @@ public WatsonxAiRequest request(Prompt prompt) {
114115
}
115116

116117
if (prompt.getOptions() != null) {
117-
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
118+
if (prompt.getOptions() instanceof WatsonxAiChatOptions runtimeOptions) {
119+
options = ModelOptionsUtils.merge(runtimeOptions, options, WatsonxAiChatOptions.class);
120+
}
121+
else if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
118122
var updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, ChatOptions.class,
119123
WatsonxAiChatOptions.class);
120124

models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,18 @@
1515
*/
1616
package org.springframework.ai.watsonx;
1717

18+
import java.util.HashMap;
1819
import java.util.List;
1920
import java.util.Map;
2021
import java.util.stream.Collectors;
2122

2223
import com.fasterxml.jackson.annotation.JsonProperty;
24+
import com.fasterxml.jackson.annotation.JsonIgnore;
25+
import com.fasterxml.jackson.annotation.JsonAnyGetter;
26+
import com.fasterxml.jackson.annotation.JsonAnySetter;
2327
import com.fasterxml.jackson.core.JsonProcessingException;
2428
import com.fasterxml.jackson.core.type.TypeReference;
2529
import com.fasterxml.jackson.databind.ObjectMapper;
26-
2730
import org.springframework.ai.chat.prompt.ChatOptions;
2831

2932
/**
@@ -37,6 +40,7 @@
3740
* valid Parameters and values</a>
3841
*/
3942
// @formatter:off
43+
4044
public class WatsonxAiChatOptions implements ChatOptions {
4145

4246
/**
@@ -85,14 +89,14 @@ public class WatsonxAiChatOptions implements ChatOptions {
8589
/**
8690
* Sets how many tokens must the LLM generate. (Default: 0)
8791
*/
88-
@JsonProperty("min_new_tokens") private Integer minNewTokens = 0;
92+
@JsonProperty("min_new_tokens") private Integer minNewTokens;
8993

9094
/**
9195
* Sets when the LLM should stop.
9296
* (e.g., ["\n\n\n"]) then when the LLM generates three consecutive line breaks it will terminate.
9397
* Stop sequences are ignored until after the number of tokens that are specified in the Min tokens parameter are generated.
9498
*/
95-
@JsonProperty("stop_sequences") private List<String> stopSequences = List.of();
99+
@JsonProperty("stop_sequences") private List<String> stopSequences;
96100

97101
/**
98102
* Sets how strongly to penalize repetitions. A higher value
@@ -111,6 +115,14 @@ public class WatsonxAiChatOptions implements ChatOptions {
111115
*/
112116
@JsonProperty("model") private String model;
113117

118+
/**
119+
* Set additional request params (some model have non-predefined options)
120+
*/
121+
@JsonProperty("additional")
122+
private Map<String, Object> additional = new HashMap<>();
123+
124+
@JsonIgnore
125+
private ObjectMapper mapper = new ObjectMapper();
114126

115127
public Float getTemperature() {
116128
return temperature;
@@ -192,6 +204,20 @@ public void setModel(String model) {
192204
this.model = model;
193205
}
194206

207+
@JsonAnyGetter
208+
public Map<String, Object> getAdditionalProperties() {
209+
return additional.entrySet().stream()
210+
.collect(Collectors.toMap(
211+
entry -> toSnakeCase(entry.getKey()),
212+
Map.Entry::getValue
213+
));
214+
}
215+
216+
@JsonAnySetter
217+
public void addAdditionalProperty(String key, Object value) {
218+
additional.put(key, value);
219+
}
220+
195221
public static Builder builder() {
196222
return new Builder();
197223
}
@@ -250,6 +276,16 @@ public Builder withModel(String model) {
250276
return this;
251277
}
252278

279+
public Builder withAdditionalProperty(String key, Object value) {
280+
this.options.additional.put(key, value);
281+
return this;
282+
}
283+
284+
public Builder withAdditionalProperties(Map<String, Object> properties) {
285+
this.options.additional.putAll(properties);
286+
return this;
287+
}
288+
253289
public WatsonxAiChatOptions build() {
254290
return this.options;
255291
}
@@ -261,9 +297,11 @@ public WatsonxAiChatOptions build() {
261297
*/
262298
public Map<String, Object> toMap() {
263299
try {
264-
var json = new ObjectMapper().writeValueAsString(this);
265-
return new ObjectMapper().readValue(json, new TypeReference<Map<String, Object>>() {
266-
});
300+
var json = mapper.writeValueAsString(this);
301+
var map = mapper.readValue(json, new TypeReference<Map<String, Object>>() {});
302+
map.remove("additional");
303+
304+
return map;
267305
}
268306
catch (JsonProcessingException e) {
269307
throw new RuntimeException(e);
@@ -282,5 +320,9 @@ public static Map<String, Object> filterNonSupportedFields(Map<String, Object> o
282320
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
283321
}
284322

323+
private String toSnakeCase(String input) {
324+
return input != null ? input.replaceAll("([a-z])([A-Z]+)", "$1_$2").toLowerCase() : null;
325+
}
326+
285327
}
286328
// @formatter:on

models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiRequest.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import com.fasterxml.jackson.annotation.JsonProperty;
2222

2323
import org.springframework.ai.watsonx.WatsonxAiChatOptions;
24+
import org.springframework.util.Assert;
2425

2526
// @formatter:off
2627
@JsonInclude(JsonInclude.Include.NON_NULL)
@@ -62,6 +63,7 @@ public WatsonxAiRequest withProjectId(String projectId) {
6263
public static Builder builder(String input) { return new Builder(input); }
6364

6465
public static class Builder {
66+
public static final String MODEL_PARAMETER_IS_REQUIRED = "Model parameter is required";
6567
private final String input;
6668
private Map<String, Object> parameters;
6769
private String model = "";
@@ -71,6 +73,7 @@ public Builder(String input) {
7173
}
7274

7375
public Builder withParameters(Map<String, Object> parameters) {
76+
Assert.notNull(parameters.get("model"), MODEL_PARAMETER_IS_REQUIRED);
7477
this.model = parameters.get("model").toString();
7578
this.parameters = WatsonxAiChatOptions.filterNonSupportedFields(parameters);
7679
return this;

models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/api/WatsonxAiChatOptionTest.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.springframework.ai.watsonx.WatsonxAiChatOptions;
2323

2424
import java.util.List;
25+
import java.util.Map;
2526

2627
/**
2728
* @author Pablo Sanchidrian Herrera
@@ -56,6 +57,68 @@ public void testOptions() {
5657
assertThat(optionsMap).containsEntry("random_seed", 4);
5758
}
5859

60+
@Test
61+
public void testOptionsWithAdditionalParamsOneByOne() {
62+
WatsonxAiChatOptions options = WatsonxAiChatOptions.builder()
63+
.withDecodingMethod("sample")
64+
.withTemperature(1.2f)
65+
.withTopK(20)
66+
.withTopP(0.5f)
67+
.withMaxNewTokens(100)
68+
.withMinNewTokens(20)
69+
.withStopSequences(List.of("\n\n\n"))
70+
.withRepetitionPenalty(1.1f)
71+
.withRandomSeed(4)
72+
.withAdditionalProperty("HAP", true)
73+
.withAdditionalProperty("typicalP", 0.5f)
74+
.build();
75+
76+
var optionsMap = options.toMap();
77+
78+
assertThat(optionsMap).containsEntry("decoding_method", "sample");
79+
assertThat(optionsMap).containsEntry("temperature", 1.2);
80+
assertThat(optionsMap).containsEntry("top_k", 20);
81+
assertThat(optionsMap).containsEntry("top_p", 0.5);
82+
assertThat(optionsMap).containsEntry("max_new_tokens", 100);
83+
assertThat(optionsMap).containsEntry("min_new_tokens", 20);
84+
assertThat(optionsMap).containsEntry("stop_sequences", List.of("\n\n\n"));
85+
assertThat(optionsMap).containsEntry("repetition_penalty", 1.1);
86+
assertThat(optionsMap).containsEntry("random_seed", 4);
87+
assertThat(optionsMap).containsEntry("hap", true);
88+
assertThat(optionsMap).containsEntry("typical_p", 0.5);
89+
}
90+
91+
@Test
92+
public void testOptionsWithAdditionalParamsMap() {
93+
WatsonxAiChatOptions options = WatsonxAiChatOptions.builder()
94+
.withDecodingMethod("sample")
95+
.withTemperature(1.2f)
96+
.withTopK(20)
97+
.withTopP(0.5f)
98+
.withMaxNewTokens(100)
99+
.withMinNewTokens(20)
100+
.withStopSequences(List.of("\n\n\n"))
101+
.withRepetitionPenalty(1.1f)
102+
.withRandomSeed(4)
103+
.withAdditionalProperties(Map.of("HAP", true, "typicalP", 0.5f, "test_value", "test"))
104+
.build();
105+
106+
var optionsMap = options.toMap();
107+
108+
assertThat(optionsMap).containsEntry("decoding_method", "sample");
109+
assertThat(optionsMap).containsEntry("temperature", 1.2);
110+
assertThat(optionsMap).containsEntry("top_k", 20);
111+
assertThat(optionsMap).containsEntry("top_p", 0.5);
112+
assertThat(optionsMap).containsEntry("max_new_tokens", 100);
113+
assertThat(optionsMap).containsEntry("min_new_tokens", 20);
114+
assertThat(optionsMap).containsEntry("stop_sequences", List.of("\n\n\n"));
115+
assertThat(optionsMap).containsEntry("repetition_penalty", 1.1);
116+
assertThat(optionsMap).containsEntry("random_seed", 4);
117+
assertThat(optionsMap).containsEntry("hap", true);
118+
assertThat(optionsMap).containsEntry("typical_p", 0.5);
119+
assertThat(optionsMap).containsEntry("test_value", "test");
120+
}
121+
59122
@Test
60123
public void testFilterOut() {
61124
WatsonxAiChatOptions options = WatsonxAiChatOptions.builder().withModel("google/flan-ul2").build();

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/watsonxai/WatsonxAiChatProperties.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import org.springframework.boot.context.properties.ConfigurationProperties;
2020
import org.springframework.boot.context.properties.NestedConfigurationProperty;
2121

22+
import java.util.List;
23+
2224
/**
2325
* Chat properties for Watsonx.AI Chat.
2426
*
@@ -48,6 +50,7 @@ public class WatsonxAiChatProperties {
4850
.withMaxNewTokens(20)
4951
.withMinNewTokens(0)
5052
.withRepetitionPenalty(1.0f)
53+
.withStopSequences(List.of())
5154
.build();
5255

5356
public boolean isEnabled() {

0 commit comments

Comments
 (0)