Skip to content

Commit 7e4b90d

Browse files
committed
AzureOpenAI - Adopt ToolCallingManager API
- Use the new ToolCallingManager API for AzureOpenAI chat model - Add Builder to construct AzureOpenAI chat model instance - Deprecate existing constructors - Update documentation about the change Signed-off-by: Ilayaperumal Gopinathan <ilayaperumal.gopinathan@broadcom.com>
1 parent ee9eb05 commit 7e4b90d

File tree

11 files changed

+400
-199
lines changed

11 files changed

+400
-199
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 235 additions & 36 deletions
Large diffs are not rendered by default.

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java

Lines changed: 126 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
package org.springframework.ai.azure.openai;
1818

1919
import java.util.ArrayList;
20+
import java.util.Arrays;
21+
import java.util.HashMap;
2022
import java.util.HashSet;
2123
import java.util.List;
2224
import java.util.Map;
@@ -30,7 +32,9 @@
3032
import com.fasterxml.jackson.annotation.JsonProperty;
3133

3234
import org.springframework.ai.model.function.FunctionCallback;
33-
import org.springframework.ai.model.function.FunctionCallingOptions;
35+
import org.springframework.ai.model.tool.ToolCallingChatOptions;
36+
import org.springframework.ai.tool.ToolCallback;
37+
import org.springframework.lang.Nullable;
3438
import org.springframework.util.Assert;
3539

3640
/**
@@ -44,7 +48,7 @@
4448
* @author Ilayaperumal Gopinathan
4549
*/
4650
@JsonInclude(Include.NON_NULL)
47-
public class AzureOpenAiChatOptions implements FunctionCallingOptions {
51+
public class AzureOpenAiChatOptions implements ToolCallingChatOptions {
4852

4953
/**
5054
* The maximum number of tokens to generate.
@@ -138,33 +142,6 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions {
138142
@JsonProperty("response_format")
139143
private AzureOpenAiResponseFormat responseFormat;
140144

141-
/**
142-
* OpenAI Tool Function Callbacks to register with the ChatModel. For Prompt Options
143-
* the functionCallbacks are automatically enabled for the duration of the prompt
144-
* execution. For Default Options the functionCallbacks are registered but disabled by
145-
* default. Use the enableFunctions to set the functions from the registry to be used
146-
* by the ChatModel chat completion requests.
147-
*/
148-
@JsonIgnore
149-
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
150-
151-
/**
152-
* List of functions, identified by their names, to configure for function calling in
153-
* the chat completion requests. Functions with those names must exist in the
154-
* functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions
155-
* are automatically enabled for the duration of the prompt execution.
156-
*
157-
* Note that function enabled with the default options are enabled for all chat
158-
* completion requests. This could impact the token count and the billing. If the
159-
* functions is set in a prompt options, then the enabled functions are only active
160-
* for the duration of this prompt execution.
161-
*/
162-
@JsonIgnore
163-
private Set<String> functions = new HashSet<>();
164-
165-
@JsonIgnore
166-
private Boolean proxyToolCalls;
167-
168145
/**
169146
* Seed value for deterministic sampling such that the same seed and parameters return
170147
* the same result.
@@ -199,7 +176,68 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions {
199176
private ChatCompletionStreamOptions streamOptions;
200177

201178
@JsonIgnore
202-
private Map<String, Object> toolContext;
179+
private Map<String, Object> toolContext = new HashMap<>();
180+
181+
/**
182+
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
183+
* completion requests.
184+
*/
185+
@JsonIgnore
186+
private List<FunctionCallback> toolCallbacks = new ArrayList<>();
187+
188+
/**
189+
* Collection of tool names to be resolved at runtime and used for tool calling in the
190+
* chat completion requests.
191+
*/
192+
@JsonIgnore
193+
private Set<String> toolNames = new HashSet<>();
194+
195+
/**
196+
* Whether to enable the tool execution lifecycle internally in ChatModel.
197+
*/
198+
@JsonIgnore
199+
private Boolean internalToolExecutionEnabled;
200+
201+
@Override
202+
@JsonIgnore
203+
public List<FunctionCallback> getToolCallbacks() {
204+
return this.toolCallbacks;
205+
}
206+
207+
@Override
208+
@JsonIgnore
209+
public void setToolCallbacks(List<FunctionCallback> toolCallbacks) {
210+
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
211+
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
212+
this.toolCallbacks = toolCallbacks;
213+
}
214+
215+
@Override
216+
@JsonIgnore
217+
public Set<String> getToolNames() {
218+
return this.toolNames;
219+
}
220+
221+
@Override
222+
@JsonIgnore
223+
public void setToolNames(Set<String> toolNames) {
224+
Assert.notNull(toolNames, "toolNames cannot be null");
225+
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
226+
this.toolNames = toolNames;
227+
}
228+
229+
@Override
230+
@Nullable
231+
@JsonIgnore
232+
public Boolean isInternalToolExecutionEnabled() {
233+
return internalToolExecutionEnabled;
234+
}
235+
236+
@Override
237+
@JsonIgnore
238+
public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
239+
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
240+
}
203241

204242
public static Builder builder() {
205243
return new Builder();
@@ -224,7 +262,10 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
224262
.topLogprobs(fromOptions.getTopLogProbs())
225263
.enhancements(fromOptions.getEnhancements())
226264
.toolContext(fromOptions.getToolContext())
265+
.internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled())
227266
.streamOptions(fromOptions.getStreamOptions())
267+
.toolCallbacks(fromOptions.getToolCallbacks())
268+
.toolNames(fromOptions.getToolNames())
228269
.build();
229270
}
230271

