Skip to content

Commit 035036c

Browse files
PabloSanchimarkpollack
authored andcommitted
Add support for watsonx.ai embedding model
This commit introduces support for the Watsonx.ai embedding model. It includes: - Watsonx embedding options class with tests - Watsonx embedding model implementation - Auto-configuration and properties for the embedding model - Tests for the Watsonx embedding model - Documentation for using the Watsonx embedding model Also removed use of deprecated APIs in WatsonAIChatModel
1 parent 49b5ff5 commit 035036c

File tree

21 files changed

+676
-71
lines changed

21 files changed

+676
-71
lines changed

models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatModel.java

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.List;
1919
import java.util.Map;
2020

21+
import org.springframework.ai.chat.messages.AssistantMessage;
2122
import org.springframework.ai.chat.model.ChatModel;
2223
import reactor.core.publisher.Flux;
2324

@@ -29,17 +30,18 @@
2930
import org.springframework.ai.chat.prompt.Prompt;
3031
import org.springframework.ai.model.ModelOptionsUtils;
3132
import org.springframework.ai.watsonx.api.WatsonxAiApi;
32-
import org.springframework.ai.watsonx.api.WatsonxAiRequest;
33-
import org.springframework.ai.watsonx.api.WatsonxAiResponse;
33+
import org.springframework.ai.watsonx.api.WatsonxAiChatRequest;
34+
import org.springframework.ai.watsonx.api.WatsonxAiChatResponse;
3435
import org.springframework.ai.watsonx.utils.MessageToPromptConverter;
3536
import org.springframework.util.Assert;
3637

