diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index 7cd7bdb23a9..de09d0edf07 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -206,7 +206,6 @@ public List getFunctionCallbacks() { return this.functionCallbacks; } - @Override public void setFunctionCallbacks(List functionCallbacks) { Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); this.functionCallbacks = functionCallbacks; @@ -217,7 +216,6 @@ public Set getFunctions() { return this.functions; } - @Override public void setFunctions(Set functions) { Assert.notNull(functions, "Function must not be null"); this.functions = functions; diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index 86c0bda36bb..16167472954 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -23,12 +23,11 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice; import org.springframework.ai.mistralai.api.MistralAiApi.FunctionTool; import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.function.PortableFunctionCallingOptions; import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.util.Assert; @@ -38,7 +37,7 @@ * @since 0.8.1 */ @JsonInclude(JsonInclude.Include.NON_NULL) -public class MistralAiChatOptions implements FunctionCallingOptions, ChatOptions { +public class MistralAiChatOptions implements PortableFunctionCallingOptions { /** * ID of the model to use @@ -292,21 +291,9 @@ public List getFunctionCallbacks() { return this.functionCallbacks; } - @Override - public void setFunctionCallbacks(List functionCallbacks) { - Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); - this.functionCallbacks = functionCallbacks; - } - @Override public Set getFunctions() { return this.functions; } - @Override - public void setFunctions(Set functions) { - Assert.notNull(functions, "Function must not be null"); - this.functions = functions; - } - } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java index b99c7cca6de..46a003dde6c 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java @@ -77,7 +77,7 @@ public class OpenAiAudioTranscriptionClient */ public OpenAiAudioTranscriptionClient(OpenAiAudioApi audioApi) { this(audioApi, - OpenAiAudioTranscriptionOptions.builder() + OpenAiAudioTranscriptionOptionsBuilder.builder() .withModel(OpenAiAudioApi.WhisperModel.WHISPER_1.getValue()) .withResponseFormat(OpenAiAudioApi.TranscriptResponseFormat.JSON) .withTemperature(0.7f) @@ -209,19 +209,18 @@ private byte[] toBytes(Resource resource) { private OpenAiAudioTranscriptionOptions merge(OpenAiAudioTranscriptionOptions source, OpenAiAudioTranscriptionOptions target) { - + OpenAiAudioTranscriptionOptionsBuilder builder = OpenAiAudioTranscriptionOptionsBuilder.builder(); if (source == null) { - source = new OpenAiAudioTranscriptionOptions(); + source = OpenAiAudioTranscriptionOptionsBuilder.builder().build(); } - OpenAiAudioTranscriptionOptions merged = new OpenAiAudioTranscriptionOptions(); - merged.setLanguage(source.getLanguage() != null ? source.getLanguage() : target.getLanguage()); - merged.setModel(source.getModel() != null ? source.getModel() : target.getModel()); - merged.setPrompt(source.getPrompt() != null ? source.getPrompt() : target.getPrompt()); - merged.setResponseFormat( + builder.withLanguage(source.getLanguage() != null ? source.getLanguage() : target.getLanguage()); + builder.withModel(source.getModel() != null ? source.getModel() : target.getModel()); + builder.withPrompt(source.getPrompt() != null ? source.getPrompt() : target.getPrompt()); + builder.withResponseFormat( source.getResponseFormat() != null ? source.getResponseFormat() : target.getResponseFormat()); - merged.setTemperature(source.getTemperature() != null ? source.getTemperature() : target.getTemperature()); - return merged; + builder.withTemperature(source.getTemperature() != null ? source.getTemperature() : target.getTemperature()); + return builder.build(); } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionOptions.java index 0307657e655..f2b36d0aefd 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionOptions.java @@ -18,187 +18,47 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; - import org.springframework.ai.model.ModelOptions; import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptResponseFormat; import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptionRequest.GranularityType; /** + * @author youngmon * @author Michael Lavelle * @author Christian Tzolov * @since 0.8.1 */ @JsonInclude(Include.NON_NULL) -public class OpenAiAudioTranscriptionOptions implements ModelOptions { +public interface OpenAiAudioTranscriptionOptions extends ModelOptions { - // @formatter:off /** * ID of the model to use. */ - private @JsonProperty("model") String model; + @JsonProperty("model") + String getModel(); /** - * The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt. + * The format of the transcript output, in one of these options: json, text, srt, + * verbose_json, or vtt. */ - private @JsonProperty("response_format") TranscriptResponseFormat responseFormat; + @JsonProperty("response_format") + TranscriptResponseFormat getResponseFormat(); - private @JsonProperty("prompt") String prompt; + @JsonProperty("prompt") + String getPrompt(); - private @JsonProperty("language") String language; + @JsonProperty("language") + String getLanguage(); /** - * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output - * more random, while lower values like 0.2 will make it more focused and deterministic. + * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make + * the output more random, while lower values like 0.2 will make it more focused and + * deterministic. */ - private @JsonProperty("temperature") Float temperature; - - private @JsonProperty("timestamp_granularities") GranularityType granularityType; - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - - protected OpenAiAudioTranscriptionOptions options; - - public Builder() { - this.options = new OpenAiAudioTranscriptionOptions(); - } - - public Builder(OpenAiAudioTranscriptionOptions options) { - this.options = options; - } - - public Builder withModel(String model) { - this.options.model = model; - return this; - } - - public Builder withLanguage(String language) { - this.options.language = language; - return this; - } - - public Builder withPrompt(String prompt) { - this.options.prompt = prompt; - return this; - } - - public Builder withResponseFormat(TranscriptResponseFormat responseFormat) { - this.options.responseFormat = responseFormat; - return this; - } - - public Builder withTemperature(Float temperature) { - this.options.temperature = temperature; - return this; - } - - public Builder withGranularityType(GranularityType granularityType) { - this.options.granularityType = granularityType; - return this; - } - - public OpenAiAudioTranscriptionOptions build() { - return this.options; - } - - } - - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - - public String getLanguage() { - return this.language; - } - - public void setLanguage(String language) { - this.language = language; - } - - public String getPrompt() { - return this.prompt; - } - - public void setPrompt(String prompt) { - this.prompt = prompt; - } - - public Float getTemperature() { - return this.temperature; - } - - public void setTemperature(Float temperature) { - this.temperature = temperature; - } - - - public TranscriptResponseFormat getResponseFormat() { - return this.responseFormat; - } - - public void setResponseFormat(TranscriptResponseFormat responseFormat) { - this.responseFormat = responseFormat; - } - - public GranularityType getGranularityType() { - return this.granularityType; - } - - public void setGranularityType(GranularityType granularityType) { - this.granularityType = granularityType; - } + @JsonProperty("temperature") + Float getTemperature(); - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((prompt == null) ? 0 : prompt.hashCode()); - result = prime * result + ((language == null) ? 0 : language.hashCode()); - result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); - return result; - } + @JsonProperty("timestamp_granularities") + GranularityType getGranularityType(); - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - OpenAiAudioTranscriptionOptions other = (OpenAiAudioTranscriptionOptions) obj; - if (this.model == null) { - if (other.model != null) - return false; - } - else if (!model.equals(other.model)) - return false; - if (this.prompt == null) { - if (other.prompt != null) - return false; - } - else if (!this.prompt.equals(other.prompt)) - return false; - if (this.language == null) { - if (other.language != null) - return false; - } - else if (!this.language.equals(other.language)) - return false; - if (this.responseFormat == null) { - if (other.responseFormat != null) - return false; - } - else if (!this.responseFormat.equals(other.responseFormat)) - return false; - return true; - } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionOptionsBuilder.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionOptionsBuilder.java new file mode 100644 index 00000000000..81c0a1e9e56 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionOptionsBuilder.java @@ -0,0 +1,201 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.openai; + +import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptResponseFormat; +import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptionRequest.GranularityType; + +/** + * Builder for {@link OpenAiAudioTranscriptionOptions} + * + * @author youngmon + * @version 0.8.1 + */ +public class OpenAiAudioTranscriptionOptionsBuilder { + + private String model; + + private TranscriptResponseFormat responseFormat; + + private String prompt; + + private String language; + + private Float temperature; + + private GranularityType granularityType; + + private OpenAiAudioTranscriptionOptionsBuilder() { + } + + public static OpenAiAudioTranscriptionOptionsBuilder builder() { + return new OpenAiAudioTranscriptionOptionsBuilder(); + } + + /** + * Copy Constructor for {@link OpenAiAudioTranscriptionOptionsBuilder} + * @param options Existing {@link OpenAiAudioTranscriptionOptions} + * @return new OpenAiAudioTranscriptionsBuilder + */ + public static OpenAiAudioTranscriptionOptionsBuilder builder(OpenAiAudioTranscriptionOptions options) { + return builder().withModel(options.getModel()) + .withResponseFormat(options.getResponseFormat()) + .withPrompt(options.getPrompt()) + .withLanguage(options.getLanguage()) + .withTemperature(options.getTemperature()) + .withGranularityType(options.getGranularityType()); + } + + public OpenAiAudioTranscriptionOptions build() { + return new OpenAiAudioTranscriptionOptionsImpl(this); + } + + public OpenAiAudioTranscriptionOptionsBuilder withModel(final String model) { + if (model == null) + return this; + this.model = model; + return this; + } + + public OpenAiAudioTranscriptionOptionsBuilder withResponseFormat(final TranscriptResponseFormat responseFormat) { + if (responseFormat == null) + return this; + this.responseFormat = responseFormat; + return this; + } + + public OpenAiAudioTranscriptionOptionsBuilder withPrompt(final String prompt) { + if (prompt == null) + return this; + this.prompt = prompt; + return this; + } + + public OpenAiAudioTranscriptionOptionsBuilder withLanguage(final String language) { + if (language == null) + return this; + this.language = language; + return this; + } + + public OpenAiAudioTranscriptionOptionsBuilder withTemperature(final Float temperature) { + if (temperature == null) + return this; + this.temperature = temperature; + return this; + } + + public OpenAiAudioTranscriptionOptionsBuilder withGranularityType(final GranularityType granularityType) { + if (granularityType == null) + return this; + this.granularityType = granularityType; + return this; + } + + private static class OpenAiAudioTranscriptionOptionsImpl implements OpenAiAudioTranscriptionOptions { + + private final String model; + + private final TranscriptResponseFormat responseFormat; + + private final String prompt; + + private final String language; + + private final Float temperature; + + private final GranularityType granularityType; + + private OpenAiAudioTranscriptionOptionsImpl(OpenAiAudioTranscriptionOptionsBuilder builder) { + this.model = builder.model; + this.responseFormat = builder.responseFormat; + this.prompt = builder.prompt; + this.language = builder.language; + this.temperature = builder.temperature; + this.granularityType = builder.granularityType; + } + + @Override + public String getModel() { + return this.model; + } + + @Override + public TranscriptResponseFormat getResponseFormat() { + return this.responseFormat; + } + + @Override + public String getPrompt() { + return this.prompt; + } + + @Override + public String getLanguage() { + return this.language; + } + + @Override + public Float getTemperature() { + return this.temperature; + } + + @Override + public GranularityType getGranularityType() { + return this.granularityType; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((model == null) ? 0 : model.hashCode()); + result = prime * result + ((prompt == null) ? 0 : prompt.hashCode()); + result = prime * result + ((language == null) ? 0 : language.hashCode()); + result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + OpenAiAudioTranscriptionOptionsImpl other = (OpenAiAudioTranscriptionOptionsImpl) obj; + if ((this.model == null) != (other.model == null)) + return false; + else if (!model.equals(other.model)) + return false; + if ((this.prompt == null) != (other.prompt == null)) + return false; + else if (!this.prompt.equals(other.prompt)) + return false; + if ((this.language == null) != (other.language == null)) + return false; + else if (!this.language.equals(other.language)) + return false; + if ((this.responseFormat == null) != (other.responseFormat == null)) + return false; + else + return this.responseFormat.equals(other.responseFormat); + } + + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java index 8e86ecdd19a..0f09e6f2e90 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java @@ -101,7 +101,10 @@ public class OpenAiChatClient extends */ public OpenAiChatClient(OpenAiApi openAiApi) { this(openAiApi, - OpenAiChatOptions.builder().withModel(OpenAiApi.DEFAULT_CHAT_MODEL).withTemperature(0.7f).build()); + OpenAiChatOptionsBuilder.builder() + .withModel(OpenAiApi.DEFAULT_CHAT_MODEL) + .withTemperature(0.7f) + .build()); } /** @@ -262,20 +265,20 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream); if (prompt.getOptions() != null) { - if (prompt.getOptions() instanceof ChatOptions runtimeOptions) { - OpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions, - ChatOptions.class, OpenAiChatOptions.class); - - Set promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions, - IS_RUNTIME_CALL); - functionsForThisRequest.addAll(promptEnabledFunctions); - - request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class); - } - else { + OpenAiChatOptionsBuilder builder; + if (prompt.getOptions() instanceof OpenAiChatOptions runtimeOptions) + builder = OpenAiChatOptionsBuilder.builder(runtimeOptions); + else if (prompt.getOptions() instanceof ChatOptions runtimeOptions) + builder = OpenAiChatOptionsBuilder.builder().withChatOptions(runtimeOptions); + else throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " + prompt.getOptions().getClass().getSimpleName()); - } + + Set promptEnabledFunctions = this.handleFunctionCallbackConfigurations(builder.build(), + IS_RUNTIME_CALL); + functionsForThisRequest.addAll(promptEnabledFunctions); + builder.withFunctions(promptEnabledFunctions); + request = ModelOptionsUtils.merge(builder.build(), request, ChatCompletionRequest.class); } if (this.defaultOptions != null) { @@ -291,9 +294,9 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { // Add the enabled functions definitions to the request's tools parameter. if (!CollectionUtils.isEmpty(functionsForThisRequest)) { - request = ModelOptionsUtils.merge( - OpenAiChatOptions.builder().withTools(this.getFunctionTools(functionsForThisRequest)).build(), - request, ChatCompletionRequest.class); + request = ModelOptionsUtils.merge(OpenAiChatOptionsBuilder.builder() + .withTools(this.getFunctionTools(functionsForThisRequest)) + .build(), request, ChatCompletionRequest.class); } return request; diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index d45d4db18cd..2673bdaf4a1 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -411,7 +411,6 @@ public List getFunctionCallbacks() { return this.functionCallbacks; } - @Override public void setFunctionCallbacks(List functionCallbacks) { this.functionCallbacks = functionCallbacks; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptionsBuilder.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptionsBuilder.java new file mode 100644 index 00000000000..502fbe9385e --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptionsBuilder.java @@ -0,0 +1,488 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.openai; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.function.FunctionCallingOptionsBuilder; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat; +import org.springframework.ai.openai.api.OpenAiApi.FunctionTool; + +public class OpenAiChatOptionsBuilder { + + private final FunctionCallingOptionsBuilder functionCallingOptionsBuilder = FunctionCallingOptionsBuilder.builder(); + + private final ChatOptionsBuilder chatOptionsBuilder = ChatOptionsBuilder.builder(); + + private String model; + + private Float frequencyPenalty; + + private ResponseFormat responseFormat; + + private Integer maxTokens; + + private Integer n; + + private Float presencePenalty; + + private Integer seed; + + private String toolChoice; + + private String user; + + private final List stop = new ArrayList<>(); + + private final List tools = new ArrayList<>(); + + private final Map logitBias = new HashMap<>(); + + private OpenAiChatOptionsBuilder() { + } + + public static OpenAiChatOptionsBuilder builder() { + return new OpenAiChatOptionsBuilder(); + } + + /** + * Copy Constructor for {@link OpenAiChatOptionsBuilder} + * @param options Existing {@link OpenAiChatOptions} + * @return new OpenAiChatOptionsBuilder + */ + public static OpenAiChatOptionsBuilder builder(final OpenAiChatOptions options) { + return builder().withFunctionCallingOptions(options) + .withChatOptions(options) + .withFrequencyPenalty(options.getFrequencyPenalty()) + .withPresencePenalty(options.getPresencePenalty()) + .withMaxTokens(options.getMaxTokens()) + .withLogitBias(options.getLogitBias()) + .withUser(options.getUser()) + .withModel(options.getModel()) + .withToolChoice(options.getToolChoice()) + .withTools(options.getTools()) + .withN(options.getN()) + .withSeed(options.getSeed()) + .withStop(options.getStop()) + .withResponseFormat(options.getResponseFormat()); + } + + public OpenAiChatOptions build() { + return new OpenAiChatOptionsImpl(this); + } + + public OpenAiChatOptionsBuilder withFunctionCallingOptions(final FunctionCallingOptions options) { + if (options == null) + return this; + withFunctionCallbacks(options.getFunctionCallbacks()); + withFunctions(options.getFunctions()); + return this; + } + + public OpenAiChatOptionsBuilder withChatOptions(final ChatOptions options) { + if (options == null) + return this; + withTopP(options.getTopP()); + withTemperature(options.getTemperature()); + return this; + } + + public OpenAiChatOptionsBuilder withModel(final String model) { + if (model == null) + return this; + this.model = model; + return this; + } + + public OpenAiChatOptionsBuilder withFrequencyPenalty(final Float frequencyPenalty) { + if (frequencyPenalty == null) + return this; + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public OpenAiChatOptionsBuilder withLogitBias(final Map logitBias) { + if (logitBias == null) + return this; + this.logitBias.putAll(logitBias); + return this; + } + + public OpenAiChatOptionsBuilder withMaxTokens(final Integer maxTokens) { + if (maxTokens == null) + return this; + this.maxTokens = maxTokens; + return this; + } + + public OpenAiChatOptionsBuilder withN(final Integer n) { + if (n == null) + return this; + this.n = n; + return this; + } + + public OpenAiChatOptionsBuilder withPresencePenalty(final Float presencePenalty) { + if (presencePenalty == null) + return this; + this.presencePenalty = presencePenalty; + return this; + } + + public OpenAiChatOptionsBuilder withResponseFormat(final ResponseFormat responseFormat) { + if (responseFormat == null) + return this; + this.responseFormat = responseFormat; + return this; + } + + public OpenAiChatOptionsBuilder withSeed(final Integer seed) { + if (seed == null) + return this; + this.seed = seed; + return this; + } + + public OpenAiChatOptionsBuilder withStop(final List stop) { + if (stop == null) + return this; + this.stop.addAll(stop); + return this; + } + + public OpenAiChatOptionsBuilder withTemperature(final Float temperature) { + if (temperature == null) + return this; + this.chatOptionsBuilder.withTemperature(temperature); + return this; + } + + public OpenAiChatOptionsBuilder withTopP(final Float topP) { + if (topP == null) + return this; + this.chatOptionsBuilder.withTopP(topP); + return this; + } + + public OpenAiChatOptionsBuilder withTools(final List tools) { + if (tools == null) + return this; + this.tools.addAll(tools); + return this; + } + + public OpenAiChatOptionsBuilder withToolChoice(final String toolChoice) { + if (toolChoice == null) + return this; + this.toolChoice = toolChoice; + return this; + } + + public OpenAiChatOptionsBuilder withUser(final String user) { + if (user == null) + return this; + this.user = user; + return this; + } + + public OpenAiChatOptionsBuilder withFunctionCallbacks(final List functionCallbacks) { + if (functionCallbacks == null) + return this; + this.functionCallingOptionsBuilder.withFunctionCallbacks(functionCallbacks); + return this; + } + + public OpenAiChatOptionsBuilder withFunctionCallback(final FunctionCallback functionCallback) { + if (functionCallback == null) + return this; + this.functionCallingOptionsBuilder.withFunctionCallback(functionCallback); + return this; + } + + public OpenAiChatOptionsBuilder withFunctions(final Set functionNames) { + if (functionNames == null) + return this; + this.functionCallingOptionsBuilder.withFunctions(functionNames); + return this; + } + + public OpenAiChatOptionsBuilder withFunction(final String functionName) { + if (functionName != null) + this.functionCallingOptionsBuilder.withFunction(functionName); + return this; + } + + private static class OpenAiChatOptionsImpl extends OpenAiChatOptions { + + private final ChatOptions chatOptions; + + private final FunctionCallingOptions functionCallingOptions; + + private final String model; + + private final Float frequencyPenalty; + + private final ResponseFormat responseFormat; + + private final Map logitBias; + + private final Integer maxTokens; + + private final Integer n; + + private final Float presencePenalty; + + private final Integer seed; + + private final String toolChoice; + + private final String user; + + private final List stop; + + private final List tools; + + private OpenAiChatOptionsImpl(final OpenAiChatOptionsBuilder builder) { + this.chatOptions = builder.chatOptionsBuilder.build(); + this.functionCallingOptions = builder.functionCallingOptionsBuilder.build(); + this.frequencyPenalty = builder.frequencyPenalty; + this.presencePenalty = builder.presencePenalty; + this.logitBias = builder.logitBias; + this.maxTokens = builder.maxTokens; + this.model = builder.model; + this.responseFormat = builder.responseFormat; + this.stop = builder.stop; + this.tools = builder.tools; + this.user = builder.user; + this.n = builder.n; + this.toolChoice = builder.toolChoice; + this.seed = builder.seed; + } + + @Override + public String getModel() { + return this.model; + } + + @Override + public Float getFrequencyPenalty() { + return this.frequencyPenalty; + } + + @Override + public Map getLogitBias() { + return this.logitBias; + } + + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + + @Override + public Integer getN() { + return this.n; + } + + @Override + public Float getPresencePenalty() { + return this.presencePenalty; + } + + @Override + public ResponseFormat getResponseFormat() { + return this.responseFormat; + } + + @Override + public Integer getSeed() { + return this.seed; + } + + @Override + public List getStop() { + return this.stop; + } + + @Override + public List getTools() { + return this.tools; + } + + @Override + public String getToolChoice() { + return this.toolChoice; + } + + @Override + public String getUser() { + return this.user; + } + + /** + * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will + * make the output more random, while lower values like 0.2 will make it more + * focused and deterministic. We generally recommend altering this or top_p but + * not both. + */ + @Override + @JsonProperty("temperature") + public Float getTemperature() { + return this.chatOptions.getTemperature(); + } + + /** + * An alternative to sampling with temperature, called nucleus sampling, where the + * model considers the results of the tokens with top_p probability mass. So 0.1 + * means only the tokens comprising the top 10% probability mass are considered. + * We generally recommend altering this or temperature but not both. + */ + @Override + @JsonProperty("top_p") + public Float getTopP() { + return this.chatOptions.getTopP(); + } + + @Override + @JsonIgnore + public Integer getTopK() { + throw new UnsupportedOperationException("Unimplemented method 'getTopK'"); + } + + @Override + public List getFunctionCallbacks() { + return this.functionCallingOptions.getFunctionCallbacks(); + } + + @Override + public Set getFunctions() { + return this.functionCallingOptions.getFunctions(); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((model == null) ? 0 : model.hashCode()); + result = prime * result + ((frequencyPenalty == null) ? 0 : frequencyPenalty.hashCode()); + result = prime * result + ((logitBias == null) ? 0 : logitBias.hashCode()); + result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode()); + result = prime * result + ((n == null) ? 0 : n.hashCode()); + result = prime * result + ((presencePenalty == null) ? 0 : presencePenalty.hashCode()); + result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode()); + result = prime * result + ((seed == null) ? 0 : seed.hashCode()); + result = prime * result + ((stop == null) ? 0 : stop.hashCode()); + result = prime * result + ((tools == null) ? 0 : tools.hashCode()); + result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode()); + result = prime * result + ((user == null) ? 0 : user.hashCode()); + result = prime * result + ((functionCallingOptions == null) ? 0 : functionCallingOptions.hashCode()); + result = prime * result + ((chatOptions == null) ? 0 : chatOptions.hashCode()); + return result; + } + + @Override + public boolean equals(final Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + OpenAiChatOptionsImpl other = (OpenAiChatOptionsImpl) obj; + if (this.model == null) { + if (other.model != null) + return false; + } + else if (!model.equals(other.model)) + return false; + if (this.frequencyPenalty == null) { + if (other.frequencyPenalty != null) + return false; + } + else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) + return false; + if (this.logitBias == null) { + if (other.logitBias != null) + return false; + } + else if (!this.logitBias.equals(other.logitBias)) + return false; + if (this.maxTokens == null) { + if (other.maxTokens != null) + return false; + } + else if (!this.maxTokens.equals(other.maxTokens)) + return false; + if (this.n == null) { + if (other.n != null) + return false; + } + else if (!this.n.equals(other.n)) + return false; + if (this.presencePenalty == null) { + if (other.presencePenalty != null) + return false; + } + else if (!this.presencePenalty.equals(other.presencePenalty)) + return false; + if (this.responseFormat == null) { + if (other.responseFormat != null) + return false; + } + else if (!this.responseFormat.equals(other.responseFormat)) + return false; + if (this.seed == null) { + if (other.seed != null) + return false; + } + else if (!this.seed.equals(other.seed)) + return false; + if (this.stop == null) { + if (other.stop != null) + return false; + } + else if (!stop.equals(other.stop)) + return false; + if ((this.chatOptions == null) != (other.chatOptions == null)) + return false; + else if (!this.chatOptions.equals(other.chatOptions)) + return false; + if ((this.tools == null) != (other.tools == null)) + return false; + else if (!tools.equals(other.tools)) + return false; + if ((this.toolChoice == null) != (other.toolChoice == null)) + return false; + else if (!toolChoice.equals(other.toolChoice)) + return false; + if ((this.user == null) != (other.user == null)) + return false; + else + return this.user.equals(other.user); + } + + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java index 808c5f3513f..3dfa53590a3 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java @@ -68,7 +68,7 @@ public OpenAiEmbeddingClient(OpenAiApi openAiApi) { */ public OpenAiEmbeddingClient(OpenAiApi openAiApi, MetadataMode metadataMode) { this(openAiApi, metadataMode, - OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(), + OpenAiEmbeddingOptionsBuilder.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(), RetryUtils.DEFAULT_RETRY_TEMPLATE); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptionsBuilder.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptionsBuilder.java new file mode 100644 index 00000000000..11b1dc9c0b6 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptionsBuilder.java @@ -0,0 +1,104 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.openai; + +/** + * @author youngmon + * @version 0.8.1 + */ +public class OpenAiEmbeddingOptionsBuilder { + + private String model; + + private String encodingFormat; + + private String user; + + private OpenAiEmbeddingOptionsBuilder() { + } + + public static OpenAiEmbeddingOptionsBuilder builder() { + return new OpenAiEmbeddingOptionsBuilder(); + } + + /** + * Copy Constructor for {@link OpenAiEmbeddingOptionsBuilder} + * @param options Existing {@link OpenAiEmbeddingOptions} + * @return new OpenAiEmbeddingOptionsBuilder + */ + public static OpenAiEmbeddingOptionsBuilder builder(OpenAiEmbeddingOptions options) { + return builder().withUser(options.getUser()) + .withModel(options.getModel()) + .withEncodingFormat(options.getEncodingFormat()); + } + + public OpenAiEmbeddingOptionsBuilder withUser(final String user) { + if (user == null) + return this; + this.user = user; + return this; + } + + public OpenAiEmbeddingOptionsBuilder withModel(final String model) { + if (model == null) + return this; + this.model = model; + return this; + } + + public OpenAiEmbeddingOptionsBuilder withEncodingFormat(final String encodingFormat) { + if (encodingFormat == null) + return this; + this.encodingFormat = encodingFormat; + return this; + } + + public OpenAiEmbeddingOptions build() { + return new OpenAiEmbeddingOptionsImpl(this); + } + + private static class OpenAiEmbeddingOptionsImpl extends OpenAiEmbeddingOptions { + + private final String model; + + private final String encodingFormat; + + private final String user; + + private OpenAiEmbeddingOptionsImpl(final OpenAiEmbeddingOptionsBuilder builder) { + this.user = builder.user; + this.encodingFormat = builder.encodingFormat; + this.model = builder.model; + } + + @Override + public String getModel() { + return this.model; + } + + @Override + public String getEncodingFormat() { + return this.encodingFormat; + } + + @Override + public String getUser() { + return this.user; + } + + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java index 69863ef9ffc..9b24cd57925 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java @@ -48,14 +48,14 @@ public class OpenAiImageClient implements ImageClient { private final static Logger logger = LoggerFactory.getLogger(OpenAiImageClient.class); - private OpenAiImageOptions defaultOptions; + private final OpenAiImageOptions defaultOptions; private final OpenAiImageApi openAiImageApi; public final RetryTemplate retryTemplate; public OpenAiImageClient(OpenAiImageApi openAiImageApi) { - this(openAiImageApi, OpenAiImageOptions.builder().build(), RetryUtils.DEFAULT_RETRY_TEMPLATE); + this(openAiImageApi, OpenAiImageOptionsBuilder.builder().build(), RetryUtils.DEFAULT_RETRY_TEMPLATE); } public OpenAiImageClient(OpenAiImageApi openAiImageApi, OpenAiImageOptions defaultOptions, @@ -123,39 +123,9 @@ private ImageResponse convertResponse(ResponseEntity messages = new ArrayList<>(List.of(userMessage)); - var promptOptions = OpenAiChatOptions.builder() + var promptOptions = OpenAiChatOptionsBuilder.builder() .withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue()) .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") @@ -217,7 +217,7 @@ void streamFunctionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); - var promptOptions = OpenAiChatOptions.builder() + var promptOptions = OpenAiChatOptionsBuilder.builder() // .withModel(OpenAiApi.ChatModel.GPT_4_TURBO_PREVIEW.getValue()) .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("getCurrentWeather") @@ -252,7 +252,9 @@ void multiModalityEmbeddedImage() throws IOException { List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue()).build())); + OpenAiChatOptionsBuilder.builder() + .withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue()) + .build())); logger.info(response.getResult().getOutput().getContent()); assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "bowl"); @@ -266,7 +268,9 @@ void multiModalityImageUrl() throws IOException { "https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png"))); ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue()).build())); + OpenAiChatOptionsBuilder.builder() + .withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue()) + .build())); logger.info(response.getResult().getOutput().getContent()); assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "bowl"); @@ -280,7 +284,9 @@ void streamingMultiModalityImageUrl() throws IOException { "https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png"))); Flux response = streamingChatClient.stream(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue()).build())); + OpenAiChatOptionsBuilder.builder() + .withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue()) + .build())); String content = response.collectList() .block() diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java index dcf0f303ff0..dce502bdca3 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java @@ -23,6 +23,10 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.openai.OpenAiAudioTranscriptionOptionsBuilder; +import org.springframework.ai.openai.OpenAiChatOptionsBuilder; +import org.springframework.ai.openai.OpenAiEmbeddingOptionsBuilder; +import org.springframework.ai.openai.OpenAiImageOptionsBuilder; import reactor.core.publisher.Flux; import org.springframework.ai.chat.prompt.Prompt; @@ -30,13 +34,9 @@ import org.springframework.ai.image.ImageMessage; import org.springframework.ai.image.ImagePrompt; import org.springframework.ai.openai.OpenAiAudioTranscriptionClient; -import org.springframework.ai.openai.OpenAiAudioTranscriptionOptions; import org.springframework.ai.openai.OpenAiChatClient; -import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.OpenAiEmbeddingClient; -import org.springframework.ai.openai.OpenAiEmbeddingOptions; import org.springframework.ai.openai.OpenAiImageClient; -import org.springframework.ai.openai.OpenAiImageOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; @@ -121,16 +121,16 @@ public void beforeEach() { retryListener = new TestRetryListener(); retryTemplate.registerListener(retryListener); - chatClient = new OpenAiChatClient(openAiApi, OpenAiChatOptions.builder().build(), null, retryTemplate); + chatClient = new OpenAiChatClient(openAiApi, OpenAiChatOptionsBuilder.builder().build(), null, retryTemplate); embeddingClient = new OpenAiEmbeddingClient(openAiApi, MetadataMode.EMBED, - OpenAiEmbeddingOptions.builder().build(), retryTemplate); + OpenAiEmbeddingOptionsBuilder.builder().build(), retryTemplate); audioTranscriptionClient = new OpenAiAudioTranscriptionClient(openAiAudioApi, - OpenAiAudioTranscriptionOptions.builder() + OpenAiAudioTranscriptionOptionsBuilder.builder() .withModel("model") .withResponseFormat(TranscriptResponseFormat.JSON) .build(), retryTemplate); - imageClient = new OpenAiImageClient(openAiImageApi, OpenAiImageOptions.builder().build(), retryTemplate); + imageClient = new OpenAiImageClient(openAiImageApi, OpenAiImageOptionsBuilder.builder().build(), retryTemplate); } @Test diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java index 683f54b4826..2260ebebaea 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java @@ -20,7 +20,7 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.openai.OpenAiEmbeddingClient; -import org.springframework.ai.openai.OpenAiEmbeddingOptions; +import org.springframework.ai.openai.OpenAiEmbeddingOptionsBuilder; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; @@ -53,7 +53,7 @@ void defaultEmbedding() { void embedding3Large() { EmbeddingResponse embeddingResponse = embeddingClient.call(new EmbeddingRequest(List.of("Hello World"), - OpenAiEmbeddingOptions.builder().withModel("text-embedding-3-large").build())); + OpenAiEmbeddingOptionsBuilder.builder().withModel("text-embedding-3-large").build())); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(3072); @@ -68,7 +68,7 @@ void embedding3Large() { void textEmbeddingAda002() { EmbeddingResponse embeddingResponse = embeddingClient.call(new EmbeddingRequest(List.of("Hello World"), - OpenAiEmbeddingOptions.builder().withModel("text-embedding-3-small").build())); + OpenAiEmbeddingOptionsBuilder.builder().withModel("text-embedding-3-small").build())); assertThat(embeddingResponse.getResults()).hasSize(1); assertThat(embeddingResponse.getResults().get(0)).isNotNull(); assertThat(embeddingResponse.getResults().get(0).getOutput()).hasSize(1536); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java index d367af422ad..d4a2554e07e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java @@ -17,69 +17,88 @@ public class ChatOptionsBuilder { - private class ChatOptionsImpl implements ChatOptions { + private Float temperature; - private Float temperature; + private Float topP; - private Float topP; - - private Integer topK; - - @Override - public Float getTemperature() { - return temperature; - } - - public void setTemperature(Float temperature) { - this.temperature = temperature; - } - - @Override - public Float getTopP() { - return topP; - } - - public void setTopP(Float topP) { - this.topP = topP; - } - - @Override - public Integer getTopK() { - return topK; - } - - public void setTopK(Integer topK) { - this.topK = topK; - } - - } - - private final ChatOptionsImpl options = new ChatOptionsImpl(); + private Integer topK; private ChatOptionsBuilder() { } + /** + * Creates a new {@link ChatOptions} instance. + * @return A new instance of ChatOptionsBuilder. + */ public static ChatOptionsBuilder builder() { return new ChatOptionsBuilder(); } - public ChatOptionsBuilder withTemperature(Float temperature) { - options.setTemperature(temperature); + /** + * Initializes a new {@link ChatOptionsBuilder} with settings from an existing + * {@link ChatOptions} object. + * @param options The ChatOptions object whose settings are to be used. + * @return A ChatOptionsBuilder instance initialized with the provided ChatOptions + * settings. + */ + public static ChatOptionsBuilder builder(final ChatOptions options) { + return builder().withTemperature(options.getTemperature()) + .withTopK(options.getTopK()) + .withTopP(options.getTopP()); + } + + public ChatOptionsBuilder withTemperature(final Float temperature) { + this.temperature = temperature; return this; } - public ChatOptionsBuilder withTopP(Float topP) { - options.setTopP(topP); + public ChatOptionsBuilder withTopP(final Float topP) { + this.topP = topP; return this; } - public ChatOptionsBuilder withTopK(Integer topK) { - options.setTopK(topK); + public ChatOptionsBuilder withTopK(final Integer topK) { + this.topK = topK; return this; } public ChatOptions build() { - return options; + return new ChatOptionsImpl(this.temperature, this.topP, this.topK); + } + + /** + * Created only by ChatOptionsBuilder for controlled setup. Hidden implementation, + * accessed via ChatOptions interface. Promotes modularity and easy use. + */ + private static class ChatOptionsImpl implements ChatOptions { + + private final Float temperature; + + private final Float topP; + + private final Integer topK; + + ChatOptionsImpl(final Float temperature, final Float topP, final Integer topK) { + this.temperature = temperature; + this.topP = topP; + this.topK = topK; + } + + @Override + public Float getTemperature() { + return temperature; + } + + @Override + public Float getTopP() { + return topP; + } + + @Override + public Integer getTopK() { + return topK; + } + } } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java b/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java index 7db065225ab..d6f83760a6a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java @@ -78,8 +78,23 @@ public class DefaultContentFormatter implements ContentFormatter { * Start building a new configuration. * @return The entry point for creating a new configuration. */ - public static Builder builder() { - return new Builder(); + public static ContentFormatterBuilder builder() { + return new ContentFormatterBuilder(); + } + + /** + * Initializes a new {@link ContentFormatterBuilder} with settings from an existing + * {@link DefaultContentFormatter} object. + * @param formatter The DefaultContentFormatter object whose settings are to be used. + * @return A ContentFormatterBuilder instance initialized with the provided + * DefaultContentFormatter settings. + */ + public static ContentFormatterBuilder builder(DefaultContentFormatter formatter) { + return builder().withExcludedEmbedMetadataKeys(formatter.getExcludedEmbedMetadataKeys()) + .withExcludedInferenceMetadataKeys(formatter.getExcludedInferenceMetadataKeys()) + .withMetadataSeparator(formatter.getMetadataSeparator()) + .withMetadataTemplate(formatter.getMetadataTemplate()) + .withTextTemplate(formatter.getTextTemplate()); } /** @@ -90,7 +105,7 @@ public static DefaultContentFormatter defaultConfig() { return builder().build(); } - private DefaultContentFormatter(Builder builder) { + private DefaultContentFormatter(ContentFormatterBuilder builder) { this.metadataTemplate = builder.metadataTemplate; this.metadataSeparator = builder.metadataSeparator; this.textTemplate = builder.textTemplate; @@ -98,7 +113,7 @@ private DefaultContentFormatter(Builder builder) { this.excludedEmbedMetadataKeys = builder.excludedEmbedMetadataKeys; } - public static class Builder { + public static class ContentFormatterBuilder { private String metadataTemplate = DEFAULT_METADATA_TEMPLATE; @@ -110,16 +125,7 @@ public static class Builder { private List excludedEmbedMetadataKeys = new ArrayList<>(); - private Builder() { - } - - public Builder from(DefaultContentFormatter fromFormatter) { - this.withExcludedEmbedMetadataKeys(fromFormatter.getExcludedEmbedMetadataKeys()) - .withExcludedInferenceMetadataKeys(fromFormatter.getExcludedInferenceMetadataKeys()) - .withMetadataSeparator(fromFormatter.getMetadataSeparator()) - .withMetadataTemplate(fromFormatter.getMetadataTemplate()) - .withTextTemplate(fromFormatter.getTextTemplate()); - return this; + private ContentFormatterBuilder() { } /** @@ -127,7 +133,7 @@ public Builder from(DefaultContentFormatter fromFormatter) { * @param metadataTemplate Metadata template to use. * @return this builder */ - public Builder withMetadataTemplate(String metadataTemplate) { + public ContentFormatterBuilder withMetadataTemplate(String metadataTemplate) { Assert.hasText(metadataTemplate, "Metadata Template must not be empty"); this.metadataTemplate = metadataTemplate; return this; @@ -138,7 +144,7 @@ public Builder withMetadataTemplate(String metadataTemplate) { * @param metadataSeparator Metadata separator to use. * @return this builder */ - public Builder withMetadataSeparator(String metadataSeparator) { + public ContentFormatterBuilder withMetadataSeparator(String metadataSeparator) { Assert.notNull(metadataSeparator, "Metadata separator must not be empty"); this.metadataSeparator = metadataSeparator; return this; @@ -149,7 +155,7 @@ public Builder withMetadataSeparator(String metadataSeparator) { * @param textTemplate Document's content template. * @return this builder */ - public Builder withTextTemplate(String textTemplate) { + public ContentFormatterBuilder withTextTemplate(String textTemplate) { Assert.hasText(textTemplate, "Document's text template must not be empty"); this.textTemplate = textTemplate; return this; @@ -161,13 +167,13 @@ public Builder withTextTemplate(String textTemplate) { * @param excludedInferenceMetadataKeys Excluded inference metadata keys to use. * @return this builder */ - public Builder withExcludedInferenceMetadataKeys(List excludedInferenceMetadataKeys) { + public ContentFormatterBuilder withExcludedInferenceMetadataKeys(List excludedInferenceMetadataKeys) { Assert.notNull(excludedInferenceMetadataKeys, "Excluded inference metadata keys must not be null"); this.excludedInferenceMetadataKeys = excludedInferenceMetadataKeys; return this; } - public Builder withExcludedInferenceMetadataKeys(String... keys) { + public ContentFormatterBuilder withExcludedInferenceMetadataKeys(String... keys) { Assert.notNull(keys, "Excluded inference metadata keys must not be null"); this.excludedInferenceMetadataKeys.addAll(Arrays.asList(keys)); return this; @@ -178,13 +184,13 @@ public Builder withExcludedInferenceMetadataKeys(String... keys) { * @param excludedEmbedMetadataKeys Excluded Embed metadata keys to use. * @return this builder */ - public Builder withExcludedEmbedMetadataKeys(List excludedEmbedMetadataKeys) { + public ContentFormatterBuilder withExcludedEmbedMetadataKeys(List excludedEmbedMetadataKeys) { Assert.notNull(excludedEmbedMetadataKeys, "Excluded Embed metadata keys must not be null"); this.excludedEmbedMetadataKeys = excludedEmbedMetadataKeys; return this; } - public Builder withExcludedEmbedMetadataKeys(String... keys) { + public ContentFormatterBuilder withExcludedEmbedMetadataKeys(String... keys) { Assert.notNull(keys, "Excluded Embed metadata keys must not be null"); this.excludedEmbedMetadataKeys.addAll(Arrays.asList(keys)); return this; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java index c7b7f3eb115..fb308666220 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/image/ImageOptionsBuilder.java @@ -15,104 +15,130 @@ */ package org.springframework.ai.image; +/** + * Builder for {@link ImageOptions}. This builder creates option objects required for + * image generation. + * + * @author youngmon + * @since 0.8.1 + */ public class ImageOptionsBuilder { - private class ImageModelOptionsImpl implements ImageOptions { + private Integer n; + + private String model; + + private Integer width; + + private Integer height; + + private String responseFormat; + + private ImageOptionsBuilder() { + } + + /** + * Creates a new {@link ImageOptionsBuilder} instance. + * @return A new instance of ImageOptionsBuilder. + */ + public static ImageOptionsBuilder builder() { + return new ImageOptionsBuilder(); + } + + /** + * Initializes a new {@link ImageOptionsBuilder} with settings from an existing + * {@link ImageOptions} object. + * @param options The ImageOptions object whose settings are to be used. + * @return A ImageOptionsBuilder instance initialized with the provided ImageOptions + * settings. + */ + public static ImageOptionsBuilder builder(final ImageOptions options) { + return builder().withN(options.getN()) + .withModel(options.getModel()) + .withHeight(options.getHeight()) + .withWidth(options.getWidth()) + .withResponseFormat(options.getResponseFormat()); + } + + public ImageOptions build() { + return new ImageOptionsImpl(this.n, this.model, this.width, this.height, this.responseFormat); + } + + public ImageOptionsBuilder withN(final Integer n) { + this.n = n; + return this; + } + + public ImageOptionsBuilder withModel(final String model) { + this.model = model; + return this; + } + + public ImageOptionsBuilder withResponseFormat(final String responseFormat) { + this.responseFormat = responseFormat; + return this; + } - private Integer n; + public ImageOptionsBuilder withWidth(final Integer width) { + this.width = width; + return this; + } + + public ImageOptionsBuilder withHeight(final Integer height) { + this.height = height; + return this; + } - private String model; + /** + * Created only by ImageOptionsBuilder for controlled setup. Hidden implementation, + * accessed via ImageOptions interface. Promotes modularity and easy use. + */ + private static class ImageOptionsImpl implements ImageOptions { - private Integer width; + private final Integer n; - private Integer height; + private final String model; - private String responseFormat; + private final Integer width; + + private final Integer height; + + private final String responseFormat; + + private ImageOptionsImpl(final Integer n, final String model, final Integer width, final Integer height, + final String responseFormat) { + this.n = n; + this.model = model; + this.width = width; + this.height = height; + this.responseFormat = responseFormat; + } @Override public Integer getN() { return n; } - public void setN(Integer n) { - this.n = n; - } - @Override public String getModel() { return model; } - public void setModel(String model) { - this.model = model; - } - @Override public String getResponseFormat() { return responseFormat; } - public void setResponseFormat(String responseFormat) { - this.responseFormat = responseFormat; - } - @Override public Integer getWidth() { return width; } - public void setWidth(Integer width) { - this.width = width; - } - @Override public Integer getHeight() { return height; } - public void setHeight(Integer height) { - this.height = height; - } - - } - - private final ImageModelOptionsImpl options = new ImageModelOptionsImpl(); - - private ImageOptionsBuilder() { - - } - - public static ImageOptionsBuilder builder() { - return new ImageOptionsBuilder(); - } - - public ImageOptionsBuilder withN(Integer n) { - options.setN(n); - return this; - } - - public ImageOptionsBuilder withModel(String model) { - options.setModel(model); - return this; - } - - public ImageOptionsBuilder withResponseFormat(String responseFormat) { - options.setResponseFormat(responseFormat); - return this; - } - - public ImageOptionsBuilder withWidth(Integer width) { - options.setWidth(width); - return this; - } - - public ImageOptionsBuilder withHeight(Integer height) { - options.setHeight(height); - return this; - } - - public ImageOptions build() { - return options; } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java index 146b35c4715..0108b77ddd9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.model.function; +import com.fasterxml.jackson.annotation.JsonIgnore; import java.util.List; import java.util.Set; @@ -31,35 +32,14 @@ public interface FunctionCallingOptions { * ChatClient registry to be used in the chat completion requests. * @return Return the Function Callbacks to be registered with the ChatClient. */ + @JsonIgnore List getFunctionCallbacks(); - /** - * Set the Function Callbacks to be registered with the ChatClient. - * @param functionCallbacks the Function Callbacks to be registered with the - * ChatClient. - */ - void setFunctionCallbacks(List functionCallbacks); - /** * @return List of function names from the ChatClient registry to be used in the next * chat completion requests. */ + @JsonIgnore Set getFunctions(); - /** - * Set the list of function names from the ChatClient registry to be used in the next - * chat completion requests. - * @param functions the list of function names from the ChatClient registry to be used - * in the next chat completion requests. - */ - void setFunctions(Set functions); - - /** - * @return Returns FunctionCallingOptionsBuilder to create a new instance of - * FunctionCallingOptions. - */ - public static FunctionCallingOptionsBuilder builder() { - return new FunctionCallingOptionsBuilder(); - } - } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java index 8b30da67da7..1abf562e057 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java @@ -19,126 +19,94 @@ import java.util.HashSet; import java.util.List; import java.util.Set; - -import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.lang.NonNull; import org.springframework.util.Assert; /** - * Builder for {@link FunctionCallingOptions}. Using the {@link FunctionCallingOptions} - * permits options portability between different AI providers that support - * function-calling. + * Builder for {@link FunctionCallingOptions}. This builder creates option objects + * required for function-calling. * - * @author Christian Tzolov + * @author youngmon * @since 0.8.1 */ public class FunctionCallingOptionsBuilder { - private final PortableFunctionCallingOptions options; + private final List functionCallbacks = new ArrayList<>(); - public FunctionCallingOptionsBuilder() { - this.options = new PortableFunctionCallingOptions(); - } + private final Set functions = new HashSet<>(); - public FunctionCallingOptionsBuilder withFunctionCallbacks(List functionCallbacks) { - this.options.setFunctionCallbacks(functionCallbacks); - return this; + private FunctionCallingOptionsBuilder() { } - public FunctionCallingOptionsBuilder withFunctionCallback(FunctionCallback functionCallback) { - Assert.notNull(functionCallback, "FunctionCallback must not be null"); - this.options.getFunctionCallbacks().add(functionCallback); - return this; + /** + * Creates a new {@link FunctionCallingOptionsBuilder} instance. + * @return A new instance of FunctionCallingOptionsBuilder. + */ + public static FunctionCallingOptionsBuilder builder() { + return new FunctionCallingOptionsBuilder(); } - public FunctionCallingOptionsBuilder withFunctions(Set functions) { - this.options.setFunctions(functions); - return this; + /** + * Initializes a new {@link FunctionCallingOptionsBuilder} with settings from an + * existing {@link FunctionCallingOptions} object. + * @param options The FunctionCallingOptions object whose settings are to be used. + * @return A FunctionCallingOptionsBuilder instance initialized with the provided + * FunctionCallingOptions settings. + */ + public static FunctionCallingOptionsBuilder builder(final FunctionCallingOptions options) { + return builder().withFunctions(options.getFunctions()).withFunctionCallbacks(options.getFunctionCallbacks()); } - public FunctionCallingOptionsBuilder withFunction(String function) { - Assert.notNull(function, "Function must not be null"); - this.options.getFunctions().add(function); - return this; + public FunctionCallingOptions build() { + return new FunctionCallingOptionsImpl(this.functionCallbacks, this.functions); } - public FunctionCallingOptionsBuilder withTemperature(Float temperature) { - this.options.setTemperature(temperature); + public FunctionCallingOptionsBuilder withFunctionCallbacks( + @NonNull final List functionCallbacks) { + Assert.notNull(functionCallbacks, "FunctionCallback must not be null"); + this.functionCallbacks.addAll(functionCallbacks); return this; } - public FunctionCallingOptionsBuilder withTopP(Float topP) { - this.options.setTopP(topP); + public FunctionCallingOptionsBuilder withFunctionCallback(@NonNull final FunctionCallback functionCallback) { + Assert.notNull(functionCallback, "FunctionCallback must not be null"); + this.functionCallbacks.add(functionCallback); return this; } - public FunctionCallingOptionsBuilder withTopK(Integer topK) { - this.options.setTopK(topK); + public FunctionCallingOptionsBuilder withFunctions(@NonNull final Set functions) { + Assert.notNull(functions, "Functions must not be null"); + this.functions.addAll(functions); return this; } - public PortableFunctionCallingOptions build() { - return this.options; + public FunctionCallingOptionsBuilder withFunction(@NonNull final String function) { + Assert.notNull(function, "Function must not be null"); + this.functions.add(function); + return this; } - public static class PortableFunctionCallingOptions implements FunctionCallingOptions, ChatOptions { - - private List functionCallbacks = new ArrayList<>(); - - private Set functions = new HashSet<>(); + private static class FunctionCallingOptionsImpl implements FunctionCallingOptions { - private Float temperature; + private final List functionCallbacks; - private Float topP; + private final Set functions; - private Integer topK; + FunctionCallingOptionsImpl(final List functionCallbacks, final Set functions) { + this.functionCallbacks = functionCallbacks; + this.functions = functions; + } @Override public List getFunctionCallbacks() { return this.functionCallbacks; } - public void setFunctionCallbacks(List functionCallbacks) { - Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); - this.functionCallbacks = functionCallbacks; - } - @Override public Set getFunctions() { return this.functions; } - public void setFunctions(Set functions) { - Assert.notNull(functions, "Functions must not be null"); - this.functions = functions; - } - - @Override - public Float getTemperature() { - return this.temperature; - } - - public void setTemperature(Float temperature) { - this.temperature = temperature; - } - - @Override - public Float getTopP() { - return this.topP; - } - - public void setTopP(Float topP) { - this.topP = topP; - } - - @Override - public Integer getTopK() { - return this.topK; - } - - public void setTopK(Integer topK) { - this.topK = topK; - } - } -} \ No newline at end of file +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/PortableFunctionCallingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/PortableFunctionCallingOptions.java new file mode 100644 index 00000000000..adc229b8c42 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/PortableFunctionCallingOptions.java @@ -0,0 +1,28 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.model.function; + +import org.springframework.ai.chat.prompt.ChatOptions; + +/** + * Including {@link FunctionCallingOptions} and {@link ChatOptions}. This Interface allows + * the extracion of ChatOptions and FunctionCallingOptions. + * + * @author youngmon + */ +public interface PortableFunctionCallingOptions extends FunctionCallingOptions, ChatOptions { + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/PortableFunctionCallingOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/PortableFunctionCallingOptionsBuilder.java new file mode 100644 index 00000000000..44ff052802e --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/PortableFunctionCallingOptionsBuilder.java @@ -0,0 +1,151 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.model.function; + +import java.util.List; +import java.util.Set; + +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; + +/** + * Builder for {@link PortableFunctionCallingOptions}. Using the + * {@link PortableFunctionCallingOptions} permits options portability between different AI + * providers that support function-calling. + * + * @author youngmon + * @author Christian Tzolov + * @since 0.8.1 + */ +public class PortableFunctionCallingOptionsBuilder { + + private FunctionCallingOptionsBuilder functionCallingOptionBuilder = FunctionCallingOptionsBuilder.builder(); + + private ChatOptionsBuilder chatOptionsBuilder = ChatOptionsBuilder.builder(); + + private PortableFunctionCallingOptionsBuilder() { + } + + /** + * Creates a new {@link PortableFunctionCallingOptionsBuilder} instance. + * @return A new instance of PortableFunctionCallingOptionsBuilder. + */ + public static PortableFunctionCallingOptionsBuilder builder() { + return new PortableFunctionCallingOptionsBuilder(); + } + + /** + * Initializes a new {@link PortableFunctionCallingOptionsBuilder} with settings from + * an existing {@link PortableFunctionCallingOptions} object. + * @param options The PortableFunctionCallingOptions object whose settings are to be + * used. + * @return A PortableFunctionCallingOptionsBuilder instance initialized with the + * provided PortableFunctionCallingOptions settings. + */ + public static PortableFunctionCallingOptionsBuilder builder(final PortableFunctionCallingOptions options) { + return builder().withChatOptions(options).withFunctionCallingOptions(options); + } + + public PortableFunctionCallingOptions build() { + return new PortableFunctionCallingOptionsImpl(this.functionCallingOptionBuilder.build(), + this.chatOptionsBuilder.build()); + } + + public PortableFunctionCallingOptionsBuilder withFunctionCallbacks(final List functionCallbacks) { + this.functionCallingOptionBuilder.withFunctionCallbacks(functionCallbacks); + return this; + } + + public PortableFunctionCallingOptionsBuilder withFunctionCallback(final FunctionCallback functionCallback) { + this.functionCallingOptionBuilder.withFunctionCallback(functionCallback); + return this; + } + + public PortableFunctionCallingOptionsBuilder withFunctions(final Set functions) { + this.functionCallingOptionBuilder.withFunctions(functions); + return this; + } + + public PortableFunctionCallingOptionsBuilder withFunction(final String function) { + this.functionCallingOptionBuilder.withFunction(function); + return this; + } + + public PortableFunctionCallingOptionsBuilder withTemperature(final Float temperature) { + this.chatOptionsBuilder.withTemperature(temperature); + return this; + } + + public PortableFunctionCallingOptionsBuilder withTopP(final Float topP) { + this.chatOptionsBuilder.withTopP(topP); + return this; + } + + public PortableFunctionCallingOptionsBuilder withTopK(final Integer topK) { + this.chatOptionsBuilder.withTopK(topK); + return this; + } + + public PortableFunctionCallingOptionsBuilder withChatOptions(final ChatOptions options) { + this.chatOptionsBuilder = ChatOptionsBuilder.builder(options); + return this; + } + + public PortableFunctionCallingOptionsBuilder withFunctionCallingOptions(final FunctionCallingOptions options) { + this.functionCallingOptionBuilder = FunctionCallingOptionsBuilder.builder(options); + return this; + } + + private static class PortableFunctionCallingOptionsImpl implements PortableFunctionCallingOptions { + + private final FunctionCallingOptions functionCallingOptions; + + private final ChatOptions chatOptions; + + PortableFunctionCallingOptionsImpl(final FunctionCallingOptions functionCallingOptions, + final ChatOptions chatOptions) { + this.functionCallingOptions = functionCallingOptions; + this.chatOptions = chatOptions; + } + + @Override + public List getFunctionCallbacks() { + return this.functionCallingOptions.getFunctionCallbacks(); + } + + @Override + public Set getFunctions() { + return this.functionCallingOptions.getFunctions(); + } + + @Override + public Float getTemperature() { + return this.chatOptions.getTemperature(); + } + + @Override + public Float getTopP() { + return this.chatOptions.getTopP(); + } + + @Override + public Integer getTopK() { + return this.chatOptions.getTopK(); + } + + } + +} \ No newline at end of file diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatOptionsTests.java similarity index 52% rename from spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java rename to spring-ai-core/src/test/java/org/springframework/ai/chat/ChatOptionsTests.java index 19ece24f84e..25e73648c08 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatOptionsTests.java @@ -16,27 +16,18 @@ package org.springframework.ai.chat; import static org.assertj.core.api.Assertions.assertThat; - -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; import org.junit.jupiter.api.Test; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptionsBuilder; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallback; -import org.springframework.ai.model.function.FunctionCallbackWrapper; -import org.springframework.ai.model.function.FunctionCallingOptions; /** - * Unit Tests for {@link Prompt}. + * Unit Tests for {@link ChatOptions}. * * @author youngmon * @since 0.8.1 */ -public class ChatBuilderTests { +public class ChatOptionsTests { @Test void createNewChatOptionsTest() { @@ -67,43 +58,13 @@ void duplicateChatOptionsTest() { .withTopK(initTopK) .build(); - } - - @Test - void createFunctionCallingOptionTest() { - Float temperature = 1.1f; - Float topP = 2.2f; - Integer topK = 111; - List functionCallbacks = new ArrayList<>(); - Set functions = new HashSet<>(); - - String func = "func"; - FunctionCallback cb = FunctionCallbackWrapper.builder(i -> i) - .withName("cb") - .withDescription("cb") - .build(); - - functions.add(func); - functionCallbacks.add(cb); - - FunctionCallingOptions options = FunctionCallingOptions.builder() - .withFunctionCallbacks(functionCallbacks) - .withFunctions(functions) - .withTopK(topK) - .withTopP(topP) - .withTemperature(temperature) - .build(); - - // Callback Functions - assertThat(options.getFunctionCallbacks()).isNotNull(); - assertThat(options.getFunctionCallbacks().size()).isEqualTo(1); - assertThat(options.getFunctionCallbacks().contains(cb)); + ChatOptions options1 = ChatOptionsBuilder.builder(options).build(); - // Functions - assertThat(options.getFunctions()).isNotNull(); - assertThat(options.getFunctions().size()).isEqualTo(1); - assertThat(options.getFunctions().contains(func)); + assertThat(options.getTopP()).isEqualTo(options1.getTopP()); + assertThat(options.getTopK()).isEqualTo(options1.getTopK()); + assertThat(options.getTemperature()).isEqualTo(options1.getTemperature()); + assertThat(options).isNotSameAs(options1); } -} +} \ No newline at end of file diff --git a/spring-ai-core/src/test/java/org/springframework/ai/image/ImageOptionsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/image/ImageOptionsTests.java new file mode 100644 index 00000000000..14fbad39886 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/image/ImageOptionsTests.java @@ -0,0 +1,80 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.image; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit Tests for {@link ImageOptions}. + * + * @author youngmon + * @since 0.8.1 + */ +public class ImageOptionsTests { + + @Test + void createImageOptionsTest() { + ImageOptions options = ImageOptionsBuilder.builder().build(); + + assertThat(options).isNotNull(); + assertThat(options).isInstanceOf(ImageOptions.class); + } + + @Test + void initImageOptionsTest() { + ImageOptions options = ImageOptionsBuilder.builder().build(); + + assertThat(options.getN()).isNull(); + assertThat(options.getModel()).isNull(); + assertThat(options.getHeight()).isNull(); + assertThat(options.getHeight()).isNull(); + assertThat(options.getWidth()).isNull(); + } + + @Test + void assignImageOptionsTest() { + String responseFormat = "url"; + String model = "dall-e-3"; + Integer n = 3; + Integer width = 512; + Integer height = 768; + + ImageOptions options = ImageOptionsBuilder.builder() + .withResponseFormat(responseFormat) + .withWidth(width) + .withHeight(height) + .withModel(model) + .withN(n) + .build(); + + assertThat(options.getWidth()).isEqualTo(width); + assertThat(options.getHeight()).isEqualTo(height); + assertThat(options.getModel()).isEqualTo(model); + assertThat(options.getN()).isEqualTo(n); + assertThat(options.getResponseFormat()).isEqualTo(responseFormat); + } + + @Test + void immutableTest() { + ImageOptions options = ImageOptionsBuilder.builder().build(); + ImageOptions options1 = ImageOptionsBuilder.builder(options).build(); + + assertThat(options).isNotSameAs(options1); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/function/FunctionCallingOptionsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/function/FunctionCallingOptionsTests.java new file mode 100644 index 00000000000..14d902b4ce6 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/function/FunctionCallingOptionsTests.java @@ -0,0 +1,178 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.model.function; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; + +/** + * Unit Tests for {@link FunctionCallingOptions}. + * + * @author youngmon + * @since 0.8.1 + */ +public class FunctionCallingOptionsTests { + + final String func = "func"; + + final FunctionCallback cb = FunctionCallbackWrapper.builder(i -> i) + .withName("cb") + .withDescription("cb") + .build(); + + @Test + void createFunctionCallingOptionsTest() { + FunctionCallingOptions options = FunctionCallingOptionsBuilder.builder().build(); + + assertThat(options).isNotNull(); + assertThat(options).isInstanceOf(FunctionCallingOptions.class); + } + + @Test + void createPortableFunctionCallingOptionsTest() { + PortableFunctionCallingOptions options = PortableFunctionCallingOptionsBuilder.builder().build(); + + assertThat(options).isNotNull(); + assertThat(options).isInstanceOf(PortableFunctionCallingOptions.class); + } + + @Test + void castingPortableFunctionCallingOptionsTest() { + ChatOptions chatOptions = PortableFunctionCallingOptionsBuilder.builder().build(); + FunctionCallingOptions functionCallingOptions = PortableFunctionCallingOptionsBuilder.builder().build(); + + assertThat(chatOptions).isNotNull(); + assertThat(functionCallingOptions).isNotNull(); + + assertThat(chatOptions).isInstanceOf(ChatOptions.class).isInstanceOf(PortableFunctionCallingOptions.class); + assertThat(functionCallingOptions).isInstanceOf(FunctionCallingOptions.class) + .isInstanceOf(PortableFunctionCallingOptions.class); + } + + @Test + void initFunctionCallingOptionsTest() { + FunctionCallingOptions options = FunctionCallingOptionsBuilder.builder().build(); + + // Callback Functions + assertThat(options.getFunctionCallbacks()).isNotNull(); + assertThat(options.getFunctionCallbacks().isEmpty()).isTrue(); + + // Functions + assertThat(options.getFunctions()).isNotNull(); + assertThat(options.getFunctions().isEmpty()).isTrue(); + } + + @Test + void initPortableFunctionCallingOptionsTest() { + PortableFunctionCallingOptions options = PortableFunctionCallingOptionsBuilder.builder().build(); + + // Callback Funcions + assertThat(options.getFunctionCallbacks()).isNotNull(); + assertThat(options.getFunctionCallbacks().isEmpty()).isTrue(); + + // Functions + assertThat(options.getFunctions()).isNotNull(); + assertThat(options.getFunctions().isEmpty()).isTrue(); + } + + @Test + void assignFunctionCallingOptionsTest() { + FunctionCallingOptions options = FunctionCallingOptionsBuilder.builder() + .withFunction(func) + .withFunctionCallback(cb) + .build(); + + // Callback Functions + assertThat(options.getFunctionCallbacks()).isNotNull(); + assertThat(options.getFunctions().size()).isEqualTo(1); + assertThat(options.getFunctionCallbacks().contains(cb)); + + // Functions + assertThat(options.getFunctions()).isNotNull(); + assertThat(options.getFunctions().size()).isEqualTo(1); + assertThat(options.getFunctions().contains(func)); + } + + @Test + void assignPortableFunctionCallingOptionsTest() { + PortableFunctionCallingOptions options = PortableFunctionCallingOptionsBuilder.builder() + .withFunction(func) + .withFunctionCallback(cb) + .build(); + + // Callback Functions + assertThat(options.getFunctionCallbacks()).isNotNull(); + assertThat(options.getFunctions().size()).isEqualTo(1); + assertThat(options.getFunctionCallbacks().contains(cb)); + + // Functions + assertThat(options.getFunctions()).isNotNull(); + assertThat(options.getFunctions().size()).isEqualTo(1); + assertThat(options.getFunctions().contains(func)); + } + + @Test + void convertPortableFunctionCallingOptionsTest() { + String func = "func"; + Integer topK = 123; + + PortableFunctionCallingOptions options = PortableFunctionCallingOptionsBuilder.builder() + .withFunction(func) + .withTopK(topK) + .build(); + + assertThat(options.getFunctions()).contains(func); + assertThat(options.getTopK()).isEqualTo(topK); + + // type + assertThat(options).isInstanceOf(ChatOptions.class); + assertThat(options).isInstanceOf(FunctionCallingOptions.class); + + // up casting + FunctionCallingOptions functionCallingOptions = options; + ChatOptions chatOptions = options; + + assertThat(functionCallingOptions.getFunctions()).contains(func); + assertThat(chatOptions.getTopK()).isEqualTo(topK); + + // down casting + { + PortableFunctionCallingOptions tmp = (PortableFunctionCallingOptions) functionCallingOptions; + assertThat(tmp.getFunctions()).contains(func); + } + { + PortableFunctionCallingOptions tmp = (PortableFunctionCallingOptions) chatOptions; + assertThat(tmp.getTopK()).isEqualTo(topK); + } + + // build new Instance for modify + String newFunc = "newFunc"; + Float topP = 1.2f; + + FunctionCallingOptions newFunctionCallingOptions = FunctionCallingOptionsBuilder.builder(options) + .withFunction(newFunc) + .build(); + ChatOptions newChatOptions = ChatOptionsBuilder.builder(options).withTopP(topP).build(); + + assertThat(newFunctionCallingOptions.getFunctions()).contains(func, newFunc); + assertThat(newChatOptions.getTopK()).isEqualTo(topK); + assertThat(newChatOptions.getTopP()).isEqualTo(topP); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioTranscriptionOptionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioTranscriptionOptionProperties.java new file mode 100644 index 00000000000..2ed81924ba2 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioTranscriptionOptionProperties.java @@ -0,0 +1,102 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.openai; + +import org.springframework.ai.openai.OpenAiAudioTranscriptionOptions; +import org.springframework.ai.openai.api.OpenAiAudioApi; +import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptResponseFormat; +import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptionRequest.GranularityType; + +/** + * @author Michael Lavelle + * @author Christian Tzolov + * @since 0.8.1 + */ +public class OpenAiAudioTranscriptionOptionProperties implements OpenAiAudioTranscriptionOptions { + + public static final String DEFAULT_TRANSCRIPTION_MODEL = OpenAiAudioApi.WhisperModel.WHISPER_1.getValue(); + + private static final Double DEFAULT_TEMPERATURE = 0.7; + + private static final OpenAiAudioApi.TranscriptResponseFormat DEFAULT_RESPONSE_FORMAT = OpenAiAudioApi.TranscriptResponseFormat.TEXT; + + private String model = DEFAULT_TRANSCRIPTION_MODEL; + + private Float temperature = DEFAULT_TEMPERATURE.floatValue(); + + private TranscriptResponseFormat responseFormat = DEFAULT_RESPONSE_FORMAT; + + private String prompt; + + private String language; + + private GranularityType granularityType; + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public String getLanguage() { + return this.language; + } + + public void setLanguage(String language) { + this.language = language; + } + + @Override + public String getPrompt() { + return this.prompt; + } + + public void setPrompt(String prompt) { + this.prompt = prompt; + } + + @Override + public Float getTemperature() { + return this.temperature; + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + @Override + public TranscriptResponseFormat getResponseFormat() { + return this.responseFormat; + } + + public void setResponseFormat(TranscriptResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + @Override + public GranularityType getGranularityType() { + return this.granularityType; + } + + public void setGranularityType(GranularityType granularityType) { + this.granularityType = granularityType; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioTranscriptionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioTranscriptionProperties.java index ab27093f80b..7ed908cad58 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioTranscriptionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAudioTranscriptionProperties.java @@ -15,8 +15,6 @@ */ package org.springframework.ai.autoconfigure.openai; -import org.springframework.ai.openai.OpenAiAudioTranscriptionOptions; -import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; @@ -25,24 +23,14 @@ public class OpenAiAudioTranscriptionProperties extends OpenAiParentProperties { public static final String CONFIG_PREFIX = "spring.ai.openai.audio.transcription"; - public static final String DEFAULT_TRANSCRIPTION_MODEL = OpenAiAudioApi.WhisperModel.WHISPER_1.getValue(); - - private static final Double DEFAULT_TEMPERATURE = 0.7; - - private static final OpenAiAudioApi.TranscriptResponseFormat DEFAULT_RESPONSE_FORMAT = OpenAiAudioApi.TranscriptResponseFormat.TEXT; - @NestedConfigurationProperty - private OpenAiAudioTranscriptionOptions options = OpenAiAudioTranscriptionOptions.builder() - .withModel(DEFAULT_TRANSCRIPTION_MODEL) - .withTemperature(DEFAULT_TEMPERATURE.floatValue()) - .withResponseFormat(DEFAULT_RESPONSE_FORMAT) - .build(); + private OpenAiAudioTranscriptionOptionProperties options = new OpenAiAudioTranscriptionOptionProperties(); - public OpenAiAudioTranscriptionOptions getOptions() { + public OpenAiAudioTranscriptionOptionProperties getOptions() { return options; } - public void setOptions(OpenAiAudioTranscriptionOptions options) { + public void setOptions(OpenAiAudioTranscriptionOptionProperties options) { this.options = options; } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index 848456468f2..9282f55c801 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -24,6 +24,7 @@ import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.openai.OpenAiEmbeddingClient; import org.springframework.ai.openai.OpenAiImageClient; +import org.springframework.ai.openai.OpenAiImageOptionsBuilder; import org.springframework.ai.openai.OpenAiAudioSpeechClient; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiAudioApi; @@ -118,7 +119,8 @@ public OpenAiImageClient openAiImageClient(OpenAiConnectionProperties commonProp var openAiImageApi = new OpenAiImageApi(baseUrl, apiKey, restClientBuilder, responseErrorHandler); - return new OpenAiImageClient(openAiImageApi, imageProperties.getOptions(), retryTemplate); + return new OpenAiImageClient(openAiImageApi, + OpenAiImageOptionsBuilder.builder(imageProperties.getOptions()).build(), retryTemplate); } @Bean diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatOptionsProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatOptionsProperties.java new file mode 100644 index 00000000000..65fdd3323ad --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatOptionsProperties.java @@ -0,0 +1,216 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.openai; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat; +import org.springframework.ai.openai.api.OpenAiApi.FunctionTool; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +public class OpenAiChatOptionsProperties extends OpenAiChatOptions { + + public static final String DEFAULT_CHAT_MODEL = "gpt-3.5-turbo"; + + private static final Double DEFAULT_TEMPERATURE = 0.7; + + private String model = DEFAULT_CHAT_MODEL; + + private Float temperature = DEFAULT_TEMPERATURE.floatValue(); + + private Float frequencyPenalty; + + private ResponseFormat responseFormat; + + private Map logitBias; + + private Integer maxTokens; + + private Integer n; + + private Float presencePenalty; + + private Integer seed; + + private String toolChoice; + + private String user; + + private Float topP; + + @NestedConfigurationProperty + private List tools; + + @NestedConfigurationProperty + private List stop; + + @NestedConfigurationProperty + private List functionCallbacks = new ArrayList<>(); + + @NestedConfigurationProperty + private Set functions = new HashSet<>(); + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Float getFrequencyPenalty() { + return this.frequencyPenalty; + } + + public void setFrequencyPenalty(Float frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + @Override + public Map getLogitBias() { + return this.logitBias; + } + + public void setLogitBias(Map logitBias) { + this.logitBias = logitBias; + } + + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + @Override + public Integer getN() { + return this.n; + } + + public void setN(Integer n) { + this.n = n; + } + + @Override + public Float getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(Float presencePenalty) { + this.presencePenalty = presencePenalty; + } + + @Override + public ResponseFormat getResponseFormat() { + return this.responseFormat; + } + + public void setResponseFormat(ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + @Override + public Integer getSeed() { + return this.seed; + } + + public void setSeed(Integer seed) { + this.seed = seed; + } + + @Override + public List getStop() { + return this.stop; + } + + public void setStop(List stop) { + this.stop = stop; + } + + @Override + public List getTools() { + return this.tools; + } + + public void setTools(List tools) { + this.tools = tools; + } + + @Override + public String getToolChoice() { + return this.toolChoice; + } + + public void setToolChoice(String toolChoice) { + this.toolChoice = toolChoice; + } + + @Override + public String getUser() { + return this.user; + } + + public void setUser(String user) { + this.user = user; + } + + @Override + public Float getTemperature() { + return this.temperature; + } + + public void setTemperature(Float temperature) { + this.temperature = temperature; + } + + @Override + public Float getTopP() { + return this.topP; + } + + public void setTopP(Float topP) { + this.topP = topP; + } + + @Override + public List getFunctionCallbacks() { + return this.functionCallbacks; + } + + @Override + public Set getFunctions() { + return this.functions; + } + + public void setFunctions(Set functions) { + this.functions = functions; + } + + @Override + public Integer getTopK() { + return null; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java index 5e9940b4238..2407900b461 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiChatProperties.java @@ -15,7 +15,6 @@ */ package org.springframework.ai.autoconfigure.openai; -import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; @@ -24,26 +23,19 @@ public class OpenAiChatProperties extends OpenAiParentProperties { public static final String CONFIG_PREFIX = "spring.ai.openai.chat"; - public static final String DEFAULT_CHAT_MODEL = "gpt-3.5-turbo"; - - private static final Double DEFAULT_TEMPERATURE = 0.7; - /** * Enable OpenAI chat client. */ private boolean enabled = true; @NestedConfigurationProperty - private OpenAiChatOptions options = OpenAiChatOptions.builder() - .withModel(DEFAULT_CHAT_MODEL) - .withTemperature(DEFAULT_TEMPERATURE.floatValue()) - .build(); + private OpenAiChatOptionsProperties options = new OpenAiChatOptionsProperties(); - public OpenAiChatOptions getOptions() { + public OpenAiChatOptionsProperties getOptions() { return options; } - public void setOptions(OpenAiChatOptions options) { + public void setOptions(OpenAiChatOptionsProperties options) { this.options = options; } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingOptionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingOptionProperties.java new file mode 100644 index 00000000000..235eddcda96 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingOptionProperties.java @@ -0,0 +1,57 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.openai; + +import org.springframework.ai.openai.OpenAiEmbeddingOptions; + +public class OpenAiEmbeddingOptionProperties extends OpenAiEmbeddingOptions { + + public static final String DEFAULT_EMBEDDING_MODEL = "text-embedding-ada-002"; + + private String model = DEFAULT_EMBEDDING_MODEL; + + private String encodingFormat; + + private String user; + + @Override + public String getModel() { + return this.model; + } + + public void setModel(final String model) { + this.model = model; + } + + @Override + public String getEncodingFormat() { + return this.encodingFormat; + } + + public void setEncodingFormat(final String encodingFormat) { + this.encodingFormat = encodingFormat; + } + + @Override + public String getUser() { + return this.user; + } + + public void setUser(final String user) { + this.user = user; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java index fa796d92f3f..5b6695bb4e0 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiEmbeddingProperties.java @@ -16,17 +16,18 @@ package org.springframework.ai.autoconfigure.openai; import org.springframework.ai.document.MetadataMode; -import org.springframework.ai.openai.OpenAiEmbeddingOptions; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; +/** + * @author youngmon + * @version 0.8.1 + */ @ConfigurationProperties(OpenAiEmbeddingProperties.CONFIG_PREFIX) public class OpenAiEmbeddingProperties extends OpenAiParentProperties { public static final String CONFIG_PREFIX = "spring.ai.openai.embedding"; - public static final String DEFAULT_EMBEDDING_MODEL = "text-embedding-ada-002"; - /** * Enable OpenAI embedding client. */ @@ -35,15 +36,13 @@ public class OpenAiEmbeddingProperties extends OpenAiParentProperties { private MetadataMode metadataMode = MetadataMode.EMBED; @NestedConfigurationProperty - private OpenAiEmbeddingOptions options = OpenAiEmbeddingOptions.builder() - .withModel(DEFAULT_EMBEDDING_MODEL) - .build(); + private OpenAiEmbeddingOptionProperties options = new OpenAiEmbeddingOptionProperties(); - public OpenAiEmbeddingOptions getOptions() { + public OpenAiEmbeddingOptionProperties getOptions() { return this.options; } - public void setOptions(OpenAiEmbeddingOptions options) { + public void setOptions(OpenAiEmbeddingOptionProperties options) { this.options = options; } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiImageOptionsProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiImageOptionsProperties.java new file mode 100644 index 00000000000..29d5fbf1247 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiImageOptionsProperties.java @@ -0,0 +1,134 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.openai; + +import org.springframework.ai.openai.OpenAiImageOptions; +import org.springframework.ai.openai.OpenAiImageOptionsBuilder; + +/** + * This class is an object for binding Spring Properties. As it implements the + * {@link OpenAiImageOptions} interface, it can be used as an immutable object through the + * {@link OpenAiImageOptionsBuilder}. + * + * @since 0.8.1 + * @author youngmon + */ +public class OpenAiImageOptionsProperties implements OpenAiImageOptions { + + private Integer n; + + private String model; + + private Integer width; + + private Integer height; + + private String quality; + + private String responseFormat; + + private String size; + + private String style; + + private String user; + + @Override + public Integer getN() { + return this.n; + } + + @Override + public String getModel() { + return this.model; + } + + @Override + public Integer getHeight() { + return this.height; + } + + @Override + public Integer getWidth() { + return this.width; + } + + @Override + public String getResponseFormat() { + return responseFormat; + } + + @Override + public String getSize() { + + if (this.size != null) { + return this.size; + } + return (this.width != null && this.height != null) ? this.width + "x" + this.height : null; + } + + @Override + public String getUser() { + return this.user; + } + + @Override + public String getStyle() { + return this.style; + } + + @Override + public String getQuality() { + return this.quality; + } + + public void setN(Integer n) { + this.n = n; + } + + public void setWidth(Integer width) { + this.width = width; + } + + public void setHeight(Integer height) { + this.height = height; + } + + public void setResponseFormat(String responseFormat) { + this.responseFormat = responseFormat; + } + + public void setQuality(String quality) { + this.quality = quality; + } + + public void setModel(String model) { + this.model = model; + } + + public void setStyle(String style) { + this.style = style; + } + + public void setUser(String user) { + this.user = user; + } + + public void setSize(String size) { + this.size = size; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiImageProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiImageProperties.java index 99fadafaf4d..baf256adfe6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiImageProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiImageProperties.java @@ -39,13 +39,13 @@ public class OpenAiImageProperties extends OpenAiParentProperties { * Options for OpenAI Image API. */ @NestedConfigurationProperty - private OpenAiImageOptions options = OpenAiImageOptions.builder().build(); + private OpenAiImageOptionsProperties options = new OpenAiImageOptionsProperties(); public OpenAiImageOptions getOptions() { return options; } - public void setOptions(OpenAiImageOptions options) { + public void setOptions(OpenAiImageOptionsProperties options) { this.options = options; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java index 15796f921b2..854f287932a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java @@ -32,7 +32,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.openai.OpenAiChatClient; -import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.openai.OpenAiChatOptionsBuilder; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -71,7 +71,7 @@ void functionCallTest() { ChatResponse response = chatClient .call(new Prompt(List.of(new UserMessage("What's the status of my transaction with id T1001?")), - OpenAiChatOptions.builder() + OpenAiChatOptionsBuilder.builder() .withFunction("retrievePaymentStatus") .withFunction("retrievePaymentDate") .build())); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java index 67770f6af82..298dd762faf 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java @@ -22,6 +22,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.openai.OpenAiChatOptionsBuilder; import reactor.core.publisher.Flux; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; @@ -33,7 +34,6 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.openai.OpenAiChatClient; -import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -58,7 +58,7 @@ void functionCallTest() { UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); - var promptOptions = OpenAiChatOptions.builder() + var promptOptions = OpenAiChatOptionsBuilder.builder() .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("CurrentWeatherService") .withDescription("Get the weather in location") @@ -83,7 +83,7 @@ void streamingFunctionCallTest() { UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); - var promptOptions = OpenAiChatOptions.builder() + var promptOptions = OpenAiChatOptionsBuilder.builder() .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) .withName("CurrentWeatherService") .withDescription("Get the weather in location") diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java index d5437a62fdf..81e60e87511 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.openai.OpenAiChatOptionsBuilder; import reactor.core.publisher.Flux; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; @@ -32,10 +33,9 @@ import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallingOptions; -import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; +import org.springframework.ai.model.function.PortableFunctionCallingOptions; +import org.springframework.ai.model.function.PortableFunctionCallingOptionsBuilder; import org.springframework.ai.openai.OpenAiChatClient; -import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -66,7 +66,7 @@ void functionCallTest() { UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withFunction("weatherFunction").build())); + OpenAiChatOptionsBuilder.builder().withFunction("weatherFunction").build())); logger.info("Response: {}", response); @@ -74,7 +74,7 @@ void functionCallTest() { // Test weatherFunctionTwo response = chatClient.call(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); + OpenAiChatOptionsBuilder.builder().withFunction("weatherFunctionTwo").build())); logger.info("Response: {}", response); @@ -92,7 +92,7 @@ void functionCallWithPortableFunctionCallingOptions() { // Test weatherFunction UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); - PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() + PortableFunctionCallingOptions functionOptions = PortableFunctionCallingOptionsBuilder.builder() .withFunction("weatherFunction") .build(); @@ -112,7 +112,7 @@ void streamFunctionCallTest() { UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); Flux response = chatClient.stream(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withFunction("weatherFunction").build())); + OpenAiChatOptionsBuilder.builder().withFunction("weatherFunction").build())); String content = response.collectList() .block() @@ -130,7 +130,7 @@ void streamFunctionCallTest() { // Test weatherFunctionTwo response = chatClient.stream(new Prompt(List.of(userMessage), - OpenAiChatOptions.builder().withFunction("weatherFunctionTwo").build())); + OpenAiChatOptionsBuilder.builder().withFunction("weatherFunctionTwo").build())); content = response.collectList() .block() diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java index 44e2c3af740..06782e1a0e9 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java @@ -22,6 +22,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.openai.OpenAiChatOptionsBuilder; import reactor.core.publisher.Flux; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; @@ -34,7 +35,6 @@ import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.openai.OpenAiChatClient; -import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -62,8 +62,8 @@ void functionCallTest() { UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); - ChatResponse response = chatClient.call( - new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withFunction("WeatherInfo").build())); + ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), + OpenAiChatOptionsBuilder.builder().withFunction("WeatherInfo").build())); logger.info("Response: {}", response); @@ -80,8 +80,8 @@ void streamFunctionCallTest() { UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); - Flux response = chatClient.stream( - new Prompt(List.of(userMessage), OpenAiChatOptions.builder().withFunction("WeatherInfo").build())); + Flux response = chatClient.stream(new Prompt(List.of(userMessage), + OpenAiChatOptionsBuilder.builder().withFunction("WeatherInfo").build())); String content = response.collectList() .block()