Skip to content

.Net: Include streaming tool call information in model diagnostics #6305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 59 additions & 41 deletions dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,8 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
string? streamedName = null;
ChatRole? streamedRole = default;
CompletionsFinishReason finishReason = default;
ChatCompletionsFunctionToolCall[]? toolCalls = null;
FunctionCallContent[]? functionCallContents = null;

using (var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.DeploymentOrModelName, ModelProvider, chat, chatExecutionSettings))
{
Expand Down Expand Up @@ -717,10 +719,16 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
streamedContents?.Add(openAIStreamingChatMessageContent);
yield return openAIStreamingChatMessageContent;
}

// Translate all entries into ChatCompletionsFunctionToolCall instances.
toolCalls = OpenAIFunctionToolCall.ConvertToolCallUpdatesToChatCompletionsFunctionToolCalls(
ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex);
// Translate all entries into FunctionCallContent instances for diagnostics purposes.
functionCallContents = ModelDiagnostics.IsSensitiveEventsEnabled() ? toolCalls.Select(this.GetFunctionCallContent).ToArray() : null;
}
finally
{
activity?.EndStreaming(streamedContents);
activity?.EndStreaming(streamedContents, functionCallContents);
await responseEnumerator.DisposeAsync();
}
}
Expand All @@ -738,10 +746,6 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
// Get any response content that was streamed.
string content = contentBuilder?.ToString() ?? string.Empty;

// Translate all entries into ChatCompletionsFunctionToolCall instances.
ChatCompletionsFunctionToolCall[] toolCalls = OpenAIFunctionToolCall.ConvertToolCallUpdatesToChatCompletionsFunctionToolCalls(
ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex);

// Log the requests
if (this.Logger.IsEnabled(LogLevel.Trace))
{
Expand All @@ -755,7 +759,17 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
// Add the original assistant message to the chatOptions; this is required for the service
// to understand the tool call responses.
chatOptions.Messages.Add(GetRequestMessage(streamedRole ?? default, content, streamedName, toolCalls));
chat.Add(new OpenAIChatMessageContent(streamedRole ?? default, content, this.DeploymentOrModelName, toolCalls, metadata) { AuthorName = streamedName });
// Add the result message to the caller's chat history
var newChatMessageContent = new OpenAIChatMessageContent(streamedRole ?? default, content, this.DeploymentOrModelName, toolCalls, metadata)
{
AuthorName = streamedName
};
// Add the tool call messages to the new chat message content for diagnostics purposes.
foreach (var functionCall in functionCallContents ?? [])
{
newChatMessageContent.Items.Add(functionCall);
}
chat.Add(newChatMessageContent);

// Respond to each tooling request.
for (int toolCallIndex = 0; toolCallIndex < toolCalls.Length; toolCallIndex++)
Expand Down Expand Up @@ -1350,48 +1364,52 @@ private OpenAIChatMessageContent GetChatMessage(ChatChoice chatChoice, ChatCompl
// This allows consumers to work with functions in an LLM-agnostic way.
if (toolCall is ChatCompletionsFunctionToolCall functionToolCall)
{
Exception? exception = null;
KernelArguments? arguments = null;
try
{
arguments = JsonSerializer.Deserialize<KernelArguments>(functionToolCall.Arguments);
if (arguments is not null)
{
// Iterate over copy of the names to avoid mutating the dictionary while enumerating it
var names = arguments.Names.ToArray();
foreach (var name in names)
{
arguments[name] = arguments[name]?.ToString();
}
}
}
catch (JsonException ex)
{
exception = new KernelException("Error: Function call arguments were invalid JSON.", ex);

if (this.Logger.IsEnabled(LogLevel.Debug))
{
this.Logger.LogDebug(ex, "Failed to deserialize function arguments ({FunctionName}/{FunctionId}).", functionToolCall.Name, functionToolCall.Id);
}
}
var functionCallContent = this.GetFunctionCallContent(functionToolCall);
message.Items.Add(functionCallContent);
}
}

var functionName = FunctionName.Parse(functionToolCall.Name, OpenAIFunction.NameSeparator);
return message;
}

