|
1 | 1 | /*
|
2 |
| - * Copyright 2023-2024 the original author or authors. |
| 2 | + * Copyright 2023-2025 the original author or authors. |
3 | 3 | *
|
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | * you may not use this file except in compliance with the License.
|
|
58 | 58 | *
|
59 | 59 | * @author Christian Tzolov
|
60 | 60 | * @author Mark Pollack
|
| 61 | + * @author Rodrigo Malara |
61 | 62 | * @since 1.0.0
|
62 | 63 | */
|
63 | 64 | public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {
|
@@ -128,37 +129,38 @@ public EmbeddingResponse call(EmbeddingRequest request) {
|
128 | 129 | .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
|
129 | 130 | this.observationRegistry)
|
130 | 131 | .observe(() -> {
|
131 |
| - PredictionServiceClient client = createPredictionServiceClient(); |
| 132 | + try (PredictionServiceClient client = createPredictionServiceClient()) { |
132 | 133 |
|
133 |
| - EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel()); |
| 134 | + EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel()); |
134 | 135 |
|
135 |
| - PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName, |
136 |
| - finalOptions); |
| 136 | + PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName, |
| 137 | + finalOptions); |
137 | 138 |
|
138 |
| - PredictResponse embeddingResponse = this.retryTemplate |
139 |
| - .execute(context -> getPredictResponse(client, predictRequestBuilder)); |
| 139 | + PredictResponse embeddingResponse = this.retryTemplate |
| 140 | + .execute(context -> getPredictResponse(client, predictRequestBuilder)); |
140 | 141 |
|
141 |
| - int index = 0; |
142 |
| - int totalTokenCount = 0; |
143 |
| - List<Embedding> embeddingList = new ArrayList<>(); |
144 |
| - for (Value prediction : embeddingResponse.getPredictionsList()) { |
145 |
| - Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings"); |
146 |
| - Value statistics = embeddings.getStructValue().getFieldsOrThrow("statistics"); |
147 |
| - Value tokenCount = statistics.getStructValue().getFieldsOrThrow("token_count"); |
148 |
| - totalTokenCount = totalTokenCount + (int) tokenCount.getNumberValue(); |
| 142 | + int index = 0; |
| 143 | + int totalTokenCount = 0; |
| 144 | + List<Embedding> embeddingList = new ArrayList<>(); |
| 145 | + for (Value prediction : embeddingResponse.getPredictionsList()) { |
| 146 | + Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings"); |
| 147 | + Value statistics = embeddings.getStructValue().getFieldsOrThrow("statistics"); |
| 148 | + Value tokenCount = statistics.getStructValue().getFieldsOrThrow("token_count"); |
| 149 | + totalTokenCount = totalTokenCount + (int) tokenCount.getNumberValue(); |
149 | 150 |
|
150 |
| - Value values = embeddings.getStructValue().getFieldsOrThrow("values"); |
| 151 | + Value values = embeddings.getStructValue().getFieldsOrThrow("values"); |
151 | 152 |
|
152 |
| - float[] vectorValues = VertexAiEmbeddingUtils.toVector(values); |
| 153 | + float[] vectorValues = VertexAiEmbeddingUtils.toVector(values); |
153 | 154 |
|
154 |
| - embeddingList.add(new Embedding(vectorValues, index++)); |
155 |
| - } |
156 |
| - EmbeddingResponse response = new EmbeddingResponse(embeddingList, |
157 |
| - generateResponseMetadata(finalOptions.getModel(), totalTokenCount)); |
| 155 | + embeddingList.add(new Embedding(vectorValues, index++)); |
| 156 | + } |
| 157 | + EmbeddingResponse response = new EmbeddingResponse(embeddingList, |
| 158 | + generateResponseMetadata(finalOptions.getModel(), totalTokenCount)); |
158 | 159 |
|
159 |
| - observationContext.setResponse(response); |
| 160 | + observationContext.setResponse(response); |
160 | 161 |
|
161 |
| - return response; |
| 162 | + return response; |
| 163 | + } |
162 | 164 | });
|
163 | 165 | }
|
164 | 166 |
|
|
0 commit comments