17
17
package org .springframework .ai .azure .openai ;
18
18
19
19
import java .util .Collections ;
20
+ import java .util .HashSet ;
20
21
import java .util .List ;
22
+ import java .util .Set ;
21
23
22
24
import com .azure .ai .openai .OpenAIClient ;
23
25
import com .azure .ai .openai .models .ChatChoice ;
24
26
import com .azure .ai .openai .models .ChatCompletions ;
27
+ import com .azure .ai .openai .models .ChatCompletionsFunctionToolCall ;
28
+ import com .azure .ai .openai .models .ChatCompletionsFunctionToolDefinition ;
25
29
import com .azure .ai .openai .models .ChatCompletionsOptions ;
30
+ import com .azure .ai .openai .models .ChatCompletionsToolCall ;
31
+ import com .azure .ai .openai .models .ChatCompletionsToolDefinition ;
26
32
import com .azure .ai .openai .models .ChatRequestAssistantMessage ;
27
33
import com .azure .ai .openai .models .ChatRequestMessage ;
28
34
import com .azure .ai .openai .models .ChatRequestSystemMessage ;
35
+ import com .azure .ai .openai .models .ChatRequestToolMessage ;
29
36
import com .azure .ai .openai .models .ChatRequestUserMessage ;
37
+ import com .azure .ai .openai .models .ChatResponseMessage ;
38
+ import com .azure .ai .openai .models .CompletionsFinishReason ;
30
39
import com .azure .ai .openai .models .ContentFilterResultsForPrompt ;
40
+ import com .azure .ai .openai .models .FunctionDefinition ;
41
+ import com .azure .core .util .BinaryData ;
31
42
import com .azure .core .util .IterableStream ;
32
43
import org .slf4j .Logger ;
33
44
import org .slf4j .LoggerFactory ;
34
45
import reactor .core .publisher .Flux ;
35
46
36
47
import org .springframework .ai .azure .openai .metadata .AzureOpenAiChatResponseMetadata ;
37
48
import org .springframework .ai .chat .ChatClient ;
38
- import org .springframework .ai .chat .prompt .ChatOptions ;
39
49
import org .springframework .ai .chat .ChatResponse ;
40
50
import org .springframework .ai .chat .Generation ;
41
51
import org .springframework .ai .chat .StreamingChatClient ;
42
52
import org .springframework .ai .chat .messages .Message ;
43
53
import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
44
54
import org .springframework .ai .chat .metadata .PromptMetadata ;
45
55
import org .springframework .ai .chat .metadata .PromptMetadata .PromptFilterMetadata ;
56
+ import org .springframework .ai .chat .prompt .ChatOptions ;
46
57
import org .springframework .ai .chat .prompt .Prompt ;
47
58
import org .springframework .ai .model .ModelOptionsUtils ;
59
+ import org .springframework .ai .model .function .AbstractFunctionCallSupport ;
60
+ import org .springframework .ai .model .function .FunctionCallbackContext ;
48
61
import org .springframework .util .Assert ;
62
+ import org .springframework .util .CollectionUtils ;
49
63
50
64
/**
51
65
* {@link ChatClient} implementation for {@literal Microsoft Azure AI} backed by
58
72
* @see ChatClient
59
73
* @see com.azure.ai.openai.OpenAIClient
60
74
*/
61
- public class AzureOpenAiChatClient implements ChatClient , StreamingChatClient {
75
+ public class AzureOpenAiChatClient
76
+ extends AbstractFunctionCallSupport <ChatRequestMessage , ChatCompletionsOptions , ChatCompletions >
77
+ implements ChatClient , StreamingChatClient {
62
78
63
79
private static final String DEFAULT_MODEL = "gpt-35-turbo" ;
64
80
@@ -82,6 +98,12 @@ public AzureOpenAiChatClient(OpenAIClient microsoftOpenAiClient) {
82
98
}
83
99
84
100
public AzureOpenAiChatClient (OpenAIClient microsoftOpenAiClient , AzureOpenAiChatOptions options ) {
101
+ this (microsoftOpenAiClient , options , null );
102
+ }
103
+
104
+ public AzureOpenAiChatClient (OpenAIClient microsoftOpenAiClient , AzureOpenAiChatOptions options ,
105
+ FunctionCallbackContext functionCallbackContext ) {
106
+ super (functionCallbackContext );
85
107
Assert .notNull (microsoftOpenAiClient , "com.azure.ai.openai.OpenAIClient must not be null" );
86
108
Assert .notNull (options , "AzureOpenAiChatOptions must not be null" );
87
109
this .openAIClient = microsoftOpenAiClient ;
@@ -100,7 +122,7 @@ public AzureOpenAiChatClient withDefaultOptions(AzureOpenAiChatOptions defaultOp
100
122
}
101
123
102
124
public AzureOpenAiChatOptions getDefaultOptions () {
103
- return defaultOptions ;
125
+ return this . defaultOptions ;
104
126
}
105
127
106
128
@ Override
@@ -111,7 +133,10 @@ public ChatResponse call(Prompt prompt) {
111
133
112
134
logger .trace ("Azure ChatCompletionsOptions: {}" , options );
113
135
114
- ChatCompletions chatCompletions = this .openAIClient .getChatCompletions (options .getModel (), options );
136
+ ChatCompletions chatCompletions = this .callWithFunctionSupport (options );
137
+
138
+ // ChatCompletions chatCompletions =
139
+ // this.openAIClient.getChatCompletions(options.getModel(), options);
115
140
116
141
logger .trace ("Azure ChatCompletions: {}" , chatCompletions );
117
142
@@ -154,6 +179,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
154
179
*/
155
180
ChatCompletionsOptions toAzureChatCompletionsOptions (Prompt prompt ) {
156
181
182
+ Set <String > functionsForThisRequest = new HashSet <>();
183
+
157
184
List <ChatRequestMessage > azureMessages = prompt .getInstructions ()
158
185
.stream ()
159
186
.map (this ::fromSpringAiMessage )
@@ -167,6 +194,10 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
167
194
// options = ModelOptionsUtils.merge(options, this.defaultOptions,
168
195
// ChatCompletionsOptions.class);
169
196
options = merge (options , this .defaultOptions );
197
+
198
+ Set <String > defaultEnabledFunctions = this .handleFunctionCallbackConfigurations (this .defaultOptions ,
199
+ !IS_RUNTIME_CALL );
200
+ functionsForThisRequest .addAll (defaultEnabledFunctions );
170
201
}
171
202
172
203
if (prompt .getOptions () != null ) {
@@ -178,16 +209,43 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
178
209
// options = ModelOptionsUtils.merge(runtimeOptions, options,
179
210
// ChatCompletionsOptions.class);
180
211
options = merge (updatedRuntimeOptions , options );
212
+
213
+ Set <String > promptEnabledFunctions = this .handleFunctionCallbackConfigurations (updatedRuntimeOptions ,
214
+ IS_RUNTIME_CALL );
215
+ functionsForThisRequest .addAll (promptEnabledFunctions );
216
+
181
217
}
182
218
else {
183
219
throw new IllegalArgumentException ("Prompt options are not of type ChatCompletionsOptions:"
184
220
+ prompt .getOptions ().getClass ().getSimpleName ());
185
221
}
186
222
}
187
223
224
+ // Add the enabled functions definitions to the request's tools parameter.
225
+
226
+ if (!CollectionUtils .isEmpty (functionsForThisRequest )) {
227
+ List <ChatCompletionsFunctionToolDefinition > tools = this .getFunctionTools (functionsForThisRequest );
228
+ List <ChatCompletionsToolDefinition > tools2 = tools .stream ()
229
+ .map (t -> ((ChatCompletionsToolDefinition ) t ))
230
+ .toList ();
231
+ options .setTools (tools2 );
232
+ }
233
+
188
234
return options ;
189
235
}
190
236
237
+ private List <ChatCompletionsFunctionToolDefinition > getFunctionTools (Set <String > functionNames ) {
238
+ return this .resolveFunctionCallbacks (functionNames ).stream ().map (functionCallback -> {
239
+
240
+ FunctionDefinition functionDefinition = new FunctionDefinition (functionCallback .getName ());
241
+ functionDefinition .setDescription (functionCallback .getDescription ());
242
+ BinaryData parameters = BinaryData
243
+ .fromObject (ModelOptionsUtils .jsonToMap (functionCallback .getInputTypeSchema ()));
244
+ functionDefinition .setParameters (parameters );
245
+ return new ChatCompletionsFunctionToolDefinition (functionDefinition );
246
+ }).toList ();
247
+ }
248
+
191
249
private ChatRequestMessage fromSpringAiMessage (Message message ) {
192
250
193
251
switch (message .getMessageType ()) {
@@ -281,6 +339,8 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions springAiOptions, Cha
281
339
ChatCompletionsOptions mergedAzureOptions = new ChatCompletionsOptions (azureOptions .getMessages ());
282
340
mergedAzureOptions = merge (azureOptions , mergedAzureOptions );
283
341
342
+ mergedAzureOptions .setStream (azureOptions .isStream ());
343
+
284
344
if (springAiOptions .getMaxTokens () != null ) {
285
345
mergedAzureOptions .setMaxTokens (springAiOptions .getMaxTokens ());
286
346
}
@@ -324,6 +384,8 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions springAiOptions, Cha
324
384
return mergedAzureOptions ;
325
385
}
326
386
387
+ // https://github.com/Azure/azure-sdk-for-java/blob/azure-ai-openai_1.0.0-beta.6/sdk/openai/azure-ai-openai/src/samples/java/com/azure/ai/openai/usage/GetChatCompletionsToolCallSample.java
388
+
327
389
private ChatCompletionsOptions merge (ChatCompletionsOptions fromOptions , ChatCompletionsOptions toOptions ) {
328
390
329
391
if (fromOptions == null ) {
@@ -367,4 +429,68 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromOptions, ChatCom
367
429
return mergedOptions ;
368
430
}
369
431
432
+ @ Override
433
+ protected ChatCompletionsOptions doCreateToolResponseRequest (ChatCompletionsOptions previousRequest ,
434
+ ChatRequestMessage responseMessage , List <ChatRequestMessage > conversationHistory ) {
435
+
436
+ // Every tool-call item requires a separate function call and a response (TOOL)
437
+ // message.
438
+ for (ChatCompletionsToolCall toolCall : ((ChatRequestAssistantMessage ) responseMessage ).getToolCalls ()) {
439
+
440
+ var functionName = ((ChatCompletionsFunctionToolCall ) toolCall ).getFunction ().getName ();
441
+ String functionArguments = ((ChatCompletionsFunctionToolCall ) toolCall ).getFunction ().getArguments ();
442
+
443
+ if (!this .functionCallbackRegister .containsKey (functionName )) {
444
+ throw new IllegalStateException ("No function callback found for function name: " + functionName );
445
+ }
446
+
447
+ String functionResponse = this .functionCallbackRegister .get (functionName ).call (functionArguments );
448
+
449
+ // Add the function response to the conversation.
450
+ conversationHistory .add (new ChatRequestToolMessage (functionResponse , toolCall .getId ()));
451
+ }
452
+
453
+ // Recursively call chatCompletionWithTools until the model doesn't call a
454
+ // functions anymore.
455
+ ChatCompletionsOptions newRequest = new ChatCompletionsOptions (conversationHistory );
456
+
457
+ newRequest = merge (previousRequest , newRequest );
458
+
459
+ return newRequest ;
460
+ }
461
+
462
+ @ Override
463
+ protected List <ChatRequestMessage > doGetUserMessages (ChatCompletionsOptions request ) {
464
+ return request .getMessages ();
465
+ }
466
+
467
+ @ Override
468
+ protected ChatRequestMessage doGetToolResponseMessage (ChatCompletions response ) {
469
+ ChatResponseMessage responseMessage = response .getChoices ().get (0 ).getMessage ();
470
+ ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage ("" );
471
+ assistantMessage .setToolCalls (responseMessage .getToolCalls ());
472
+ return assistantMessage ;
473
+ }
474
+
475
+ @ Override
476
+ protected ChatCompletions doChatCompletion (ChatCompletionsOptions request ) {
477
+ return this .openAIClient .getChatCompletions (request .getModel (), request );
478
+ }
479
+
480
+ @ Override
481
+ protected boolean isToolFunctionCall (ChatCompletions chatCompletions ) {
482
+
483
+ if (chatCompletions == null || CollectionUtils .isEmpty (chatCompletions .getChoices ())) {
484
+ return false ;
485
+ }
486
+
487
+ var choice = chatCompletions .getChoices ().get (0 );
488
+
489
+ if (choice == null || choice .getFinishReason () == null ) {
490
+ return false ;
491
+ }
492
+
493
+ return choice .getFinishReason () == CompletionsFinishReason .TOOL_CALLS ;
494
+ }
495
+
370
496
}
0 commit comments