|
| 1 | +/* |
| 2 | + ** Copyright (c) 2024, Oracle and/or its affiliates. |
| 3 | + ** Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
| 4 | + */ |
| 5 | + |
| 6 | +package com.oracle.cloud.spring.genai; |
| 7 | + |
| 8 | +import com.oracle.bmc.auth.RegionProvider; |
| 9 | +import com.oracle.bmc.generativeaiinference.GenerativeAiInference; |
| 10 | +import com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient; |
| 11 | +import com.oracle.bmc.generativeaiinference.model.DedicatedServingMode; |
| 12 | +import com.oracle.bmc.generativeaiinference.model.EmbedTextDetails; |
| 13 | +import com.oracle.bmc.generativeaiinference.model.OnDemandServingMode; |
| 14 | +import com.oracle.bmc.generativeaiinference.model.ServingMode; |
| 15 | +import com.oracle.cloud.spring.autoconfigure.core.CredentialsProvider; |
| 16 | +import org.springframework.beans.factory.annotation.Qualifier; |
| 17 | +import org.springframework.boot.autoconfigure.AutoConfiguration; |
| 18 | +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; |
| 19 | +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; |
| 20 | +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; |
| 21 | +import org.springframework.boot.context.properties.EnableConfigurationProperties; |
| 22 | +import org.springframework.cloud.context.config.annotation.RefreshScope; |
| 23 | +import org.springframework.context.annotation.Bean; |
| 24 | +import org.springframework.util.StringUtils; |
| 25 | + |
| 26 | +import static com.oracle.cloud.spring.autoconfigure.core.CredentialsProviderAutoConfiguration.credentialsProviderQualifier; |
| 27 | +import static com.oracle.cloud.spring.autoconfigure.core.RegionProviderAutoConfiguration.regionProviderQualifier; |
| 28 | + |
| 29 | +/** |
| 30 | + * Auto-configuration for initializing the OCI GenAI component. |
| 31 | + * Depends on {@link com.oracle.cloud.spring.autoconfigure.core.CredentialsProviderAutoConfiguration} and |
| 32 | + * {@link com.oracle.cloud.spring.autoconfigure.core.RegionProviderAutoConfiguration} |
| 33 | + * for loading the Authentication configuration |
| 34 | + * |
| 35 | + * @see ChatModel |
| 36 | + * @see EmbeddingModel |
| 37 | + */ |
| 38 | +@AutoConfiguration |
| 39 | +@ConditionalOnClass({ChatModel.class}) |
| 40 | +@EnableConfigurationProperties(GenAIProperties.class) |
| 41 | +@ConditionalOnProperty(name = "spring.cloud.oci.genai.enabled", havingValue = "true", matchIfMissing = true) |
| 42 | +public class GenAIAutoConfiguration { |
| 43 | + private final GenAIProperties properties; |
| 44 | + |
| 45 | + public GenAIAutoConfiguration(GenAIProperties properties) { |
| 46 | + this.properties = properties; |
| 47 | + } |
| 48 | + |
| 49 | + @Bean |
| 50 | + @RefreshScope |
| 51 | + @ConditionalOnProperty(name = "spring.cloud.oci.genai.embedding.enabled", havingValue = "true", matchIfMissing = true) |
| 52 | + @ConditionalOnMissingBean(EmbeddingModel.class) |
| 53 | + public EmbeddingModel embeddingModel(GenerativeAiInference generativeAiInference) { |
| 54 | + GenAIProperties.Embedding embedding = properties.getEmbedding(); |
| 55 | + return EmbeddingModelImpl.builder() |
| 56 | + .client(generativeAiInference) |
| 57 | + .truncate(StringUtils.hasText(embedding.getTruncate()) ? |
| 58 | + EmbedTextDetails.Truncate.valueOf(embedding.getTruncate()) : |
| 59 | + EmbedTextDetails.Truncate.None) |
| 60 | + .compartment(embedding.getCompartment()) |
| 61 | + .servingMode(servingMode(embedding.getOnDemandModelId(), embedding.getDedicatedClusterEndpoint())) |
| 62 | + .build(); |
| 63 | + } |
| 64 | + |
| 65 | + @Bean |
| 66 | + @RefreshScope |
| 67 | + @ConditionalOnProperty(name = "spring.cloud.oci.genai.chat.enabled", havingValue = "true", matchIfMissing = true) |
| 68 | + @ConditionalOnMissingBean(ChatModel.class) |
| 69 | + public ChatModel chatModel(GenerativeAiInference generativeAiInference) { |
| 70 | + GenAIProperties.Chat chat = properties.getChat(); |
| 71 | + return ChatModelImpl.builder() |
| 72 | + .client(generativeAiInference) |
| 73 | + .preambleOverride(chat.getPreambleOverride()) |
| 74 | + .inferenceRequestType(chat.getInferenceRequestType()) |
| 75 | + .servingMode(servingMode(chat.getOnDemandModelId(), chat.getDedicatedClusterEndpoint())) |
| 76 | + .topK(chat.getTopK()) |
| 77 | + .topP(chat.getTopP()) |
| 78 | + .compartment(chat.getCompartment()) |
| 79 | + .frequencyPenalty(chat.getFrequencyPenalty()) |
| 80 | + .presencePenalty(chat.getPresencePenalty()) |
| 81 | + .temperature(chat.getTemperature()) |
| 82 | + .build(); |
| 83 | + } |
| 84 | + |
| 85 | + @Bean |
| 86 | + @RefreshScope |
| 87 | + @ConditionalOnMissingBean |
| 88 | + GenerativeAiInference genAIClient(@Qualifier(regionProviderQualifier) RegionProvider regionProvider, |
| 89 | + @Qualifier(credentialsProviderQualifier) |
| 90 | + CredentialsProvider cp) { |
| 91 | + GenerativeAiInference generativeAiInference = GenerativeAiInferenceClient.builder() |
| 92 | + .build(cp.getAuthenticationDetailsProvider()); |
| 93 | + if (regionProvider.getRegion() != null) { |
| 94 | + generativeAiInference.setRegion(regionProvider.getRegion()); |
| 95 | + } |
| 96 | + return generativeAiInference; |
| 97 | + } |
| 98 | + |
| 99 | + private ServingMode servingMode(String onDemandModelId, String dedicatedClusterEndpoint) { |
| 100 | + if (StringUtils.hasText(onDemandModelId)) { |
| 101 | + return OnDemandServingMode.builder().modelId(onDemandModelId).build(); |
| 102 | + } else if (StringUtils.hasText(dedicatedClusterEndpoint)) { |
| 103 | + return DedicatedServingMode.builder().endpointId(dedicatedClusterEndpoint).build(); |
| 104 | + } |
| 105 | + throw new IllegalArgumentException("One of spring.cloud.oci.genai.embedding.onDemandModelId or spring.cloud.oci.genai.embedding.dedicatedClusterEndpoint must be specified."); |
| 106 | + } |
| 107 | +} |
0 commit comments