var functionCallContent = new FunctionCallContent(
functionName: functionName.Name,
pluginName: functionName.PluginName,
id: functionToolCall.Id,
arguments: arguments)
private FunctionCallContent GetFunctionCallContent(ChatCompletionsFunctionToolCall toolCall)
{
KernelArguments? arguments = null;
Exception? exception = null;
try
{
arguments = JsonSerializer.Deserialize<KernelArguments>(toolCall.Arguments);
if (arguments is not null)
{
// Iterate over copy of the names to avoid mutating the dictionary while enumerating it
var names = arguments.Names.ToArray();
foreach (var name in names)
{
InnerContent = functionToolCall,
Exception = exception
};
arguments[name] = arguments[name]?.ToString();
}
}
}
catch (JsonException ex)
{
exception = new KernelException("Error: Function call arguments were invalid JSON.", ex);

message.Items.Add(functionCallContent);
if (this.Logger.IsEnabled(LogLevel.Debug))
{
this.Logger.LogDebug(ex, "Failed to deserialize function arguments ({FunctionName}/{FunctionId}).", toolCall.Name, toolCall.Id);
}
}

return message;
var functionName = FunctionName.Parse(toolCall.Name, OpenAIFunction.NameSeparator);

return new FunctionCallContent(
functionName: functionName.Name,
pluginName: functionName.PluginName,
id: toolCall.Id,
arguments: arguments)
{
InnerContent = toolCall,
Exception = exception
};
}

private static void ValidateMaxTokens(int? maxTokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,17 @@ public static void SetCompletionResponse(this Activity activity, IEnumerable<Cha
/// <summary>
/// Notify the end of streaming for a given activity.
/// </summary>
public static void EndStreaming(this Activity activity, IEnumerable<StreamingKernelContent>? contents, int? promptTokens = null, int? completionTokens = null)
public static void EndStreaming(
this Activity activity,
IEnumerable<StreamingKernelContent>? contents,
IEnumerable<FunctionCallContent>? toolCalls = null,
int? promptTokens = null,
int? completionTokens = null)
{
if (IsModelDiagnosticsEnabled())
{
var choices = OrganizeStreamingContent(contents);
SetCompletionResponse(activity, choices, promptTokens, completionTokens);
SetCompletionResponse(activity, choices, toolCalls, promptTokens, completionTokens);
}
}

Expand Down Expand Up @@ -120,6 +125,12 @@ public static bool IsModelDiagnosticsEnabled()
return (s_enableDiagnostics || s_enableSensitiveEvents) && s_activitySource.HasListeners();
}

/// <summary>
/// Check if sensitive events are enabled.
/// Sensitive events are enabled if EnableSensitiveEvents is set to true and there are listeners.
/// </summary>
public static bool IsSensitiveEventsEnabled() => s_enableSensitiveEvents && s_activitySource.HasListeners();

#region Private
private static void AddOptionalTags<TPromptExecutionSettings>(Activity? activity, TPromptExecutionSettings? executionSettings)
where TPromptExecutionSettings : PromptExecutionSettings
Expand Down Expand Up @@ -170,8 +181,11 @@ private static string ToOpenAIFormat(IEnumerable<ChatMessageContent> chatHistory
sb.Append(message.Role);
sb.Append("\", \"content\": ");
sb.Append(JsonSerializer.Serialize(message.Content));
sb.Append(", \"tool_calls\": ");
ToOpenAIFormat(sb, message.Items);
if (message.Items.OfType<FunctionCallContent>().Any())
{
sb.Append(", \"tool_calls\": ");
ToOpenAIFormat(sb, message.Items);
}
sb.Append('}');

isFirst = false;
Expand Down Expand Up @@ -307,6 +321,7 @@ private static void SetCompletionResponse<T>(
private static void SetCompletionResponse(
Activity activity,
Dictionary<int, List<StreamingKernelContent>> choices,
IEnumerable<FunctionCallContent>? toolCalls,
int? promptTokens,
int? completionTokens)
{
Expand Down Expand Up @@ -334,6 +349,12 @@ private static void SetCompletionResponse(
var chatMessage = choiceContents.Value.Select(c => c.ToString()).Aggregate((a, b) => a + b);
return new ChatMessageContent(lastContent.Role ?? AuthorRole.Assistant, chatMessage, metadata: lastContent.Metadata);
}).ToList();
// It's currently not allowed to request multiple results per prompt while auto-invoke is enabled.
// Therefore, we can assume that there is only one completion per prompt when tool calls are present.
foreach (var functionCall in toolCalls ?? [])
{
chatCompletions.FirstOrDefault()?.Items.Add(functionCall);
}
SetCompletionResponse(activity, chatCompletions, promptTokens, completionTokens, ToOpenAIFormat);
break;
};
Expand Down
Loading