3738
/**
3839
* {@link ChatModel} implementation for {@literal watsonx.ai}.
39-
*
40+
* <p>
4041
* watsonx.ai allows developers to use large language models within a SaaS service. It
41-
* supports multiple open-source models as well as IBM created models
42-
* [watsonx.ai](https://www.ibm.com/products/watsonx-ai). Please refer to the <a href=
42+
* supports multiple open-source models as well as IBM created models.
43+
* <p>
44+
* Please refer to the <a href=
4345
* "https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models.html?context=wx">watsonx.ai
4446
* models</a> for the most up-to-date information about the available models.
4547
*
@@ -78,35 +80,37 @@ public WatsonxAiChatModel(WatsonxAiApi watsonxAiApi, WatsonxAiChatOptions defaul
7880
@Override
7981
public ChatResponse call(Prompt prompt) {
8082

81-
WatsonxAiRequest request = request(prompt);
82-
83-
WatsonxAiResponse response = this.watsonxAiApi.generate(request).getBody();
84-
var generator = new Generation(response.results().get(0).generatedText());
83+
WatsonxAiChatRequest request = request(prompt);
8584

86-
generator = generator.withGenerationMetadata(
85+
WatsonxAiChatResponse response = this.watsonxAiApi.generate(request).getBody();
86+
var generation = new Generation(new AssistantMessage(response.results().get(0).generatedText()),
8787
ChatGenerationMetadata.from(response.results().get(0).stopReason(), response.system()));
8888

89-
return new ChatResponse(List.of(generator));
89+
return new ChatResponse(List.of(generation));
9090
}
9191

9292
@Override
9393
public Flux<ChatResponse> stream(Prompt prompt) {
9494

95-
WatsonxAiRequest request = request(prompt);
95+
WatsonxAiChatRequest request = request(prompt);
9696

97-
Flux<WatsonxAiResponse> response = this.watsonxAiApi.generateStreaming(request);
97+
Flux<WatsonxAiChatResponse> response = this.watsonxAiApi.generateStreaming(request);
9898

9999
return response.map(chunk -> {
100-
Generation generation = new Generation(chunk.results().get(0).generatedText());
100+
String generatedText = chunk.results().get(0).generatedText();
101+
AssistantMessage assistantMessage = new AssistantMessage(generatedText);
102+
103+
ChatGenerationMetadata metadata = ChatGenerationMetadata.NULL;
101104
if (chunk.system() != null) {
102-
generation = generation.withGenerationMetadata(
103-
ChatGenerationMetadata.from(chunk.results().get(0).stopReason(), chunk.system()));
105+
metadata = ChatGenerationMetadata.from(chunk.results().get(0).stopReason(), chunk.system());
104106
}
107+
108+
Generation generation = new Generation(assistantMessage, metadata);
105109
return new ChatResponse(List.of(generation));
106110
});
107111
}
108112

109-
public WatsonxAiRequest request(Prompt prompt) {
113+
public WatsonxAiChatRequest request(Prompt prompt) {
110114

111115
WatsonxAiChatOptions options = WatsonxAiChatOptions.builder().build();
112116

@@ -133,7 +137,7 @@ public WatsonxAiRequest request(Prompt prompt) {
133137
.withHumanPrompt("")
134138
.toPrompt(prompt.getInstructions());
135139

136-
return WatsonxAiRequest.builder(convertedPrompt).withParameters(parameters).build();
140+
return WatsonxAiChatRequest.builder(convertedPrompt).withParameters(parameters).build();
137141
}
138142

139143
@Override

models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ public class WatsonxAiChatOptions implements ChatOptions {
123123
private Map<String, Object> additional = new HashMap<>();
124124

125125
@JsonIgnore
126-
private ObjectMapper mapper = new ObjectMapper();
126+
private final ObjectMapper mapper = new ObjectMapper();
127127

128128
@Override
129129
public Double getTemperature() {
@@ -343,7 +343,7 @@ public Map<String, Object> toMap() {
343343
}
344344

345345
/**
346-
* Filter out the non supported fields from the options.
346+
* Filter out the non-supported fields from the options.
347347
* @param options The options to filter.
348348
* @return The filtered options.
349349
*/
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package org.springframework.ai.watsonx;
2+
3+
import org.slf4j.Logger;
4+
import org.slf4j.LoggerFactory;
5+
import org.springframework.ai.document.Document;
6+
import org.springframework.ai.embedding.*;
7+
import org.springframework.ai.watsonx.api.WatsonxAiApi;
8+
import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingRequest;
9+
import org.springframework.ai.watsonx.api.WatsonxAiEmbeddingResponse;
10+
import org.springframework.util.Assert;
11+
import org.springframework.util.StringUtils;
12+
13+
import java.util.List;
14+
import java.util.concurrent.atomic.AtomicInteger;
15+
16+
/**
17+
* {@link EmbeddingModel} implementation for {@literal Watsonx.ai}.
18+
* <p>
19+
* Watsonx.ai allows developers to run large language models and generate embeddings. It
20+
* supports open-source models available on <a href=
21+
* "https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models.html?context=wx">Watsonx.ai
22+
* models</a>.
23+
* <p>
24+
* Please refer to the <a href="https://www.ibm.com/products/watsonx-ai/">official
25+
* Watsonx.ai website</a> for the most up-to-date information on available models.
26+
*
27+
* @author Pablo Sanchidrian Herrera
28+
* @since 1.0.0
29+
*/
30+
public class WatsonxAiEmbeddingModel extends AbstractEmbeddingModel {
31+
32+
private final Logger logger = LoggerFactory.getLogger(getClass());
33+
34+
private final WatsonxAiApi watsonxAiApi;
35+
36+
/**
37+
* Default options to be used for all embedding requests.
38+
*/
39+
private WatsonxAiEmbeddingOptions defaultOptions = WatsonxAiEmbeddingOptions.create()
40+
.withModel(WatsonxAiEmbeddingOptions.DEFAULT_MODEL);
41+
42+
public WatsonxAiEmbeddingModel(WatsonxAiApi watsonxAiApi) {
43+
this.watsonxAiApi = watsonxAiApi;
44+
}
45+
46+
public WatsonxAiEmbeddingModel(WatsonxAiApi watsonxAiApi, WatsonxAiEmbeddingOptions defaultOptions) {
47+
this.watsonxAiApi = watsonxAiApi;
48+
this.defaultOptions = defaultOptions;
49+
}
50+
51+
@Override
52+
public float[] embed(Document document) {
53+
return embed(document.getContent());
54+
}
55+
56+
@Override
57+
public EmbeddingResponse call(EmbeddingRequest request) {
58+
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
59+
60+
WatsonxAiEmbeddingRequest embeddingRequest = watsonxAiEmbeddingRequest(request.getInstructions(),
61+
request.getOptions());
62+
WatsonxAiEmbeddingResponse response = this.watsonxAiApi.embeddings(embeddingRequest).getBody();
63+
64+
AtomicInteger indexCounter = new AtomicInteger(0);
65+
List<Embedding> embeddings = response.results()
66+
.stream()
67+
.map(e -> new Embedding(e.embedding(), indexCounter.getAndIncrement()))
68+
.toList();
69+
70+
return new EmbeddingResponse(embeddings);
71+
}
72+
73+
WatsonxAiEmbeddingRequest watsonxAiEmbeddingRequest(List<String> inputs, EmbeddingOptions options) {
74+
75+
WatsonxAiEmbeddingOptions runtimeOptions = (options instanceof WatsonxAiEmbeddingOptions)
76+
? (WatsonxAiEmbeddingOptions) options : this.defaultOptions;
77+
78+
if (!StringUtils.hasText(runtimeOptions.getModel())) {
79+
this.logger.warn("The model cannot be null, using default model instead");
80+
runtimeOptions = this.defaultOptions;
81+
}
82+
83+
return WatsonxAiEmbeddingRequest.builder(inputs).withModel(runtimeOptions.getModel()).build();
84+
}
85+
86+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package org.springframework.ai.watsonx;
2+
3+
import com.fasterxml.jackson.annotation.JsonIgnore;
4+
import com.fasterxml.jackson.annotation.JsonInclude;
5+
import com.fasterxml.jackson.annotation.JsonProperty;
6+
import org.springframework.ai.embedding.EmbeddingOptions;
7+
8+
/**
9+
* The configuration information for the embedding requests.
10+
*
11+
* @author Pablo Sanchidrian Herrera
12+
* @since 1.0.0
13+
*/
14+
@JsonInclude(JsonInclude.Include.NON_NULL)
15+
public class WatsonxAiEmbeddingOptions implements EmbeddingOptions {
16+
17+
public static final String DEFAULT_MODEL = "ibm/slate-30m-english-rtrvr";
18+
19+
/**
20+
* The embedding model identifier
21+
*/
22+
@JsonProperty("model_id")
23+
private String model;
24+
25+
public WatsonxAiEmbeddingOptions withModel(String model) {
26+
this.model = model;
27+
return this;
28+
}
29+
30+
public String getModel() {
31+
return model;
32+
}
33+
34+
public void setModel(String model) {
35+
this.model = model;
36+
}
37+
38+
@Override
39+
@JsonIgnore
40+
public Integer getDimensions() {
41+
return null;
42+
}
43+
44+
/**
45+
* Helper factory method to create a new {@link WatsonxAiEmbeddingOptions} instance.
46+
* @return A new {@link WatsonxAiEmbeddingOptions} instance.
47+
*/
48+
public static WatsonxAiEmbeddingOptions create() {
49+
return new WatsonxAiEmbeddingOptions();
50+
}
51+
52+
public static WatsonxAiEmbeddingOptions fromOptions(WatsonxAiEmbeddingOptions fromOptions) {
53+
return new WatsonxAiEmbeddingOptions().withModel(fromOptions.getModel());
54+
}
55+
56+
}

models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiApi.java

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ public class WatsonxAiApi {
5252
private final IamAuthenticator iamAuthenticator;
5353
private final String streamEndpoint;
5454
private final String textEndpoint;
55+
private final String embeddingEndpoint;
5556
private final String projectId;
5657
private IamToken token;
5758

@@ -60,6 +61,7 @@ public class WatsonxAiApi {
6061
* @param baseUrl api base URL.
6162
* @param streamEndpoint streaming generation.
6263
* @param textEndpoint text generation.
64+
* @param embeddingEndpoint embedding generation
6365
* @param projectId watsonx.ai project identifier.
6466
* @param IAMToken IBM Cloud IAM token.
6567
* @param restClientBuilder rest client builder.
@@ -68,12 +70,14 @@ public WatsonxAiApi(
6870
String baseUrl,
6971
String streamEndpoint,
7072
String textEndpoint,
73+
String embeddingEndpoint,
7174
String projectId,
7275
String IAMToken,
7376
RestClient.Builder restClientBuilder
7477
) {
7578
this.streamEndpoint = streamEndpoint;
7679
this.textEndpoint = textEndpoint;
80+
this.embeddingEndpoint = embeddingEndpoint;
7781
this.projectId = projectId;
7882
this.iamAuthenticator = IamAuthenticator.fromConfiguration(Map.of("APIKEY", IAMToken));
7983
this.token = this.iamAuthenticator.requestToken();
@@ -94,8 +98,8 @@ public WatsonxAiApi(
9498
}
9599

96100
@Retryable(retryFor = Exception.class, maxAttempts = 3, backoff = @Backoff(random = true, delay = 1200, maxDelay = 7000, multiplier = 2.5))
97-
public ResponseEntity<WatsonxAiResponse> generate(WatsonxAiRequest watsonxAiRequest) {
98-
Assert.notNull(watsonxAiRequest, WATSONX_REQUEST_CANNOT_BE_NULL);
101+
public ResponseEntity<WatsonxAiChatResponse> generate(WatsonxAiChatRequest watsonxAiChatRequest) {
102+
Assert.notNull(watsonxAiChatRequest, WATSONX_REQUEST_CANNOT_BE_NULL);
99103

100104
if(this.token.needsRefresh()) {
101105
this.token = this.iamAuthenticator.requestToken();
@@ -104,14 +108,14 @@ public ResponseEntity<WatsonxAiResponse> generate(WatsonxAiRequest watsonxAiRequ
104108
return this.restClient.post()
105109
.uri(this.textEndpoint)
106110
.header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken())
107-
.body(watsonxAiRequest.withProjectId(projectId))
111+
.body(watsonxAiChatRequest.withProjectId(projectId))
108112
.retrieve()
109-
.toEntity(WatsonxAiResponse.class);
113+
.toEntity(WatsonxAiChatResponse.class);
110114
}
111115

112116
@Retryable(retryFor = Exception.class, maxAttempts = 3, backoff = @Backoff(random = true, delay = 1200, maxDelay = 7000, multiplier = 2.5))
113-
public Flux<WatsonxAiResponse> generateStreaming(WatsonxAiRequest watsonxAiRequest) {
114-
Assert.notNull(watsonxAiRequest, WATSONX_REQUEST_CANNOT_BE_NULL);
117+
public Flux<WatsonxAiChatResponse> generateStreaming(WatsonxAiChatRequest watsonxAiChatRequest) {
118+
Assert.notNull(watsonxAiChatRequest, WATSONX_REQUEST_CANNOT_BE_NULL);
115119

116120
if(this.token.needsRefresh()) {
117121
this.token = this.iamAuthenticator.requestToken();
@@ -120,9 +124,9 @@ public Flux<WatsonxAiResponse> generateStreaming(WatsonxAiRequest watsonxAiReque
120124
return this.webClient.post()
121125
.uri(this.streamEndpoint)
122126
.header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken())
123-
.bodyValue(watsonxAiRequest.withProjectId(this.projectId))
127+
.bodyValue(watsonxAiChatRequest.withProjectId(this.projectId))
124128
.retrieve()
125-
.bodyToFlux(WatsonxAiResponse.class)
129+
.bodyToFlux(WatsonxAiChatResponse.class)
126130
.handle((data, sink) -> {
127131
if (logger.isTraceEnabled()) {
128132
logger.trace(data);
@@ -131,4 +135,21 @@ public Flux<WatsonxAiResponse> generateStreaming(WatsonxAiRequest watsonxAiReque
131135
});
132136
}
133137

138+
@Retryable(retryFor = Exception.class, maxAttempts = 3, backoff = @Backoff(random = true, delay = 1200, maxDelay = 7000, multiplier = 2.5))
139+
public ResponseEntity<WatsonxAiEmbeddingResponse> embeddings(WatsonxAiEmbeddingRequest request) {
140+
Assert.notNull(request, WATSONX_REQUEST_CANNOT_BE_NULL);
141+
142+
if(this.token.needsRefresh()) {
143+
this.token = this.iamAuthenticator.requestToken();
144+
}
145+
146+
return this.restClient.post()
147+
.uri(this.embeddingEndpoint)
148+
.header(HttpHeaders.AUTHORIZATION, "Bearer " + this.token.getAccessToken())
149+
.body(request.withProjectId(projectId))
150+
.retrieve()
151+
.toEntity(WatsonxAiEmbeddingResponse.class);
152+
}
153+
154+
134155
}

models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiRequest.java renamed to models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/api/WatsonxAiChatRequest.java

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,15 @@
2323
import org.springframework.ai.watsonx.WatsonxAiChatOptions;
2424
import org.springframework.util.Assert;
2525

26+
/**
27+
* Java class for Watsonx.ai Chat Request object.
28+
*
29+
* @author Pablo Sanchidrian Herrera
30+
* @since 1.0.0
31+
*/
2632
// @formatter:off
2733
@JsonInclude(JsonInclude.Include.NON_NULL)
28-
public class WatsonxAiRequest {
34+
public class WatsonxAiChatRequest {
2935

3036
@JsonProperty("input")
3137
private String input;
@@ -36,19 +42,14 @@ public class WatsonxAiRequest {
3642
@JsonProperty("project_id")
3743
private String projectId = "";
3844

39-
private WatsonxAiRequest(String input, Map<String, Object> parameters, String modelId, String projectId) {
45+
private WatsonxAiChatRequest(String input, Map<String, Object> parameters, String modelId, String projectId) {
4046
this.input = input;
4147
this.parameters = parameters;
4248
this.modelId = modelId;
4349
this.projectId = projectId;
4450
}
4551

46-
public WatsonxAiRequest withModelId(String modelId) {
47-
this.modelId = modelId;
48-
return this;
49-
}
50-
51-
public WatsonxAiRequest withProjectId(String projectId) {
52+
public WatsonxAiChatRequest withProjectId(String projectId) {
5253
this.projectId = projectId;
5354
return this;
5455
}
@@ -79,8 +80,8 @@ public Builder withParameters(Map<String, Object> parameters) {
7980
return this;
8081
}
8182

82-
public WatsonxAiRequest build() {
83-
return new WatsonxAiRequest(input, parameters, model, "");
83+
public WatsonxAiChatRequest build() {
84+
return new WatsonxAiChatRequest(input, parameters, model, "");
8485
}
8586

8687
}

0 commit comments

Comments
 (0)