@@ -336,21 +377,28 @@ public void setTopP(Double topP) {
336377
}
337378

338379
@Override
380+
@Deprecated
381+
@JsonIgnore
339382
public List<FunctionCallback> getFunctionCallbacks() {
340-
return this.functionCallbacks;
383+
return this.getToolCallbacks();
341384
}
342385

386+
@Override
387+
@Deprecated
388+
@JsonIgnore
343389
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
344-
this.functionCallbacks = functionCallbacks;
390+
this.setToolCallbacks(functionCallbacks);
345391
}
346392

347393
@Override
394+
@Deprecated
395+
@JsonIgnore
348396
public Set<String> getFunctions() {
349-
return this.functions;
397+
return this.getToolNames();
350398
}
351399

352400
public void setFunctions(Set<String> functions) {
353-
this.functions = functions;
401+
this.setToolNames(functions);
354402
}
355403

356404
public AzureOpenAiResponseFormat getResponseFormat() {
@@ -400,12 +448,16 @@ public void setEnhancements(AzureChatEnhancementConfiguration enhancements) {
400448
}
401449

402450
@Override
451+
@Deprecated
452+
@JsonIgnore
403453
public Boolean getProxyToolCalls() {
404-
return this.proxyToolCalls;
454+
return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null;
405455
}
406456

457+
@Deprecated
458+
@JsonIgnore
407459
public void setProxyToolCalls(Boolean proxyToolCalls) {
408-
this.proxyToolCalls = proxyToolCalls;
460+
this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null;
409461
}
410462

411463
@Override
@@ -493,30 +545,31 @@ public Builder user(String user) {
493545
return this;
494546
}
495547

548+
@Deprecated
496549
public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
497-
this.options.functionCallbacks = functionCallbacks;
498-
return this;
550+
return toolCallbacks(functionCallbacks);
499551
}
500552

553+
@Deprecated
501554
public Builder functions(Set<String> functionNames) {
502-
Assert.notNull(functionNames, "Function names must not be null");
503-
this.options.functions = functionNames;
504-
return this;
555+
return toolNames(functionNames);
505556
}
506557

