Skip to content

Commit 2517ca8

Browse files
rmalarasobychacko
authored andcommitted
GH-2609: Fix the thread leak issue in VertexAiTextEmbeddingModel
Fixes: #2609 #2609 The PredictionServiceClient was not being closed. Connections are kept open preventing resources from being disposed properly. Signed-off-by: rmalara <rmalara@interactions.com> Signed-off-by: Rodrigo Malara <rodrigomalara@gmail.com>
1 parent 2294c5a commit 2517ca8

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -58,6 +58,7 @@
5858
*
5959
* @author Christian Tzolov
6060
* @author Mark Pollack
61+
* @author Rodrigo Malara
6162
* @since 1.0.0
6263
*/
6364
public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {
@@ -128,37 +129,38 @@ public EmbeddingResponse call(EmbeddingRequest request) {
128129
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
129130
this.observationRegistry)
130131
.observe(() -> {
131-
PredictionServiceClient client = createPredictionServiceClient();
132+
try (PredictionServiceClient client = createPredictionServiceClient()) {
132133

133-
EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel());
134+
EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel());
134135

135-
PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName,
136-
finalOptions);
136+
PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName,
137+
finalOptions);
137138

138-
PredictResponse embeddingResponse = this.retryTemplate
139-
.execute(context -> getPredictResponse(client, predictRequestBuilder));
139+
PredictResponse embeddingResponse = this.retryTemplate
140+
.execute(context -> getPredictResponse(client, predictRequestBuilder));
140141

141-
int index = 0;
142-
int totalTokenCount = 0;
143-
List<Embedding> embeddingList = new ArrayList<>();
144-
for (Value prediction : embeddingResponse.getPredictionsList()) {
145-
Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
146-
Value statistics = embeddings.getStructValue().getFieldsOrThrow("statistics");
147-
Value tokenCount = statistics.getStructValue().getFieldsOrThrow("token_count");
148-
totalTokenCount = totalTokenCount + (int) tokenCount.getNumberValue();
142+
int index = 0;
143+
int totalTokenCount = 0;
144+
List<Embedding> embeddingList = new ArrayList<>();
145+
for (Value prediction : embeddingResponse.getPredictionsList()) {
146+
Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
147+
Value statistics = embeddings.getStructValue().getFieldsOrThrow("statistics");
148+
Value tokenCount = statistics.getStructValue().getFieldsOrThrow("token_count");
149+
totalTokenCount = totalTokenCount + (int) tokenCount.getNumberValue();
149150

150-
Value values = embeddings.getStructValue().getFieldsOrThrow("values");
151+
Value values = embeddings.getStructValue().getFieldsOrThrow("values");
151152

152-
float[] vectorValues = VertexAiEmbeddingUtils.toVector(values);
153+
float[] vectorValues = VertexAiEmbeddingUtils.toVector(values);
153154

154-
embeddingList.add(new Embedding(vectorValues, index++));
155-
}
156-
EmbeddingResponse response = new EmbeddingResponse(embeddingList,
157-
generateResponseMetadata(finalOptions.getModel(), totalTokenCount));
155+
embeddingList.add(new Embedding(vectorValues, index++));
156+
}
157+
EmbeddingResponse response = new EmbeddingResponse(embeddingList,
158+
generateResponseMetadata(finalOptions.getModel(), totalTokenCount));
158159

159-
observationContext.setResponse(response);
160+
observationContext.setResponse(response);
160161

161-
return response;
162+
return response;
163+
}
162164
});
163165
}
164166

0 commit comments

Comments
 (0)