15
15
*/
16
16
package org .springframework .ai .azure .openai ;
17
17
18
- import java .util .Collections ;
19
- import java .util .HashSet ;
20
- import java .util .List ;
21
- import java .util .Set ;
22
-
23
18
import com .azure .ai .openai .OpenAIClient ;
24
19
import com .azure .ai .openai .models .ChatChoice ;
25
20
import com .azure .ai .openai .models .ChatCompletions ;
33
28
import com .azure .ai .openai .models .ChatRequestSystemMessage ;
34
29
import com .azure .ai .openai .models .ChatRequestToolMessage ;
35
30
import com .azure .ai .openai .models .ChatRequestUserMessage ;
36
- import com .azure .ai .openai .models .ChatResponseMessage ;
37
31
import com .azure .ai .openai .models .CompletionsFinishReason ;
38
32
import com .azure .ai .openai .models .ContentFilterResultsForPrompt ;
33
+ import com .azure .ai .openai .models .FunctionCall ;
39
34
import com .azure .ai .openai .models .FunctionDefinition ;
40
35
import com .azure .core .util .BinaryData ;
41
36
import com .azure .core .util .IterableStream ;
42
37
import org .slf4j .Logger ;
43
38
import org .slf4j .LoggerFactory ;
44
- import reactor .core .publisher .Flux ;
45
39
46
40
import org .springframework .ai .azure .openai .metadata .AzureOpenAiChatResponseMetadata ;
47
41
import org .springframework .ai .chat .ChatClient ;
59
53
import org .springframework .ai .model .function .FunctionCallbackContext ;
60
54
import org .springframework .util .Assert ;
61
55
import org .springframework .util .CollectionUtils ;
56
+ import reactor .core .publisher .Flux ;
57
+
58
+ import java .util .Collections ;
59
+ import java .util .HashSet ;
60
+ import java .util .List ;
61
+ import java .util .Optional ;
62
+ import java .util .Set ;
63
+ import java .util .concurrent .atomic .AtomicBoolean ;
62
64
63
65
/**
64
66
* {@link ChatClient} implementation for {@literal Microsoft Azure AI} backed by
68
70
* @author Ueibin Kim
69
71
* @author John Blum
70
72
* @author Christian Tzolov
73
+ * @author Grogdunn
71
74
* @see ChatClient
72
75
* @see com.azure.ai.openai.OpenAIClient
73
76
*/
@@ -158,17 +161,42 @@ public Flux<ChatResponse> stream(Prompt prompt) {
158
161
IterableStream <ChatCompletions > chatCompletionsStream = this .openAIClient
159
162
.getChatCompletionsStream (options .getModel (), options );
160
163
161
- return Flux .fromStream (chatCompletionsStream .stream ()
164
+ Flux <ChatCompletions > chatCompletionsFlux = Flux .fromIterable (chatCompletionsStream );
165
+
166
+ final var isFunctionCall = new AtomicBoolean (false );
167
+ final var accessibleChatCompletionsFlux = chatCompletionsFlux
162
168
// Note: the first chat completions can be ignored when using Azure OpenAI
163
169
// service which is a known service bug.
164
170
.skip (1 )
165
- .map (ChatCompletions ::getChoices )
166
- .flatMap (List ::stream )
171
+ .map (chatCompletions -> {
172
+ final var toolCalls = chatCompletions .getChoices ().get (0 ).getDelta ().getToolCalls ();
173
+ isFunctionCall .set (toolCalls != null && !toolCalls .isEmpty ());
174
+ return chatCompletions ;
175
+ })
176
+ .windowUntil (chatCompletions -> {
177
+ if (isFunctionCall .get () && chatCompletions .getChoices ()
178
+ .get (0 )
179
+ .getFinishReason () == CompletionsFinishReason .TOOL_CALLS ) {
180
+ isFunctionCall .set (false );
181
+ return true ;
182
+ }
183
+ return false ;
184
+ }, false )
185
+ .concatMapIterable (window -> {
186
+ final var reduce = window .reduce (MergeUtils .emptyChatCompletions (), MergeUtils ::mergeChatCompletions );
187
+ return List .of (reduce );
188
+ })
189
+ .flatMap (mono -> mono );
190
+ return accessibleChatCompletionsFlux
191
+ .switchMap (accessibleChatCompletions -> handleFunctionCallOrReturnStream (options ,
192
+ Flux .just (accessibleChatCompletions )))
193
+ .flatMapIterable (ChatCompletions ::getChoices )
167
194
.map (choice -> {
168
- var content = (choice .getDelta () != null ) ? choice .getDelta ().getContent () : null ;
195
+ var content = Optional . ofNullable (choice .getMessage ()). orElse ( choice .getDelta ()) .getContent ();
169
196
var generation = new Generation (content ).withGenerationMetadata (generateChoiceMetadata (choice ));
170
197
return new ChatResponse (List .of (generation ));
171
- }));
198
+ });
199
+
172
200
}
173
201
174
202
/**
@@ -522,9 +550,17 @@ protected List<ChatRequestMessage> doGetUserMessages(ChatCompletionsOptions requ
522
550
523
551
@ Override
524
552
protected ChatRequestMessage doGetToolResponseMessage (ChatCompletions response ) {
525
- ChatResponseMessage responseMessage = response .getChoices ().get (0 ).getMessage ();
553
+ final var accessibleChatChoice = response .getChoices ().get (0 );
554
+ var responseMessage = Optional .ofNullable (accessibleChatChoice .getMessage ())
555
+ .orElse (accessibleChatChoice .getDelta ());
526
556
ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage ("" );
527
- assistantMessage .setToolCalls (responseMessage .getToolCalls ());
557
+ final var toolCalls = responseMessage .getToolCalls ();
558
+ assistantMessage .setToolCalls (toolCalls .stream ().map (tc -> {
559
+ final var tc1 = (ChatCompletionsFunctionToolCall ) tc ;
560
+ var toDowncast = new ChatCompletionsFunctionToolCall (tc .getId (),
561
+ new FunctionCall (tc1 .getFunction ().getName (), tc1 .getFunction ().getArguments ()));
562
+ return ((ChatCompletionsToolCall ) toDowncast );
563
+ }).toList ());
528
564
return assistantMessage ;
529
565
}
530
566
@@ -533,6 +569,11 @@ protected ChatCompletions doChatCompletion(ChatCompletionsOptions request) {
533
569
return this .openAIClient .getChatCompletions (request .getModel (), request );
534
570
}
535
571
572
+ @ Override
573
+ protected Flux <ChatCompletions > doChatCompletionStream (ChatCompletionsOptions request ) {
574
+ return Flux .fromIterable (this .openAIClient .getChatCompletionsStream (request .getModel (), request ));
575
+ }
576
+
536
577
@ Override
537
578
protected boolean isToolFunctionCall (ChatCompletions chatCompletions ) {
538
579
@@ -549,4 +590,4 @@ protected boolean isToolFunctionCall(ChatCompletions chatCompletions) {
549
590
return choice .getFinishReason () == CompletionsFinishReason .TOOL_CALLS ;
550
591
}
551
592
552
- }
593
+ }
0 commit comments