558+
@Deprecated
507559
public Builder function(String functionName) {
508-
Assert.hasText(functionName, "Function name must not be empty");
509-
this.options.functions.add(functionName);
510-
return this;
560+
return toolNames(functionName);
511561
}
512562

513563
public Builder responseFormat(AzureOpenAiResponseFormat responseFormat) {
514564
this.options.responseFormat = responseFormat;
515565
return this;
516566
}
517567

568+
@Deprecated
518569
public Builder proxyToolCalls(Boolean proxyToolCalls) {
519-
this.options.proxyToolCalls = proxyToolCalls;
570+
if (proxyToolCalls != null) {
571+
this.options.setInternalToolExecutionEnabled(!proxyToolCalls);
572+
}
520573
return this;
521574
}
522575

@@ -555,6 +608,34 @@ public Builder streamOptions(ChatCompletionStreamOptions streamOptions) {
555608
return this;
556609
}
557610

611+
public Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
612+
this.options.setToolCallbacks(toolCallbacks);
613+
return this;
614+
}
615+
616+
public Builder toolCallbacks(FunctionCallback... toolCallbacks) {
617+
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
618+
this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks));
619+
return this;
620+
}
621+
622+
public Builder toolNames(Set<String> toolNames) {
623+
Assert.notNull(toolNames, "toolNames cannot be null");
624+
this.options.setToolNames(toolNames);
625+
return this;
626+
}
627+
628+
public Builder toolNames(String... toolNames) {
629+
Assert.notNull(toolNames, "toolNames cannot be null");
630+
this.options.toolNames.addAll(Set.of(toolNames));
631+
return this;
632+
}
633+
634+
public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
635+
this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled);
636+
return this;
637+
}
638+
558639
public AzureOpenAiChatOptions build() {
559640
return this.options;
560641
}

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,10 @@ public void createRequestWithChatOptions() {
7171
.responseFormat(AzureOpenAiResponseFormat.TEXT)
7272
.build();
7373

74-
var client = new AzureOpenAiChatModel(mockClient, defaultOptions);
74+
var client = AzureOpenAiChatModel.builder()
75+
.openAIClientBuilder(mockClient)
76+
.defaultOptions(defaultOptions)
77+
.build();
7578

7679
var requestOptions = client.toAzureChatCompletionsOptions(new Prompt("Test message content"));
7780

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatClientIT.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,10 @@ public OpenAIClientBuilder openAIClient() {
161161

162162
@Bean
163163
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) {
164-
return new AzureOpenAiChatModel(openAIClientBuilder,
165-
AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build());
166-
164+
return AzureOpenAiChatModel.builder()
165+
.openAIClientBuilder(openAIClientBuilder)
166+
.defaultOptions(AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build())
167+
.build();
167168
}
168169

169170
@Bean

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelIT.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,10 @@ public OpenAIClientBuilder openAIClientBuilder() {
269269

270270
@Bean
271271
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder) {
272-
return new AzureOpenAiChatModel(openAIClientBuilder,
273-
AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build());
274-
272+
return AzureOpenAiChatModel.builder()
273+
.openAIClientBuilder(openAIClientBuilder)
274+
.defaultOptions(AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build())
275+
.build();
275276
}
276277

277278
}

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatModelObservationIT.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,11 @@ public OpenAIClientBuilder openAIClient() {
194194
@Bean
195195
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder,
196196
TestObservationRegistry observationRegistry) {
197-
return new AzureOpenAiChatModel(openAIClientBuilder,
198-
AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build(), null, List.of(),
199-
observationRegistry);
197+
return AzureOpenAiChatModel.builder()
198+
.openAIClientBuilder(openAIClientBuilder)
199+
.defaultOptions(AzureOpenAiChatOptions.builder().deploymentName("gpt-4o").maxTokens(1000).build())
200+
.observationRegistry(observationRegistry)
201+
.build();
200202
}
201203

202204
}

0 commit comments

Comments
 (0)