Skip to content

.Net: Trace ChatHistory and PromptExecutionSettings in IChatCompletionServices #6306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions dotnet/src/Connectors/Connectors.Google/Core/ClientBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ internal abstract class ClientBase
{
private readonly Func<Task<string>>? _bearerTokenProvider;

private readonly ILogger _logger;
protected ILogger Logger { get; }

protected HttpClient HttpClient { get; }

Expand All @@ -37,7 +37,7 @@ protected ClientBase(
Verify.NotNull(httpClient);

this.HttpClient = httpClient;
this._logger = logger ?? NullLogger.Instance;
this.Logger = logger ?? NullLogger.Instance;
}

protected static void ValidateMaxTokens(int? maxTokens)
Expand Down Expand Up @@ -100,16 +100,6 @@ protected async Task<HttpRequestMessage> CreateHttpRequestAsync(object requestDa
return httpRequestMessage;
}

protected void Log(LogLevel logLevel, string? message, params object[] args)
{
if (this._logger.IsEnabled(logLevel))
{
#pragma warning disable CA2254 // Template should be a constant string.
this._logger.Log(logLevel, message, args);
#pragma warning restore CA2254
}
}

protected static string GetApiVersionSubLink(GoogleAIVersion apiVersion)
=> apiVersion switch
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -159,7 +160,7 @@ public async Task<IReadOnlyList<ChatMessageContent>> GenerateChatMessageAsync(
Kernel? kernel = null,
CancellationToken cancellationToken = default)
{
var state = ValidateInputAndCreateChatCompletionState(chatHistory, kernel, executionSettings);
var state = this.ValidateInputAndCreateChatCompletionState(chatHistory, kernel, executionSettings);

for (state.Iteration = 1; ; state.Iteration++)
{
Expand Down Expand Up @@ -222,7 +223,7 @@ public async IAsyncEnumerable<StreamingChatMessageContent> StreamGenerateChatMes
Kernel? kernel = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var state = ValidateInputAndCreateChatCompletionState(chatHistory, kernel, executionSettings);
var state = this.ValidateInputAndCreateChatCompletionState(chatHistory, kernel, executionSettings);

for (state.Iteration = 1; ; state.Iteration++)
{
Expand Down Expand Up @@ -291,7 +292,7 @@ public async IAsyncEnumerable<StreamingChatMessageContent> StreamGenerateChatMes
}
}

private static ChatCompletionState ValidateInputAndCreateChatCompletionState(
private ChatCompletionState ValidateInputAndCreateChatCompletionState(
ChatHistory chatHistory,
Kernel? kernel,
PromptExecutionSettings? executionSettings)
Expand All @@ -302,6 +303,13 @@ private static ChatCompletionState ValidateInputAndCreateChatCompletionState(
var geminiExecutionSettings = GeminiPromptExecutionSettings.FromExecutionSettings(executionSettings);
ValidateMaxTokens(geminiExecutionSettings.MaxTokens);

if (this.Logger.IsEnabled(LogLevel.Trace))
{
this.Logger.LogTrace("ChatHistory: {ChatHistory}, Settings: {Settings}",
JsonSerializer.Serialize(chatHistory),
JsonSerializer.Serialize(geminiExecutionSettings));
}

return new ChatCompletionState()
{
AutoInvoke = CheckAutoInvokeCondition(kernel, geminiExecutionSettings),
Expand Down Expand Up @@ -363,13 +371,20 @@ private async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMess

private async Task ProcessFunctionsAsync(ChatCompletionState state, CancellationToken cancellationToken)
{
this.Log(LogLevel.Debug, "Tool requests: {Requests}", state.LastMessage!.ToolCalls!.Count);
this.Log(LogLevel.Trace, "Function call requests: {FunctionCall}",
string.Join(", ", state.LastMessage.ToolCalls.Select(ftc => ftc.ToString())));
if (this.Logger.IsEnabled(LogLevel.Debug))
{
this.Logger.LogDebug("Tool requests: {Requests}", state.LastMessage!.ToolCalls!.Count);
}

if (this.Logger.IsEnabled(LogLevel.Trace))
{
this.Logger.LogTrace("Function call requests: {FunctionCall}",
string.Join(", ", state.LastMessage!.ToolCalls!.Select(ftc => ftc.ToString())));
}

// We must send back a response for every tool call, regardless of whether we successfully executed it or not.
// If we successfully execute it, we'll add the result. If we don't, we'll add an error.
foreach (var toolCall in state.LastMessage.ToolCalls)
foreach (var toolCall in state.LastMessage!.ToolCalls!)
{
await this.ProcessSingleToolCallAsync(state, toolCall, cancellationToken).ConfigureAwait(false);
}
Expand All @@ -380,8 +395,11 @@ private async Task ProcessFunctionsAsync(ChatCompletionState state, Cancellation
if (state.Iteration >= state.ExecutionSettings.ToolCallBehavior!.MaximumUseAttempts)
{
// Don't add any tools as we've reached the maximum attempts limit.
this.Log(LogLevel.Debug, "Maximum use ({MaximumUse}) reached; removing the tools.",
state.ExecutionSettings.ToolCallBehavior!.MaximumUseAttempts);
if (this.Logger.IsEnabled(LogLevel.Debug))
{
this.Logger.LogDebug("Maximum use ({MaximumUse}) reached; removing the tools.",
state.ExecutionSettings.ToolCallBehavior!.MaximumUseAttempts);
}
}
else
{
Expand All @@ -394,8 +412,11 @@ private async Task ProcessFunctionsAsync(ChatCompletionState state, Cancellation
if (state.Iteration >= state.ExecutionSettings.ToolCallBehavior!.MaximumAutoInvokeAttempts)
{
state.AutoInvoke = false;
this.Log(LogLevel.Debug, "Maximum auto-invoke ({MaximumAutoInvoke}) reached.",
state.ExecutionSettings.ToolCallBehavior!.MaximumAutoInvokeAttempts);
if (this.Logger.IsEnabled(LogLevel.Debug))
{
this.Logger.LogDebug("Maximum auto-invoke ({MaximumAutoInvoke}) reached.",
state.ExecutionSettings.ToolCallBehavior!.MaximumAutoInvokeAttempts);
}
}
}

Expand Down Expand Up @@ -473,9 +494,9 @@ private void AddToolResponseMessage(
FunctionResult? functionResponse,
string? errorMessage)
{
if (errorMessage is not null)
if (errorMessage is not null && this.Logger.IsEnabled(LogLevel.Debug))
{
this.Log(LogLevel.Debug, "Failed to handle tool request ({ToolName}). {Error}", tool.FullyQualifiedName, errorMessage);
this.Logger.LogDebug("Failed to handle tool request ({ToolName}). {Error}", tool.FullyQualifiedName, errorMessage);
}

var message = new GeminiChatMessageContent(AuthorRole.Tool,
Expand Down Expand Up @@ -690,16 +711,18 @@ private void LogUsageMetadata(GeminiMetadata metadata)
{
if (metadata.TotalTokenCount <= 0)
{
this.Log(LogLevel.Debug, "Gemini usage information is not available.");
this.Logger.LogDebug("Gemini usage information is not available.");
return;
}

this.Log(
LogLevel.Debug,
"Gemini usage metadata: Candidates tokens: {CandidatesTokens}, Prompt tokens: {PromptTokens}, Total tokens: {TotalTokens}",
metadata.CandidatesTokenCount,
metadata.PromptTokenCount,
metadata.TotalTokenCount);
if (this.Logger.IsEnabled(LogLevel.Debug))
{
this.Logger.LogDebug(
"Gemini usage metadata: Candidates tokens: {CandidatesTokens}, Prompt tokens: {PromptTokens}, Total tokens: {TotalTokens}",
metadata.CandidatesTokenCount,
metadata.PromptTokenCount,
metadata.TotalTokenCount);
}

s_promptTokensCounter.Add(metadata.PromptTokenCount);
s_completionTokensCounter.Add(metadata.CandidatesTokenCount);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -273,6 +274,14 @@ private ChatCompletionRequest CreateChatRequest(
HuggingFacePromptExecutionSettings huggingFaceExecutionSettings)
{
HuggingFaceClient.ValidateMaxTokens(huggingFaceExecutionSettings.MaxTokens);

if (this._clientCore.Logger.IsEnabled(LogLevel.Trace))
{
this._clientCore.Logger.LogTrace("ChatHistory: {ChatHistory}, Settings: {Settings}",
JsonSerializer.Serialize(chatHistory),
JsonSerializer.Serialize(huggingFaceExecutionSettings));
}

var request = ChatCompletionRequest.FromChatHistoryAndExecutionSettings(chatHistory, huggingFaceExecutionSettings);
return request;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,24 +611,27 @@ private void ValidateChatHistory(ChatHistory chatHistory)
}
}

private ChatCompletionRequest CreateChatCompletionRequest(string modelId, bool stream, ChatHistory chatHistory, MistralAIPromptExecutionSettings? executionSettings, Kernel? kernel = null)
private ChatCompletionRequest CreateChatCompletionRequest(string modelId, bool stream, ChatHistory chatHistory, MistralAIPromptExecutionSettings executionSettings, Kernel? kernel = null)
{
if (this._logger.IsEnabled(LogLevel.Trace))
{
this._logger.LogTrace("ChatHistory: {ChatHistory}, Settings: {Settings}",
JsonSerializer.Serialize(chatHistory),
JsonSerializer.Serialize(executionSettings));
}

var request = new ChatCompletionRequest(modelId)
{
Stream = stream,
Messages = chatHistory.SelectMany(chatMessage => this.ToMistralChatMessages(chatMessage, executionSettings?.ToolCallBehavior)).ToList(),
Temperature = executionSettings.Temperature,
TopP = executionSettings.TopP,
MaxTokens = executionSettings.MaxTokens,
SafePrompt = executionSettings.SafePrompt,
RandomSeed = executionSettings.RandomSeed
};

if (executionSettings is not null)
{
request.Temperature = executionSettings.Temperature;
request.TopP = executionSettings.TopP;
request.MaxTokens = executionSettings.MaxTokens;
request.SafePrompt = executionSettings.SafePrompt;
request.RandomSeed = executionSettings.RandomSeed;

executionSettings.ToolCallBehavior?.ConfigureRequest(kernel, request);
}
executionSettings.ToolCallBehavior?.ConfigureRequest(kernel, request);

return request;
}
Expand Down
19 changes: 11 additions & 8 deletions dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
ValidateAutoInvoke(autoInvoke, chatExecutionSettings.ResultsPerPrompt);

// Create the Azure SDK ChatCompletionOptions instance from all available information.
var chatOptions = CreateChatCompletionsOptions(chatExecutionSettings, chat, kernel, this.DeploymentOrModelName);
var chatOptions = this.CreateChatCompletionsOptions(chatExecutionSettings, chat, kernel, this.DeploymentOrModelName);

for (int requestIndex = 1; ; requestIndex++)
{
Expand Down Expand Up @@ -642,7 +642,7 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
bool autoInvoke = kernel is not null && chatExecutionSettings.ToolCallBehavior?.MaximumAutoInvokeAttempts > 0 && s_inflightAutoInvokes.Value < MaxInflightAutoInvokes;
ValidateAutoInvoke(autoInvoke, chatExecutionSettings.ResultsPerPrompt);

var chatOptions = CreateChatCompletionsOptions(chatExecutionSettings, chat, kernel, this.DeploymentOrModelName);
var chatOptions = this.CreateChatCompletionsOptions(chatExecutionSettings, chat, kernel, this.DeploymentOrModelName);

StringBuilder? contentBuilder = null;
Dictionary<int, string>? toolCallIdsByIndex = null;
Expand Down Expand Up @@ -1060,7 +1060,7 @@ private static CompletionsOptions CreateCompletionsOptions(string text, OpenAIPr
return options;
}

private static ChatCompletionsOptions CreateChatCompletionsOptions(
private ChatCompletionsOptions CreateChatCompletionsOptions(
OpenAIPromptExecutionSettings executionSettings,
ChatHistory chatHistory,
Kernel? kernel,
Expand All @@ -1071,6 +1071,13 @@ private static ChatCompletionsOptions CreateChatCompletionsOptions(
throw new ArgumentOutOfRangeException($"{nameof(executionSettings)}.{nameof(executionSettings.ResultsPerPrompt)}", executionSettings.ResultsPerPrompt, $"The value must be in range between 1 and {MaxResultsPerPrompt}, inclusive.");
}

if (this.Logger.IsEnabled(LogLevel.Trace))
{
this.Logger.LogTrace("ChatHistory: {ChatHistory}, Settings: {Settings}",
JsonSerializer.Serialize(chatHistory),
JsonSerializer.Serialize(executionSettings));
}

var options = new ChatCompletionsOptions
{
MaxTokens = executionSettings.MaxTokens,
Expand Down Expand Up @@ -1432,11 +1439,7 @@ private void CaptureUsageDetails(CompletionsUsage usage)
{
if (usage is null)
{
if (this.Logger.IsEnabled(LogLevel.Debug))
{
this.Logger.LogDebug("Usage information is not available.");
}

this.Logger.LogDebug("Usage information is not available.");
return;
}

Expand Down
Loading