Skip to content

Commit 5054718

Browse files
committed
Rename legacy Vertex AI into Vertx AI PaLM2 classes
1 parent aa55a96 commit 5054718

File tree

28 files changed

+263
-215
lines changed

28 files changed

+263
-215
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,13 @@ public enum ChatCompletionFinishReason {
590590
/**
591591
* The model called a tool.
592592
*/
593-
@JsonProperty("tool_call") TOOL_CALL
593+
@JsonProperty("tool_call") TOOL_CALL,
594+
595+
// anticipation of future changes. Based on:
596+
// https://github.com/mistralai/client-python/blob/main/src/mistralai/models/chat_completion.py
597+
@JsonProperty("error") ERROR,
598+
599+
@JsonProperty("tool_calls") TOOL_CALLS
594600
// @formatter:on
595601

596602
}

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,30 @@ public String getValue() {
9999

100100
}
101101

102+
public enum ChatModel {
103+
104+
GEMINI_PRO_VISION("gemini-pro-vision"),
105+
106+
GEMINI_PRO("gemini-pro");
107+
108+
ChatModel(String value) {
109+
this.value = value;
110+
}
111+
112+
public final String value;
113+
114+
public String getValue() {
115+
return this.value;
116+
}
117+
118+
}
119+
102120
public VertexAiGeminiChatClient(VertexAI vertexAI) {
103121
this(vertexAI,
104-
VertexAiGeminiChatOptions.builder().withModel("gemini-pro-vision").withTemperature(0.8f).build());
122+
VertexAiGeminiChatOptions.builder()
123+
.withModel(ChatModel.GEMINI_PRO_VISION.getValue())
124+
.withTemperature(0.8f)
125+
.build());
105126
}
106127

