Skip to content

Commit ccf190c

Browse files
anders-swansonmarkpollack
authored andcommitted
Add OCI GenAI embedding model support
This commit introduces support for Oracle Cloud Infrastructure (OCI) GenAI embedding models in Spring AI. It includes: * New OCIEmbeddingModel class for interacting with OCI GenAI API * Auto-configuration for easy setup and integration * Properties for configuring OCI connection and embedding options * Documentation updates explaining usage and configuration * Integration tests to verify functionality Signed-off-by: Anders Swanson <anders.swanson@oracle.com>
1 parent 5da44c4 commit ccf190c

File tree

20 files changed

+1158
-2
lines changed

20 files changed

+1158
-2
lines changed

models/spring-ai-oci-genai/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[Oracle Cloud Infrastructure GenAI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm)

models/spring-ai-oci-genai/pom.xml

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
4+
<modelVersion>4.0.0</modelVersion>
5+
<parent>
6+
<groupId>org.springframework.ai</groupId>
7+
<artifactId>spring-ai</artifactId>
8+
<version>1.0.0-SNAPSHOT</version>
9+
<relativePath>../../pom.xml</relativePath>
10+
</parent>
11+
<artifactId>spring-ai-oci-genai</artifactId>
12+
<packaging>jar</packaging>
13+
<name>Spring AI Model - OCI GenAI</name>
14+
<description>OCI GenAI models support</description>
15+
<url>https://github.com/spring-projects/spring-ai</url>
16+
17+
<scm>
18+
<url>https://github.com/spring-projects/spring-ai</url>
19+
<connection>git://github.com/spring-projects/spring-ai.git</connection>
20+
<developerConnection>git@github.com:spring-projects/spring-ai.git</developerConnection>
21+
</scm>
22+
23+
<dependencies>
24+
25+
<!-- production dependencies -->
26+
<dependency>
27+
<groupId>org.springframework.ai</groupId>
28+
<artifactId>spring-ai-core</artifactId>
29+
<version>${project.parent.version}</version>
30+
</dependency>
31+
32+
<dependency>
33+
<groupId>com.oracle.oci.sdk</groupId>
34+
<artifactId>oci-java-sdk-shaded-full</artifactId>
35+
<version>${oci-sdk-version}</version>
36+
</dependency>
37+
38+
<dependency>
39+
<groupId>com.oracle.oci.sdk</groupId>
40+
<artifactId>oci-java-sdk-addons-oke-workload-identity</artifactId>
41+
<version>${oci-sdk-version}</version>
42+
</dependency>
43+
44+
<!-- NOTE: Required only by the @ConstructorBinding. -->
45+
<dependency>
46+
<groupId>org.springframework.boot</groupId>
47+
<artifactId>spring-boot</artifactId>
48+
</dependency>
49+
50+
<dependency>
51+
<groupId>org.springframework</groupId>
52+
<artifactId>spring-context-support</artifactId>
53+
</dependency>
54+
<dependency>
55+
<groupId>org.springframework.boot</groupId>
56+
<artifactId>spring-boot-starter-logging</artifactId>
57+
</dependency>
58+
59+
<!-- test dependencies -->
60+
<dependency>
61+
<groupId>org.springframework.ai</groupId>
62+
<artifactId>spring-ai-test</artifactId>
63+
<version>${project.version}</version>
64+
<scope>test</scope>
65+
</dependency>
66+
67+
</dependencies>
68+
69+
</project>
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
/*
2+
* Copyright 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.oci;
17+
18+
import java.util.ArrayList;
19+
import java.util.List;
20+
import java.util.Objects;
21+
import java.util.concurrent.atomic.AtomicInteger;
22+
23+
import com.oracle.bmc.generativeaiinference.GenerativeAiInference;
24+
import com.oracle.bmc.generativeaiinference.model.DedicatedServingMode;
25+
import com.oracle.bmc.generativeaiinference.model.EmbedTextDetails;
26+
import com.oracle.bmc.generativeaiinference.model.EmbedTextResult;
27+
import com.oracle.bmc.generativeaiinference.model.OnDemandServingMode;
28+
import com.oracle.bmc.generativeaiinference.model.ServingMode;
29+
import com.oracle.bmc.generativeaiinference.requests.EmbedTextRequest;
30+
import io.micrometer.observation.ObservationRegistry;
31+
import org.springframework.ai.chat.metadata.EmptyUsage;
32+
import org.springframework.ai.document.Document;
33+
import org.springframework.ai.embedding.AbstractEmbeddingModel;
34+
import org.springframework.ai.embedding.Embedding;
35+
import org.springframework.ai.embedding.EmbeddingOptions;
36+
import org.springframework.ai.embedding.EmbeddingRequest;
37+
import org.springframework.ai.embedding.EmbeddingResponse;
38+
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
39+
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
40+
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
41+
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
42+
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
43+
import org.springframework.ai.model.ModelOptionsUtils;
44+
import org.springframework.ai.observation.conventions.AiProvider;
45+
import org.springframework.util.Assert;
46+
47+
/**
48+
* {@link org.springframework.ai.embedding.EmbeddingModel} implementation that uses the
49+
* OCI GenAI Embedding API.
50+
*
51+
* @author Anders Swanson
52+
* @since 1.0.0
53+
*/
54+
public class OCIEmbeddingModel extends AbstractEmbeddingModel {
55+
56+
// The OCI GenAI API has a batch size of 96 for embed text requests.
57+
private static final int EMBEDTEXT_BATCH_SIZE = 96;
58+
59+
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();
60+
61+
private final GenerativeAiInference genAi;
62+
63+
private final OCIEmbeddingOptions options;
64+
65+
private final ObservationRegistry observationRegistry;
66+
67+
private final EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
68+
69+
public OCIEmbeddingModel(GenerativeAiInference genAi, OCIEmbeddingOptions options) {
70+
this(genAi, options, ObservationRegistry.NOOP);
71+
}
72+
73+
public OCIEmbeddingModel(GenerativeAiInference genAi, OCIEmbeddingOptions options,
74+
ObservationRegistry observationRegistry) {
75+
Assert.notNull(genAi, "com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient must not be null");
76+
Assert.notNull(options, "options must not be null");
77+
Assert.notNull(observationRegistry, "observationRegistry must not be null");
78+
this.genAi = genAi;
79+
this.options = options;
80+
this.observationRegistry = observationRegistry;
81+
}
82+
83+
@Override
84+
public EmbeddingResponse call(EmbeddingRequest request) {
85+
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
86+
OCIEmbeddingOptions runtimeOptions = mergeOptions(request.getOptions(), options);
87+
List<EmbedTextRequest> embedTextRequests = createRequests(request.getInstructions(), runtimeOptions);
88+
89+
EmbeddingModelObservationContext context = EmbeddingModelObservationContext.builder()
90+
.embeddingRequest(request)
91+
.provider(AiProvider.OCI_GENAI.value())
92+
.requestOptions(runtimeOptions)
93+
.build();
94+
95+
return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
96+
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> context,
97+
this.observationRegistry)
98+
.observe(() -> embedAllWithContext(embedTextRequests, context));
99+
}
100+
101+
@Override
102+
public float[] embed(Document document) {
103+
return embed(document.getContent());
104+
}
105+
106+
private EmbeddingResponse embedAllWithContext(List<EmbedTextRequest> embedTextRequests,
107+
EmbeddingModelObservationContext context) {
108+
String modelId = null;
109+
AtomicInteger index = new AtomicInteger(0);
110+
List<Embedding> embeddings = new ArrayList<>();
111+
for (EmbedTextRequest embedTextRequest : embedTextRequests) {
112+
EmbedTextResult embedTextResult = genAi.embedText(embedTextRequest).getEmbedTextResult();
113+
if (modelId == null) {
114+
modelId = embedTextResult.getModelId();
115+
}
116+
for (List<Float> e : embedTextResult.getEmbeddings()) {
117+
float[] data = toFloats(e);
118+
embeddings.add(new Embedding(data, index.getAndIncrement()));
119+
}
120+
}
121+
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
122+
metadata.setModel(modelId);
123+
metadata.setUsage(new EmptyUsage());
124+
EmbeddingResponse embeddingResponse = new EmbeddingResponse(embeddings, metadata);
125+
context.setResponse(embeddingResponse);
126+
return embeddingResponse;
127+
}
128+
129+
private ServingMode servingMode(OCIEmbeddingOptions embeddingOptions) {
130+
return switch (embeddingOptions.getServingMode()) {
131+
case "dedicated" -> DedicatedServingMode.builder().endpointId(embeddingOptions.getModel()).build();
132+
case "on-demand" -> OnDemandServingMode.builder().modelId(embeddingOptions.getModel()).build();
133+
default -> throw new IllegalArgumentException(
134+
"unknown serving mode for OCI embedding model: " + embeddingOptions.getServingMode());
135+
};
136+
}
137+
138+
private List<EmbedTextRequest> createRequests(List<String> inputs, OCIEmbeddingOptions embeddingOptions) {
139+
int size = inputs.size();
140+
List<EmbedTextRequest> requests = new ArrayList<>();
141+
for (int i = 0; i < inputs.size(); i += EMBEDTEXT_BATCH_SIZE) {
142+
List<String> batch = inputs.subList(i, Math.min(i + EMBEDTEXT_BATCH_SIZE, size));
143+
requests.add(createRequest(batch, embeddingOptions));
144+
}
145+
return requests;
146+
}
147+
148+
private EmbedTextRequest createRequest(List<String> inputs, OCIEmbeddingOptions embeddingOptions) {
149+
EmbedTextDetails embedTextDetails = EmbedTextDetails.builder()
150+
.servingMode(servingMode(embeddingOptions))
151+
.compartmentId(embeddingOptions.getCompartment())
152+
.inputs(inputs)
153+
.truncate(Objects.requireNonNullElse(embeddingOptions.getTruncate(), EmbedTextDetails.Truncate.End))
154+
.build();
155+
return EmbedTextRequest.builder().embedTextDetails(embedTextDetails).build();
156+
}
157+
158+
private OCIEmbeddingOptions mergeOptions(EmbeddingOptions embeddingOptions, OCIEmbeddingOptions defaultOptions) {
159+
if (embeddingOptions instanceof OCIEmbeddingOptions) {
160+
OCIEmbeddingOptions dynamicOptions = ModelOptionsUtils.merge(embeddingOptions, defaultOptions,
161+
OCIEmbeddingOptions.class);
162+
if (dynamicOptions != null) {
163+
return dynamicOptions;
164+
}
165+
}
166+
return defaultOptions;
167+
}
168+
169+
private float[] toFloats(List<Float> embedding) {
170+
float[] floats = new float[embedding.size()];
171+
for (int i = 0; i < embedding.size(); i++) {
172+
floats[i] = embedding.get(i);
173+
}
174+
return floats;
175+
}
176+
177+
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.oci;
17+
18+
import com.fasterxml.jackson.annotation.JsonInclude;
19+
import com.fasterxml.jackson.annotation.JsonProperty;
20+
import com.oracle.bmc.generativeaiinference.model.EmbedTextDetails;
21+
import org.springframework.ai.embedding.EmbeddingOptions;
22+
23+
/**
24+
* The configuration information for OCI embedding requests
25+
*
26+
* @author Anders Swanson
27+
*/
28+
@JsonInclude(JsonInclude.Include.NON_NULL)
29+
public class OCIEmbeddingOptions implements EmbeddingOptions {
30+
31+
private @JsonProperty("model") String model;
32+
33+
private @JsonProperty("compartment") String compartment;
34+
35+
private @JsonProperty("servingMode") String servingMode;
36+
37+
private @JsonProperty("truncate") EmbedTextDetails.Truncate truncate;
38+
39+
public static Builder builder() {
40+
return new Builder();
41+
}
42+
43+
public static class Builder {
44+
45+
private final OCIEmbeddingOptions options = new OCIEmbeddingOptions();
46+
47+
public Builder withModel(String model) {
48+
this.options.setModel(model);
49+
return this;
50+
}
51+
52+
public Builder withCompartment(String compartment) {
53+
this.options.setCompartment(compartment);
54+
return this;
55+
}
56+
57+
public Builder withServingMode(String servingMode) {
58+
this.options.setServingMode(servingMode);
59+
return this;
60+
}
61+
62+
public Builder withTruncate(EmbedTextDetails.Truncate truncate) {
63+
this.options.truncate = truncate;
64+
return this;
65+
}
66+
67+
public OCIEmbeddingOptions build() {
68+
return this.options;
69+
}
70+
71+
}
72+
73+
public String getModel() {
74+
return this.model;
75+
}
76+
77+
/**
78+
* Not used by OCI GenAI.
79+
* @return null
80+
*/
81+
@Override
82+
public Integer getDimensions() {
83+
return null;
84+
}
85+
86+
public void setModel(String model) {
87+
this.model = model;
88+
}
89+
90+
public String getCompartment() {
91+
return compartment;
92+
}
93+
94+
public void setCompartment(String compartment) {
95+
this.compartment = compartment;
96+
}
97+
98+
public String getServingMode() {
99+
return servingMode;
100+
}
101+
102+
public void setServingMode(String servingMode) {
103+
this.servingMode = servingMode;
104+
}
105+
106+
public EmbedTextDetails.Truncate getTruncate() {
107+
return truncate;
108+
}
109+
110+
public void setTruncate(EmbedTextDetails.Truncate truncate) {
111+
this.truncate = truncate;
112+
}
113+
114+
}

0 commit comments

Comments
 (0)