1
+ /*
2
+ * Copyright 2023 - 2024 the original author or authors.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * https://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ package org .springframework .ai .mistralai .api ;
17
+
18
+ import java .util .ArrayList ;
19
+ import java .util .List ;
20
+ import java .util .Optional ;
21
+ import java .util .UUID ;
22
+
23
+ import org .springframework .ai .mistralai .api .MistralAiApi .ChatCompletionChunk ;
24
+ import org .springframework .ai .mistralai .api .MistralAiApi .ChatCompletionChunk .ChunkChoice ;
25
+ import org .springframework .ai .mistralai .api .MistralAiApi .ChatCompletionFinishReason ;
26
+ import org .springframework .ai .mistralai .api .MistralAiApi .ChatCompletionMessage ;
27
+ import org .springframework .ai .mistralai .api .MistralAiApi .ChatCompletionMessage .ChatCompletionFunction ;
28
+ import org .springframework .ai .mistralai .api .MistralAiApi .ChatCompletionMessage .Role ;
29
+ import org .springframework .ai .mistralai .api .MistralAiApi .ChatCompletionMessage .ToolCall ;
30
+ import org .springframework .util .CollectionUtils ;
31
+
32
+ /**
33
+ * Helper class to support Streaming function calling.
34
+ *
35
+ * It can merge the streamed ChatCompletionChunk in case of function calling message.
36
+ *
37
+ * @author Christian Tzolov
38
+ * @since 0.8.1
39
+ */
40
+ public class MIstralAiStreamFunctionCallingHelper {
41
+
42
+ /**
43
+ * Merge the previous and current ChatCompletionChunk into a single one.
44
+ * @param previous the previous ChatCompletionChunk
45
+ * @param current the current ChatCompletionChunk
46
+ * @return the merged ChatCompletionChunk
47
+ */
48
+ public ChatCompletionChunk merge (ChatCompletionChunk previous , ChatCompletionChunk current ) {
49
+
50
+ if (previous == null ) {
51
+ return current ;
52
+ }
53
+
54
+ String id = (current .id () != null ? current .id () : previous .id ());
55
+ Long created = (current .created () != null ? current .created () : previous .created ());
56
+ String model = (current .model () != null ? current .model () : previous .model ());
57
+ String object = (current .object () != null ? current .object () : previous .object ());
58
+
59
+ ChunkChoice previousChoice0 = (CollectionUtils .isEmpty (previous .choices ()) ? null : previous .choices ().get (0 ));
60
+ ChunkChoice currentChoice0 = (CollectionUtils .isEmpty (current .choices ()) ? null : current .choices ().get (0 ));
61
+
62
+ ChunkChoice choice = merge (previousChoice0 , currentChoice0 );
63
+
64
+ return new ChatCompletionChunk (id , object , created , model , List .of (choice ));
65
+ }
66
+
67
+ private ChunkChoice merge (ChunkChoice previous , ChunkChoice current ) {
68
+ if (previous == null ) {
69
+ if (current .delta () != null && current .delta ().toolCalls () != null ) {
70
+ Optional <String > id = current .delta ()
71
+ .toolCalls ()
72
+ .stream ()
73
+ .filter (tool -> tool .id () != null )
74
+ .map (tool -> tool .id ())
75
+ .findFirst ();
76
+ if (!id .isPresent ()) {
77
+ var newId = UUID .randomUUID ().toString ();
78
+
79
+ var toolCallsWithID = current .delta ()
80
+ .toolCalls ()
81
+ .stream ()
82
+ .map (toolCall -> new ToolCall (newId , "function" , toolCall .function ()))
83
+ .toList ();
84
+
85
+ var role = current .delta ().role () != null ? current .delta ().role () : Role .ASSISTANT ;
86
+ current = new ChunkChoice (current .index (), new ChatCompletionMessage (current .delta ().content (),
87
+ role , current .delta ().name (), toolCallsWithID ), current .finishReason ());
88
+ }
89
+ }
90
+ return current ;
91
+ }
92
+
93
+ ChatCompletionFinishReason finishReason = (current .finishReason () != null ? current .finishReason ()
94
+ : previous .finishReason ());
95
+ Integer index = (current .index () != null ? current .index () : previous .index ());
96
+
97
+ ChatCompletionMessage message = merge (previous .delta (), current .delta ());
98
+
99
+ return new ChunkChoice (index , message , finishReason );
100
+ }
101
+
102
+ private ChatCompletionMessage merge (ChatCompletionMessage previous , ChatCompletionMessage current ) {
103
+ String content = (current .content () != null ? current .content ()
104
+ : "" + ((previous .content () != null ) ? previous .content () : "" ));
105
+ Role role = (current .role () != null ? current .role () : previous .role ());
106
+ role = (role != null ? role : Role .ASSISTANT ); // default to ASSISTANT (if null
107
+ String name = (current .name () != null ? current .name () : previous .name ());
108
+
109
+ List <ToolCall > toolCalls = new ArrayList <>();
110
+ ToolCall lastPreviousTooCall = null ;
111
+ if (previous .toolCalls () != null ) {
112
+ lastPreviousTooCall = previous .toolCalls ().get (previous .toolCalls ().size () - 1 );
113
+ if (previous .toolCalls ().size () > 1 ) {
114
+ toolCalls .addAll (previous .toolCalls ().subList (0 , previous .toolCalls ().size () - 1 ));
115
+ }
116
+ }
117
+ if (current .toolCalls () != null ) {
118
+ if (current .toolCalls ().size () > 1 ) {
119
+ throw new IllegalStateException ("Currently only one tool call is supported per message!" );
120
+ }
121
+ var currentToolCall = current .toolCalls ().iterator ().next ();
122
+ if (currentToolCall .id () != null ) {
123
+ if (lastPreviousTooCall != null ) {
124
+ toolCalls .add (lastPreviousTooCall );
125
+ }
126
+ toolCalls .add (currentToolCall );
127
+ }
128
+ else {
129
+ toolCalls .add (merge (lastPreviousTooCall , currentToolCall ));
130
+ }
131
+ }
132
+ else {
133
+ if (lastPreviousTooCall != null ) {
134
+ toolCalls .add (lastPreviousTooCall );
135
+ }
136
+ }
137
+ return new ChatCompletionMessage (content , role , name , toolCalls );
138
+ }
139
+
140
+ private ToolCall merge (ToolCall previous , ToolCall current ) {
141
+ if (previous == null ) {
142
+ return current ;
143
+ }
144
+ String id = (current .id () != null ? current .id () : previous .id ());
145
+ String type = (current .type () != null ? current .type () : previous .type ());
146
+ ChatCompletionFunction function = merge (previous .function (), current .function ());
147
+ return new ToolCall (id , type , function );
148
+ }
149
+
150
+ private ChatCompletionFunction merge (ChatCompletionFunction previous , ChatCompletionFunction current ) {
151
+ if (previous == null ) {
152
+ return current ;
153
+ }
154
+ String name = (current .name () != null ? current .name () : previous .name ());
155
+ StringBuilder arguments = new StringBuilder ();
156
+ if (previous .arguments () != null ) {
157
+ arguments .append (previous .arguments ());
158
+ }
159
+ if (current .arguments () != null ) {
160
+ arguments .append (current .arguments ());
161
+ }
162
+ return new ChatCompletionFunction (name , arguments .toString ());
163
+ }
164
+
165
+ /**
166
+ * @param chatCompletion the ChatCompletionChunk to check
167
+ * @return true if the ChatCompletionChunk is a streaming tool function call.
168
+ */
169
+ public boolean isStreamingToolFunctionCall (ChatCompletionChunk chatCompletion ) {
170
+
171
+ var choices = chatCompletion .choices ();
172
+ if (CollectionUtils .isEmpty (choices )) {
173
+ return false ;
174
+ }
175
+
176
+ var choice = choices .get (0 );
177
+ return !CollectionUtils .isEmpty (choice .delta ().toolCalls ());
178
+ }
179
+
180
+ /**
181
+ * @param chatCompletion the ChatCompletionChunk to check
182
+ * @return true if the ChatCompletionChunk is a streaming tool function call and it is
183
+ * the last one.
184
+ */
185
+ public boolean isStreamingToolFunctionCallFinish (ChatCompletionChunk chatCompletion ) {
186
+
187
+ var choices = chatCompletion .choices ();
188
+ if (CollectionUtils .isEmpty (choices )) {
189
+ return false ;
190
+ }
191
+
192
+ var choice = choices .get (0 );
193
+ return choice .finishReason () == ChatCompletionFinishReason .TOOL_CALL
194
+ || choice .finishReason () == ChatCompletionFinishReason .TOOL_CALLS ;
195
+ }
196
+
197
+ }
198
+ // ---
0 commit comments