107128
public VertexAiGeminiChatClient(VertexAI vertexAI, VertexAiGeminiChatOptions options) {

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClientIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ public VertexAI vertexAiApi() {
232232
public VertexAiGeminiChatClient vertexAiEmbedding(VertexAI vertexAi) {
233233
return new VertexAiGeminiChatClient(vertexAi,
234234
VertexAiGeminiChatOptions.builder()
235-
.withModel("gemini-pro-vision")
235+
.withModel(VertexAiGeminiChatClient.ChatModel.GEMINI_PRO_VISION.getValue())
236236
.withTransportType(TransportType.REST)
237237
.build());
238238
}

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/function/VertexAiGeminiChatClientFunctionCallingIT.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ public void functionCallExplicitOpenApiSchema() {
9393
""";
9494

9595
var promptOptions = VertexAiGeminiChatOptions.builder()
96-
.withModel("gemini-pro")
96+
.withModel(VertexAiGeminiChatClient.ChatModel.GEMINI_PRO.getValue())
9797
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
9898
.withName("getCurrentWeather")
9999
.withDescription("Get the current weather in a given location")
@@ -122,7 +122,7 @@ public void functionCallTestInferredOpenApiSchema() {
122122
List<Message> messages = new ArrayList<>(List.of(userMessage));
123123

124124
var promptOptions = VertexAiGeminiChatOptions.builder()
125-
.withModel("gemini-pro")
125+
.withModel(VertexAiGeminiChatClient.ChatModel.GEMINI_PRO.getValue())
126126
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
127127
.withSchemaType(SchemaType.OPEN_API_SCHEMA)
128128
.withName("getCurrentWeather")
@@ -152,7 +152,7 @@ public void functionCallTestInferredOpenApiSchemaStream() {
152152
List<Message> messages = new ArrayList<>(List.of(userMessage));
153153

154154
var promptOptions = VertexAiGeminiChatOptions.builder()
155-
.withModel("gemini-pro")
155+
.withModel(VertexAiGeminiChatClient.ChatModel.GEMINI_PRO.getValue())
156156
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
157157
.withSchemaType(SchemaType.OPEN_API_SCHEMA)
158158
.withName("getCurrentWeather")
@@ -193,7 +193,7 @@ public VertexAI vertexAiApi() {
193193
public VertexAiGeminiChatClient vertexAiEmbedding(VertexAI vertexAi) {
194194
return new VertexAiGeminiChatClient(vertexAi,
195195
VertexAiGeminiChatOptions.builder()
196-
.withModel("gemini-pro")
196+
.withModel(VertexAiGeminiChatClient.ChatModel.GEMINI_PRO.getValue())
197197
.withTemperature(0.9f)
198198
.withTransportType(TransportType.REST)
199199
.build());
Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,31 +25,31 @@
2525
import org.springframework.ai.chat.Generation;
2626
import org.springframework.ai.chat.prompt.Prompt;
2727
import org.springframework.ai.model.ModelOptionsUtils;
28-
import org.springframework.ai.vertexai.palm2.api.VertexAiApi;
29-
import org.springframework.ai.vertexai.palm2.api.VertexAiApi.GenerateMessageRequest;
30-
import org.springframework.ai.vertexai.palm2.api.VertexAiApi.GenerateMessageResponse;
31-
import org.springframework.ai.vertexai.palm2.api.VertexAiApi.MessagePrompt;
28+
import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api;
29+
import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.GenerateMessageRequest;
30+
import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.GenerateMessageResponse;
31+
import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api.MessagePrompt;
3232
import org.springframework.ai.chat.messages.MessageType;
3333
import org.springframework.util.Assert;
3434
import org.springframework.util.CollectionUtils;
3535

3636
/**
3737
* @author Christian Tzolov
3838
*/
39-
public class VertexAiChatClient implements ChatClient {
39+
public class VertexAiPaLm2ChatClient implements ChatClient {
4040

41-
private final VertexAiApi vertexAiApi;
41+
private final VertexAiPaLm2Api vertexAiApi;
4242

43-
private final VertexAiChatOptions defaultOptions;
43+
private final VertexAiPaLm2ChatOptions defaultOptions;
4444

45-
public VertexAiChatClient(VertexAiApi vertexAiApi) {
45+
public VertexAiPaLm2ChatClient(VertexAiPaLm2Api vertexAiApi) {
4646
this(vertexAiApi,
47-
VertexAiChatOptions.builder().withTemperature(0.7f).withCandidateCount(1).withTopK(20).build());
47+
VertexAiPaLm2ChatOptions.builder().withTemperature(0.7f).withCandidateCount(1).withTopK(20).build());
4848
}
4949

50-
public VertexAiChatClient(VertexAiApi vertexAiApi, VertexAiChatOptions defaultOptions) {
50+
public VertexAiPaLm2ChatClient(VertexAiPaLm2Api vertexAiApi, VertexAiPaLm2ChatOptions defaultOptions) {
5151
Assert.notNull(defaultOptions, "Default options must not be null!");
52-
Assert.notNull(vertexAiApi, "VertexAiApi must not be null!");
52+
Assert.notNull(vertexAiApi, "VertexAiPaLm2Api must not be null!");
5353

5454
this.vertexAiApi = vertexAiApi;
5555
this.defaultOptions = defaultOptions;
@@ -81,10 +81,10 @@ GenerateMessageRequest createRequest(Prompt prompt) {
8181
.map(m -> m.getContent())
8282
.collect(Collectors.joining("\n"));
8383

84-
List<VertexAiApi.Message> vertexMessages = prompt.getInstructions()
84+
List<VertexAiPaLm2Api.Message> vertexMessages = prompt.getInstructions()
8585
.stream()
8686
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
87-
.map(m -> new VertexAiApi.Message(m.getMessageType().getValue(), m.getContent()))
87+
.map(m -> new VertexAiPaLm2Api.Message(m.getMessageType().getValue(), m.getContent()))
8888
.toList();
8989

9090
Assert.isTrue(!CollectionUtils.isEmpty(vertexMessages), "No user or assistant messages found in the prompt!");
@@ -99,8 +99,8 @@ GenerateMessageRequest createRequest(Prompt prompt) {
9999

100100
if (prompt.getOptions() != null) {
101101
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
102-
VertexAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
103-
ChatOptions.class, VertexAiChatOptions.class);
102+
VertexAiPaLm2ChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
103+
ChatOptions.class, VertexAiPaLm2ChatOptions.class);
104104
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, GenerateMessageRequest.class);
105105
}
106106
else {
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
* @author Christian Tzolov
2727
*/
2828
@JsonInclude(Include.NON_NULL)
29-
public class VertexAiChatOptions implements ChatOptions {
29+
public class VertexAiPaLm2ChatOptions implements ChatOptions {
3030

3131
// @formatter:off
3232
/**
@@ -65,7 +65,7 @@ public static Builder builder() {
6565

6666
public static class Builder {
6767

68-
private VertexAiChatOptions options = new VertexAiChatOptions();
68+
private VertexAiPaLm2ChatOptions options = new VertexAiPaLm2ChatOptions();
6969

7070
public Builder withTemperature(Float temperature) {
7171
this.options.temperature = temperature;
@@ -87,7 +87,7 @@ public Builder withTopK(Integer topK) {
8787
return this;
8888
}
8989

90-
public VertexAiChatOptions build() {
90+
public VertexAiPaLm2ChatOptions build() {
9191
return this.options;
9292
}
9393

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,16 @@
2424
import org.springframework.ai.embedding.Embedding;
2525
import org.springframework.ai.embedding.EmbeddingRequest;
2626
import org.springframework.ai.embedding.EmbeddingResponse;
27-
import org.springframework.ai.vertexai.palm2.api.VertexAiApi;
27+
import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api;
2828

2929
/**
3030
* @author Christian Tzolov
3131
*/
32-
public class VertexAiEmbeddingClient extends AbstractEmbeddingClient {
32+
public class VertexAiPaLm2EmbeddingClient extends AbstractEmbeddingClient {
3333

34-
private final VertexAiApi vertexAiApi;
34+
private final VertexAiPaLm2Api vertexAiApi;
3535

36-
public VertexAiEmbeddingClient(VertexAiApi vertexAiApi) {
36+
public VertexAiPaLm2EmbeddingClient(VertexAiPaLm2Api vertexAiApi) {
3737
this.vertexAiApi = vertexAiApi;
3838
}
3939

@@ -44,7 +44,7 @@ public List<Double> embed(Document document) {
4444

4545
@Override
4646
public EmbeddingResponse call(EmbeddingRequest request) {
47-
List<VertexAiApi.Embedding> vertexEmbeddings = this.vertexAiApi.batchEmbedText(request.getInstructions());
47+
List<VertexAiPaLm2Api.Embedding> vertexEmbeddings = this.vertexAiApi.batchEmbedText(request.getInstructions());
4848
AtomicInteger indexCounter = new AtomicInteger(0);
4949
List<Embedding> embeddings = vertexEmbeddings.stream()
5050
.map(vm -> new Embedding(vm.value(), indexCounter.getAndIncrement()))

models/spring-ai-vertex-ai-palm2/src/main/java/org/springframework/ai/vertexai/palm2/aot/VertexRuntimeHints.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package org.springframework.ai.vertexai.palm2.aot;
22

3-
import org.springframework.ai.vertexai.palm2.api.VertexAiApi;
3+
import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api;
44
import org.springframework.aot.hint.MemberCategory;
55
import org.springframework.aot.hint.RuntimeHints;
66
import org.springframework.aot.hint.RuntimeHintsRegistrar;
@@ -20,7 +20,7 @@ public class VertexRuntimeHints implements RuntimeHintsRegistrar {
2020
@Override
2121
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
2222
var mcs = MemberCategory.values();
23-
for (var tr : findJsonAnnotatedClassesInPackage(VertexAiApi.class))
23+
for (var tr : findJsonAnnotatedClassesInPackage(VertexAiPaLm2Api.class))
2424
hints.reflection().registerType(tr, mcs);
2525
}
2626

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,11 @@
8282
* topK=null
8383
* </pre>
8484
*
85+
* https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
86+
*
8587
* @author Christian Tzolov
8688
*/
87-
public class VertexAiApi {
89+
public class VertexAiPaLm2Api {
8890

8991
/**
9092
* The default generation model. This model is used to generate responses for the
@@ -115,7 +117,7 @@ public class VertexAiApi {
115117
* Create a new chat completion api.
116118
* @param apiKey vertex apiKey.
117119
*/
118-
public VertexAiApi(String apiKey) {
120+
public VertexAiPaLm2Api(String apiKey) {
119121
this(DEFAULT_BASE_URL, apiKey, DEFAULT_GENERATE_MODEL, DEFAULT_EMBEDDING_MODEL, RestClient.builder());
120122
}
121123

@@ -127,7 +129,7 @@ public VertexAiApi(String apiKey) {
127129
* @param embeddingModel vertex embedding model.
128130
* @param restClientBuilder RestClient builder.
129131
*/
130-
public VertexAiApi(String baseUrl, String apiKey, String model, String embeddingModel,
132+
public VertexAiPaLm2Api(String baseUrl, String apiKey, String model, String embeddingModel,
131133
RestClient.Builder restClientBuilder) {
132134

133135
this.chatModel = model;
Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
import org.springframework.ai.parser.BeanOutputParser;
1818
import org.springframework.ai.parser.ListOutputParser;
1919
import org.springframework.ai.parser.MapOutputParser;
20-
import org.springframework.ai.vertexai.palm2.VertexAiChatClient;
21-
import org.springframework.ai.vertexai.palm2.api.VertexAiApi;
20+
import org.springframework.ai.vertexai.palm2.api.VertexAiPaLm2Api;
2221
import org.springframework.beans.factory.annotation.Autowired;
2322
import org.springframework.beans.factory.annotation.Value;
2423
import org.springframework.boot.SpringBootConfiguration;
@@ -31,10 +30,10 @@
3130

3231
@SpringBootTest
3332
@EnabledIfEnvironmentVariable(named = "PALM_API_KEY", matches = ".*")
34-
class VertexAiChatGenerationClientIT {
33+
class VertexAiPaLm2ChatGenerationClientIT {
3534

3635
@Autowired
37-
private VertexAiChatClient client;
36+
private VertexAiPaLm2ChatClient client;
3837

3938
@Value("classpath:/prompts/system-message.st")
4039
private Resource systemResource;
@@ -117,13 +116,13 @@ void beanOutputParserRecords() {
117116
public static class TestConfiguration {
118117

119118
@Bean
120-
public VertexAiApi vertexAiApi() {
121-
return new VertexAiApi(System.getenv("PALM_API_KEY"));
119+
public VertexAiPaLm2Api vertexAiApi() {
120+
return new VertexAiPaLm2Api(System.getenv("PALM_API_KEY"));
122121
}
123122

124123
@Bean
125-
public VertexAiChatClient vertexAiEmbedding(VertexAiApi vertexAiApi) {
126-
return new VertexAiChatClient(vertexAiApi);
124+
public VertexAiPaLm2ChatClient vertexAiEmbedding(VertexAiPaLm2Api vertexAiApi) {
125+
return new VertexAiPaLm2ChatClient(vertexAiApi);
127126
}
128127

129128
}

0 commit comments

Comments
 (0)