Skip to content

Commit 7ea867d

Browse files
committed
Add VertexAi Chat Optios Support + Docs
1 parent d38afc5 commit 7ea867d

File tree

12 files changed

+516
-119
lines changed

12 files changed

+516
-119
lines changed

models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/VertexAiChatClient.java

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
import java.util.stream.Collectors;
2121

2222
import org.springframework.ai.chat.ChatClient;
23+
import org.springframework.ai.chat.ChatOptions;
2324
import org.springframework.ai.chat.ChatResponse;
2425
import org.springframework.ai.chat.Generation;
2526
import org.springframework.ai.chat.prompt.Prompt;
27+
import org.springframework.ai.model.ModelOptionsUtils;
2628
import org.springframework.ai.chat.messages.MessageType;
2729
import org.springframework.ai.vertex.api.VertexAiApi;
2830
import org.springframework.ai.vertex.api.VertexAiApi.GenerateMessageRequest;
@@ -38,40 +40,40 @@ public class VertexAiChatClient implements ChatClient {
3840

3941
private final VertexAiApi vertexAiApi;
4042

41-
private Float temperature;
43+
private final VertexAiChatOptions defaultOptions;
4244

43-
private Float topP;
44-
45-
private Integer topK;
45+
public VertexAiChatClient(VertexAiApi vertexAiApi) {
46+
this(vertexAiApi,
47+
VertexAiChatOptions.builder().withTemperature(0.7f).withCandidateCount(1).withTopK(20).build());
48+
}
4649

47-
private Integer candidateCount;
50+
public VertexAiChatClient(VertexAiApi vertexAiApi, VertexAiChatOptions defaultOptions) {
51+
Assert.notNull(defaultOptions, "Default options must not be null!");
52+
Assert.notNull(vertexAiApi, "VertexAiApi must not be null!");
4853

49-
public VertexAiChatClient(VertexAiApi vertexAiApi) {
5054
this.vertexAiApi = vertexAiApi;
55+
this.defaultOptions = defaultOptions;
5156
}
5257

53-
public VertexAiChatClient withTemperature(Float temperature) {
54-
this.temperature = temperature;
55-
return this;
56-
}
58+
@Override
59+
public ChatResponse call(Prompt prompt) {
5760

58-
public VertexAiChatClient withTopP(Float topP) {
59-
this.topP = topP;
60-
return this;
61-
}
61+
GenerateMessageRequest request = createRequest(prompt);
6262

63-
public VertexAiChatClient withTopK(Integer topK) {
64-
this.topK = topK;
65-
return this;
66-
}
63+
GenerateMessageResponse response = this.vertexAiApi.generateMessage(request);
6764

68-
public VertexAiChatClient withCandidateCount(Integer maxTokens) {
69-
this.candidateCount = maxTokens;
70-
return this;
65+
List<Generation> generations = response.candidates()
66+
.stream()
67+
.map(vmsg -> new Generation(vmsg.content()))
68+
.toList();
69+
70+
return new ChatResponse(generations);
7171
}
7272

73-
@Override
74-
public ChatResponse call(Prompt prompt) {
73+
/**
74+
* Accessible for testing.
75+
*/
76+
GenerateMessageRequest createRequest(Prompt prompt) {
7577

7678
String vertexContext = prompt.getInstructions()
7779
.stream()
@@ -89,17 +91,25 @@ public ChatResponse call(Prompt prompt) {
8991

9092
var vertexPrompt = new MessagePrompt(vertexContext, vertexMessages);
9193

92-
GenerateMessageRequest request = new GenerateMessageRequest(vertexPrompt, this.temperature, this.candidateCount,
93-
this.topP, this.topK);
94-
95-
GenerateMessageResponse response = this.vertexAiApi.generateMessage(request);
96-
97-
List<Generation> generations = response.candidates()
98-
.stream()
99-
.map(vmsg -> new Generation(vmsg.content()))
100-
.toList();
101-
102-
return new ChatResponse(generations);
94+
GenerateMessageRequest request = new GenerateMessageRequest(vertexPrompt);
95+
96+
if (this.defaultOptions != null) {
97+
request = ModelOptionsUtils.merge(request, this.defaultOptions, GenerateMessageRequest.class);
98+
}
99+
100+
if (prompt.getOptions() != null) {
101+
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
102+
VertexAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
103+
ChatOptions.class, VertexAiChatOptions.class);
104+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, GenerateMessageRequest.class);
105+
}
106+
else {
107+
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
108+
+ prompt.getOptions().getClass().getSimpleName());
109+
}
110+
}
111+
112+
return request;
103113
}
104114

105115
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.vertex;
18+
19+
import com.fasterxml.jackson.annotation.JsonInclude;
20+
import com.fasterxml.jackson.annotation.JsonInclude.Include;
21+
import com.fasterxml.jackson.annotation.JsonProperty;
22+
23+
import org.springframework.ai.chat.ChatOptions;
24+
25+
/**
26+
* @author Christian Tzolov
27+
*/
28+
@JsonInclude(Include.NON_NULL)
29+
public class VertexAiChatOptions implements ChatOptions {
30+
31+
// @formatter:off
32+
/**
33+
* Controls the randomness of the output. Values can range over [0.0,1.0], inclusive.
34+
* A value closer to 1.0 will produce responses that are more varied, while a value
35+
* closer to 0.0 will typically result in less surprising responses from the
36+
* generative. This value specifies default to be used by the backend while making the
37+
* call to the generative.
38+
*/
39+
private @JsonProperty("temperature") Float temperature;
40+
41+
/**
42+
* The number of generated response messages to return. This value must be between [1,
43+
* 8], inclusive. Defaults to 1.
44+
*/
45+
private @JsonProperty("candidateCount") Integer candidateCount;
46+
47+
/**
48+
* The maximum cumulative probability of tokens to consider when sampling. The
49+
* generative uses combined Top-k and nucleus sampling. Nucleus sampling considers the
50+
* smallest set of tokens whose probability sum is at least topP.
51+
*/
52+
private @JsonProperty("topP") Float topP;
53+
54+
/**
55+
* The maximum number of tokens to consider when sampling. The generative uses
56+
* combined Top-k and nucleus sampling. Top-k sampling considers the set of topK most
57+
* probable tokens.
58+
*/
59+
private @JsonProperty("topK") Integer topK;
60+
// @formatter:on
61+
62+
public static Builder builder() {
63+
return new Builder();
64+
}
65+
66+
public static class Builder {
67+
68+
private VertexAiChatOptions options = new VertexAiChatOptions();
69+
70+
public Builder withTemperature(Float temperature) {
71+
this.options.temperature = temperature;
72+
return this;
73+
}
74+
75+
public Builder withCandidateCount(Integer candidateCount) {
76+
this.options.candidateCount = candidateCount;
77+
return this;
78+
}
79+
80+
public Builder withTopP(Float topP) {
81+
this.options.topP = topP;
82+
return this;
83+
}
84+
85+
public Builder withTopK(Integer topK) {
86+
this.options.topK = topK;
87+
return this;
88+
}
89+
90+
public VertexAiChatOptions build() {
91+
return this.options;
92+
}
93+
94+
}
95+
96+
@Override
97+
public Float getTemperature() {
98+
return this.temperature;
99+
}
100+
101+
@Override
102+
public void setTemperature(Float temperature) {
103+
this.temperature = temperature;
104+
}
105+
106+
public Integer getCandidateCount() {
107+
return this.candidateCount;
108+
}
109+
110+
public void setCandidateCount(Integer candidateCount) {
111+
this.candidateCount = candidateCount;
112+
}
113+
114+
@Override
115+
public Float getTopP() {
116+
return this.topP;
117+
}
118+
119+
@Override
120+
public void setTopP(Float topP) {
121+
this.topP = topP;
122+
}
123+
124+
@Override
125+
public Integer getTopK() {
126+
return this.topK;
127+
}
128+
129+
@Override
130+
public void setTopK(Integer topK) {
131+
this.topK = topK;
132+
}
133+
134+
}

models/spring-ai-vertex-ai/src/main/java/org/springframework/ai/vertex/api/VertexAiApi.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ public class VertexAiApi {
107107

108108
private final String apiKey;
109109

110-
private final String generateModel;
110+
private final String chatModel;
111111

112112
private final String embeddingModel;
113113

@@ -130,7 +130,7 @@ public VertexAiApi(String apiKey) {
130130
public VertexAiApi(String baseUrl, String apiKey, String model, String embeddingModel,
131131
RestClient.Builder restClientBuilder) {
132132

133-
this.generateModel = model;
133+
this.chatModel = model;
134134
this.embeddingModel = embeddingModel;
135135
this.apiKey = apiKey;
136136

@@ -165,11 +165,12 @@ public void handleError(ClientHttpResponse response) throws IOException {
165165
* @param request Request body.
166166
* @return Response body.
167167
*/
168+
@SuppressWarnings("null")
168169
public GenerateMessageResponse generateMessage(GenerateMessageRequest request) {
169170
Assert.notNull(request, "The request body can not be null.");
170171

171172
return this.restClient.post()
172-
.uri("/models/{model}:generateMessage?key={apiKey}", this.generateModel, this.apiKey)
173+
.uri("/models/{model}:generateMessage?key={apiKey}", this.chatModel, this.apiKey)
173174
.body(request)
174175
.retrieve()
175176
.body(GenerateMessageResponse.class);
@@ -231,7 +232,7 @@ record TokenCount(@JsonProperty("tokenCount") Integer tokenCount) {
231232
}
232233

233234
TokenCount tokenCountResponse = this.restClient.post()
234-
.uri("/models/{model}:countMessageTokens?key={apiKey}", this.generateModel, this.apiKey)
235+
.uri("/models/{model}:countMessageTokens?key={apiKey}", this.chatModel, this.apiKey)
235236
.body(Map.of("prompt", prompt))
236237
.retrieve()
237238
.body(TokenCount.class);
Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package org.springframework.ai.vertex.generation;
1+
package org.springframework.ai.vertex;
22

33
import java.util.Arrays;
44
import java.util.List;
@@ -9,15 +9,14 @@
99

1010
import org.springframework.ai.chat.ChatResponse;
1111
import org.springframework.ai.chat.Generation;
12-
import org.springframework.ai.parser.BeanOutputParser;
13-
import org.springframework.ai.parser.ListOutputParser;
14-
import org.springframework.ai.parser.MapOutputParser;
12+
import org.springframework.ai.chat.messages.Message;
13+
import org.springframework.ai.chat.messages.UserMessage;
1514
import org.springframework.ai.chat.prompt.Prompt;
1615
import org.springframework.ai.chat.prompt.PromptTemplate;
1716
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
18-
import org.springframework.ai.chat.messages.Message;
19-
import org.springframework.ai.chat.messages.UserMessage;
20-
import org.springframework.ai.vertex.VertexAiChatClient;
17+
import org.springframework.ai.parser.BeanOutputParser;
18+
import org.springframework.ai.parser.ListOutputParser;
19+
import org.springframework.ai.parser.MapOutputParser;
2120
import org.springframework.ai.vertex.api.VertexAiApi;
2221
import org.springframework.beans.factory.annotation.Autowired;
2322
import org.springframework.beans.factory.annotation.Value;

0 commit comments

Comments
 (0)