diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java index 5d91bbd027b..4abb8e1e76d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java @@ -25,14 +25,8 @@ public interface ChatOptions extends ModelOptions { Float getTemperature(); - void setTemperature(Float temperature); - Float getTopP(); - void setTopP(Float topP); - Integer getTopK(); - void setTopK(Integer topK); - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java index f702f635097..ccae863dc9e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java @@ -18,47 +18,11 @@ public class ChatOptionsBuilder { - private class ChatOptionsImpl implements ChatOptions { - - private Float temperature; - - private Float topP; - - private Integer topK; - - @Override - public Float getTemperature() { - return temperature; - } + private Float temperature; - @Override - public void setTemperature(Float temperature) { - this.temperature = temperature; - } + private Float topP; - @Override - public Float getTopP() { - return topP; - } - - @Override - public void setTopP(Float topP) { - this.topP = topP; - } - - @Override - public Integer getTopK() { - return topK; - } - - @Override - public void setTopK(Integer topK) { - this.topK = topK; - } - - } - - private final ChatOptionsImpl options = new ChatOptionsImpl(); + private Integer topK; private ChatOptionsBuilder() { } @@ -67,23 +31,65 @@ public static ChatOptionsBuilder builder() { return new ChatOptionsBuilder(); } + /** + * Constructs a new immutable object based on the provided ChatOptions object. + * @param options The original ChatOptions object to base the new object on. + * @return ChatOptionsBuilder to construct a new ChatOptions object. + */ + public static ChatOptionsBuilder builder(ChatOptions options) { + return builder().withTopK(options.getTopK()) + .withTopP(options.getTopP()) + .withTemperature(options.getTemperature()); + } + public ChatOptionsBuilder withTemperature(Float temperature) { - options.setTemperature(temperature); + this.temperature = temperature; return this; } public ChatOptionsBuilder withTopP(Float topP) { - options.setTopP(topP); + this.topP = topP; return this; } public ChatOptionsBuilder withTopK(Integer topK) { - options.setTopK(topK); + this.topK = topK; return this; } public ChatOptions build() { - return options; + return new ChatOptionsImpl(this.temperature, this.topP, this.topK); + } + + private class ChatOptionsImpl implements ChatOptions { + + private final Float temperature; + + private final Float topP; + + private final Integer topK; + + ChatOptionsImpl(Float temperature, Float topP, Integer topK) { + this.temperature = temperature; + this.topP = topP; + this.topK = topK; + } + + @Override + public Float getTemperature() { + return temperature; + } + + @Override + public Float getTopP() { + return topP; + } + + @Override + public Integer getTopK() { + return topK; + } + } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java index c66a4f5b182..13e3017c6ac 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java @@ -18,11 +18,12 @@ import java.util.List; import java.util.Set; +import org.springframework.ai.chat.prompt.ChatOptions; /** * @author Christian Tzolov */ -public interface FunctionCallingOptions { +public interface FunctionCallingOptions extends ChatOptions { /** * Function Callbacks to be registered with the ChatClient. For Prompt Options the @@ -34,33 +35,10 @@ public interface FunctionCallingOptions { */ List getFunctionCallbacks(); - /** - * Set the Function Callbacks to be registered with the ChatClient. - * @param functionCallbacks the Function Callbacks to be registered with the - * ChatClient. - */ - void setFunctionCallbacks(List functionCallbacks); - /** * @return List of function names from the ChatClient registry to be used in the next * chat completion requests. */ Set getFunctions(); - /** - * Set the list of function names from the ChatClient registry to be used in the next - * chat completion requests. - * @param functions the list of function names from the ChatClient registry to be used - * in the next chat completion requests. - */ - void setFunctions(Set functions); - - /** - * @return Returns FunctionCallingOptionsBuilder to create a new instance of - * FunctionCallingOptions. - */ - public static FunctionCallingOptionsBuilder builder() { - return new FunctionCallingOptionsBuilder(); - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java index 948fba58f9b..bfcc2c1ece7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptionsBuilder.java @@ -22,6 +22,7 @@ import java.util.Set; import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; import org.springframework.util.Assert; /** @@ -34,117 +35,155 @@ */ public class FunctionCallingOptionsBuilder { - private final PortableFunctionCallingOptions options; + private List functionCallbacks = new ArrayList<>(); - public FunctionCallingOptionsBuilder() { - this.options = new PortableFunctionCallingOptions(); + private Set functions = new HashSet<>(); + + private ChatOptionsBuilder chatOptionsBuilder = ChatOptionsBuilder.builder(); + + private FunctionCallingOptionsBuilder() { + } + + /** + * Creates a new {@link FunctionCallingOptionsBuilder} instance. + * @return A new instance of FunctionCallingOptionsBuilder. + */ + public static FunctionCallingOptionsBuilder builder() { + return new FunctionCallingOptionsBuilder(); + } + + /** + * Initializes a new {@link FunctionCallingOptionsBuilder} with settings from an + * existing {@link ChatOptions} object. This allows for creating a new + * FunctionCallingOptions object based on the settings of an existing ChatOptions + * instance. + * @param options The ChatOptions object whose settings are to be used. + * @return A FunctionCallingOptionsBuilder instance initialized with the provided + * ChatOptions settings. + */ + public static FunctionCallingOptionsBuilder builder(ChatOptions options) { + return builder().withTopK(options.getTopK()) + .withTopP(options.getTopP()) + .withTemperature(options.getTemperature()); + } + + /** + * Initializes a new {@link FunctionCallingOptionsBuilder} with settings from an + * existing {@link FunctionCallingOptions} object. This method is useful for + * transferring settings between different instances of FunctionCallingOptions, + * including function callbacks and functions. + * @param options The PortableFunctionCallingOptions object whose settings are to be + * used. + * @return A FunctionCallingOptionsBuilder instance initialized with the provided + * PortableFunctionCallingOptions settings. + */ + public static FunctionCallingOptionsBuilder builder(FunctionCallingOptions options) { + return builder().withTopK(options.getTopK()) + .withTopP(options.getTopP()) + .withTemperature(options.getTemperature()) + .withFunctions(options.getFunctions()) + .withFunctionCallbacks(options.getFunctionCallbacks()); } public FunctionCallingOptionsBuilder withFunctionCallbacks(List functionCallbacks) { - this.options.setFunctionCallbacks(functionCallbacks); + Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); + this.functionCallbacks.addAll(functionCallbacks); return this; } public FunctionCallingOptionsBuilder withFunctionCallback(FunctionCallback functionCallback) { Assert.notNull(functionCallback, "FunctionCallback must not be null"); - this.options.getFunctionCallbacks().add(functionCallback); + this.functionCallbacks.add(functionCallback); return this; } public FunctionCallingOptionsBuilder withFunctions(Set functions) { - this.options.setFunctions(functions); + Assert.notNull(functions, "Functions must not be null"); + this.functions.addAll(functions); return this; } public FunctionCallingOptionsBuilder withFunction(String function) { Assert.notNull(function, "Function must not be null"); - this.options.getFunctions().add(function); + this.functions.add(function); return this; } public FunctionCallingOptionsBuilder withTemperature(Float temperature) { - this.options.setTemperature(temperature); + this.chatOptionsBuilder.withTemperature(temperature); return this; } public FunctionCallingOptionsBuilder withTopP(Float topP) { - this.options.setTopP(topP); + this.chatOptionsBuilder.withTopP(topP); return this; } public FunctionCallingOptionsBuilder withTopK(Integer topK) { - this.options.setTopK(topK); + this.chatOptionsBuilder.withTopK(topK); return this; } public PortableFunctionCallingOptions build() { - return this.options; + return new PortableFunctionCallingOptions(this.functions, this.functionCallbacks, + this.chatOptionsBuilder.build()); } - public static class PortableFunctionCallingOptions implements FunctionCallingOptions, ChatOptions { + public class PortableFunctionCallingOptions implements FunctionCallingOptions { - private List functionCallbacks = new ArrayList<>(); + private final List functionCallbacks; - private Set functions = new HashSet<>(); + private final Set functions; - private Float temperature; + private final ChatOptions options; - private Float topP; - - private Integer topK; - - @Override - public List getFunctionCallbacks() { - return this.functionCallbacks; - } - - @Override - public void setFunctionCallbacks(List functionCallbacks) { - Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null"); + PortableFunctionCallingOptions(final Set functions, final List functionCallbacks, + ChatOptions options) { + this.functions = functions; this.functionCallbacks = functionCallbacks; + this.options = options; } + /** + * Retrieves a list of function callbacks. This method returns a new list + * containing all currently set function callbacks. The returned list is a copy, + * ensuring that modifications to the returned list do not affect the original + * list of function callbacks. This ensures the immutability of the collection + * exposed to the clients. + * @return An immutable list of {@link FunctionCallback} instances. + */ @Override - public Set getFunctions() { - return this.functions; + public List getFunctionCallbacks() { + return new ArrayList<>(this.functionCallbacks); } + /** + * Retrieves a set of functions. This method returns a new set containing all + * currently set functions. The returned set is a copy, ensuring that + * modifications to the returned set do not affect the original set of functions. + * This ensures the immutability of the collection exposed to the clients. + * @return An immutable set of String representing the functions. + */ @Override - public void setFunctions(Set functions) { - Assert.notNull(functions, "Functions must not be null"); - this.functions = functions; + public Set getFunctions() { + return new HashSet<>(this.functions); } @Override public Float getTemperature() { - return this.temperature; - } - - @Override - public void setTemperature(Float temperature) { - this.temperature = temperature; + return this.options.getTemperature(); } @Override public Float getTopP() { - return this.topP; - } - - @Override - public void setTopP(Float topP) { - this.topP = topP; + return this.options.getTopP(); } @Override public Integer getTopK() { - return this.topK; - } - - @Override - public void setTopK(Integer topK) { - this.topK = topK; + return this.options.getTopK(); } } -} +} \ No newline at end of file diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java new file mode 100644 index 00000000000..46ab5cdd79b --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/ChatBuilderTests.java @@ -0,0 +1,134 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.FunctionCallbackWrapper; +import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.function.FunctionCallingOptionsBuilder; + +/** + * Unit Tests for {@link Prompt}. + * + * @author youngmon + * @since 0.8.1 + */ +public class ChatBuilderTests { + + @Test + void createNewChatOptionsTest() { + Float temperature = 1.1f; + Float topP = 2.2f; + Integer topK = 111; + + ChatOptions options = ChatOptionsBuilder.builder() + .withTemperature(temperature) + .withTopK(topK) + .withTopP(topP) + .build(); + + assertThat(options.getTemperature()).isEqualTo(temperature); + assertThat(options.getTopP()).isEqualTo(topP); + assertThat(options.getTopK()).isEqualTo(topK); + } + + @Test + void duplicateChatOptionsTest() { + Float initTemperature = 1.1f; + Float initTopP = 2.2f; + Integer initTopK = 111; + + ChatOptions options = ChatOptionsBuilder.builder() + .withTemperature(initTemperature) + .withTopP(initTopP) + .withTopK(initTopK) + .build(); + + Integer newTopK = 222; + + ChatOptions newOptions = ChatOptionsBuilder.builder(options) + // setTopK + .withTopK(newTopK) + .build(); + + assertThat(newOptions.getTemperature()).isEqualTo(initTemperature); + assertThat(newOptions.getTopP()).isEqualTo(initTopP); + assertThat(newOptions.getTopK()).isEqualTo(newTopK); + } + + @Test + void createFunctionCallingOptionTest() { + Float temperature = 1.1f; + Float topP = 2.2f; + Integer topK = 111; + List functionCallbacks = new ArrayList<>(); + Set functions = new HashSet<>(); + + String func = "func"; + FunctionCallback cb = FunctionCallbackWrapper.builder(i -> i) + .withName("cb") + .withDescription("cb") + .build(); + + functions.add(func); + functionCallbacks.add(cb); + + FunctionCallingOptions options = FunctionCallingOptionsBuilder.builder() + .withFunctionCallbacks(functionCallbacks) + .withFunctions(functions) + .withTopK(topK) + .withTopP(topP) + .withTemperature(temperature) + .build(); + + // Callback Functions + assertThat(options.getFunctionCallbacks()).isNotNull(); + assertThat(options.getFunctionCallbacks().size()).isEqualTo(1); + assertThat(options.getFunctionCallbacks().contains(cb)); + + // Functions + assertThat(options.getFunctions()).isNotNull(); + assertThat(options.getFunctions().size()).isEqualTo(1); + assertThat(options.getFunctions().contains(func)); + + // Immutable + options.getFunctionCallbacks().add(cb); + assertThat(options.getFunctionCallbacks().size()).isEqualTo(1); + options.getFunctions().add(func + func); + assertThat(options.getFunctions().size()).isEqualTo(1); + + FunctionCallingOptions newOptions = FunctionCallingOptionsBuilder.builder(options) + .withFunction(func + func) + .withFunctionCallback(cb) + .build(); + + assertThat(newOptions.getFunctions().size()).isEqualTo(2); + assertThat(newOptions.getFunctionCallbacks().size()).isEqualTo(2); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java index c005b075044..943a7c2f52c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -28,7 +28,7 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.function.FunctionCallingOptions; +import org.springframework.ai.model.function.FunctionCallingOptionsBuilder; import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.openai.OpenAiChatOptions; @@ -87,7 +87,7 @@ void functionCallWithPortableFunctionCallingOptions() { // Test weatherFunction UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); - PortableFunctionCallingOptions functionOptions = FunctionCallingOptions.builder() + PortableFunctionCallingOptions functionOptions = FunctionCallingOptionsBuilder.builder() .withFunction("weatherFunction") .build();