Skip to content

.Net: Fix MistralAI logging #6315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ internal sealed class GeminiChatCompletionClient : ClientBase
private readonly Uri _chatGenerationEndpoint;
private readonly Uri _chatStreamingEndpoint;

private static readonly string s_namespace = typeof(GeminiChatCompletionClient).Namespace!;
private static readonly string s_namespace = typeof(GoogleAIGeminiChatCompletionService).Namespace!;

/// <summary>
/// The maximum number of auto-invokes that can be in-flight at any given time as part of the current
Expand Down Expand Up @@ -622,7 +622,28 @@ private static void ValidateGeminiResponse(GeminiResponse geminiResponse)
}

private void LogUsage(List<GeminiChatMessageContent> chatMessageContents)
=> this.LogUsageMetadata(chatMessageContents[0].Metadata!);
{
GeminiMetadata? metadata = chatMessageContents[0].Metadata;

if (metadata is null || metadata.TotalTokenCount <= 0)
{
this.Logger.LogDebug("Token usage information unavailable.");
return;
}

if (this.Logger.IsEnabled(LogLevel.Information))
{
this.Logger.LogInformation(
"Prompt tokens: {PromptTokens}. Completion tokens: {CompletionTokens}. Total tokens: {TotalTokens}.",
metadata.PromptTokenCount,
metadata.CandidatesTokenCount,
metadata.TotalTokenCount);
}

s_promptTokensCounter.Add(metadata.PromptTokenCount);
s_completionTokensCounter.Add(metadata.CandidatesTokenCount);
s_totalTokensCounter.Add(metadata.TotalTokenCount);
}

private List<GeminiChatMessageContent> GetChatMessageContentsFromResponse(GeminiResponse geminiResponse)
=> geminiResponse.Candidates!.Select(candidate => this.GetChatMessageContentFromCandidate(geminiResponse, candidate)).ToList();
Expand Down Expand Up @@ -707,28 +728,6 @@ private static GeminiMetadata GetResponseMetadata(
ResponseSafetyRatings = candidate.SafetyRatings?.ToList(),
};

private void LogUsageMetadata(GeminiMetadata metadata)
{
if (metadata.TotalTokenCount <= 0)
{
this.Logger.LogDebug("Gemini usage information is not available.");
return;
}

if (this.Logger.IsEnabled(LogLevel.Debug))
{
this.Logger.LogDebug(
"Gemini usage metadata: Candidates tokens: {CandidatesTokens}, Prompt tokens: {PromptTokens}, Total tokens: {TotalTokens}",
metadata.CandidatesTokenCount,
metadata.PromptTokenCount,
metadata.TotalTokenCount);
}

s_promptTokensCounter.Add(metadata.PromptTokenCount);
s_completionTokensCounter.Add(metadata.CandidatesTokenCount);
s_totalTokensCounter.Add(metadata.TotalTokenCount);
}

private sealed class ChatCompletionState
{
internal ChatHistory ChatHistory { get; set; } = null!;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ internal sealed class HuggingFaceMessageApiClient
{
private readonly HuggingFaceClient _clientCore;

private static readonly string s_namespace = typeof(HuggingFaceMessageApiClient).Namespace!;
private static readonly string s_namespace = typeof(HuggingFaceChatCompletionService).Namespace!;

/// <summary>
/// Instance of <see cref="Meter"/> for metrics.
Expand Down Expand Up @@ -179,20 +179,25 @@ internal async Task<IReadOnlyList<ChatMessageContent>> CompleteChatMessageAsync(

private void LogChatCompletionUsage(HuggingFacePromptExecutionSettings executionSettings, ChatCompletionResponse chatCompletionResponse)
{
if (this._clientCore.Logger.IsEnabled(LogLevel.Debug))
if (chatCompletionResponse.Usage is null)
{
this._clientCore.Logger.Log(
LogLevel.Debug,
"HuggingFace chat completion usage - ModelId: {ModelId}, Prompt tokens: {PromptTokens}, Completion tokens: {CompletionTokens}, Total tokens: {TotalTokens}",
chatCompletionResponse.Model,
chatCompletionResponse.Usage!.PromptTokens,
chatCompletionResponse.Usage!.CompletionTokens,
chatCompletionResponse.Usage!.TotalTokens);
this._clientCore.Logger.LogDebug("Token usage information unavailable.");
return;
}

s_promptTokensCounter.Add(chatCompletionResponse.Usage!.PromptTokens);
s_completionTokensCounter.Add(chatCompletionResponse.Usage!.CompletionTokens);
s_totalTokensCounter.Add(chatCompletionResponse.Usage!.TotalTokens);
if (this._clientCore.Logger.IsEnabled(LogLevel.Information))
{
this._clientCore.Logger.LogInformation(
"Prompt tokens: {PromptTokens}. Completion tokens: {CompletionTokens}. Total tokens: {TotalTokens}. ModelId: {ModelId}.",
chatCompletionResponse.Usage.PromptTokens,
chatCompletionResponse.Usage.CompletionTokens,
chatCompletionResponse.Usage.TotalTokens,
chatCompletionResponse.Model);
}

s_promptTokensCounter.Add(chatCompletionResponse.Usage.PromptTokens);
s_completionTokensCounter.Add(chatCompletionResponse.Usage.CompletionTokens);
s_totalTokensCounter.Add(chatCompletionResponse.Usage.TotalTokens);
}

private static List<ChatMessageContent> GetChatMessageContentsFromResponse(ChatCompletionResponse response, string modelId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.Metrics;
using System.IO;
using System.Linq;
using System.Net.Http;
Expand All @@ -26,8 +27,6 @@ namespace Microsoft.SemanticKernel.Connectors.MistralAI.Client;
/// </summary>
internal sealed class MistralClient
{
private const string ModelProvider = "mistralai";

internal MistralClient(
string modelId,
HttpClient httpClient,
Expand Down Expand Up @@ -67,6 +66,7 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
{
using var httpRequestMessage = this.CreatePost(chatRequest, endpoint, this._apiKey, stream: false);
responseData = await this.SendRequestAsync<ChatCompletionResponse>(httpRequestMessage, cancellationToken).ConfigureAwait(false);
this.LogUsage(responseData?.Usage);
if (responseData is null || responseData.Choices is null || responseData.Choices.Count == 0)
{
throw new KernelException("Chat completions not found");
Expand Down Expand Up @@ -572,6 +572,9 @@ internal async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<
private readonly ILogger _logger;
private readonly StreamJsonParser _streamJsonParser;

/// <summary>Provider name used for diagnostics.</summary>
private const string ModelProvider = "mistralai";

/// <summary>
/// The maximum number of auto-invokes that can be in-flight at any given time as part of the current
/// asynchronous chain of execution.
Expand All @@ -593,6 +596,63 @@ internal async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<
/// <summary>Tracking <see cref="AsyncLocal{Int32}"/> for <see cref="MaxInflightAutoInvokes"/>.</summary>
private static readonly AsyncLocal<int> s_inflightAutoInvokes = new();

private static readonly string s_namespace = typeof(MistralAIChatCompletionService).Namespace!;

/// <summary>
/// Instance of <see cref="Meter"/> for metrics.
/// </summary>
private static readonly Meter s_meter = new(s_namespace);

/// <summary>
/// Instance of <see cref="Counter{T}"/> to keep track of the number of prompt tokens used.
/// </summary>
private static readonly Counter<int> s_promptTokensCounter =
s_meter.CreateCounter<int>(
name: $"{s_namespace}.tokens.prompt",
unit: "{token}",
description: "Number of prompt tokens used");

/// <summary>
/// Instance of <see cref="Counter{T}"/> to keep track of the number of completion tokens used.
/// </summary>
private static readonly Counter<int> s_completionTokensCounter =
s_meter.CreateCounter<int>(
name: $"{s_namespace}.tokens.completion",
unit: "{token}",
description: "Number of completion tokens used");

/// <summary>
/// Instance of <see cref="Counter{T}"/> to keep track of the total number of tokens used.
/// </summary>
private static readonly Counter<int> s_totalTokensCounter =
s_meter.CreateCounter<int>(
name: $"{s_namespace}.tokens.total",
unit: "{token}",
description: "Number of tokens used");

/// <summary>Log token usage to the logger and metrics.</summary>
private void LogUsage(MistralUsage? usage)
{
if (usage is null || usage.PromptTokens is null || usage.CompletionTokens is null || usage.TotalTokens is null)
{
this._logger.LogDebug("Usage information unavailable.");
return;
}

if (this._logger.IsEnabled(LogLevel.Information))
{
this._logger.LogInformation(
"Prompt tokens: {PromptTokens}. Completion tokens: {CompletionTokens}. Total tokens: {TotalTokens}.",
usage.PromptTokens,
usage.CompletionTokens,
usage.TotalTokens);
}

s_promptTokensCounter.Add(usage.PromptTokens.Value);
s_completionTokensCounter.Add(usage.CompletionTokens.Value);
s_totalTokensCounter.Add(usage.TotalTokens.Value);
}

/// <summary>
/// Messages are required and the first prompt role should be user or system.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Net.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.MistralAI;
using Microsoft.SemanticKernel.Embeddings;
Expand Down Expand Up @@ -38,7 +39,7 @@ public static IKernelBuilder AddMistralChatCompletion(
Verify.NotNullOrWhiteSpace(apiKey);

builder.Services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
new MistralAIChatCompletionService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider)));
new MistralAIChatCompletionService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider), serviceProvider.GetService<ILoggerFactory>()));

return builder;
}
Expand All @@ -64,7 +65,7 @@ public static IKernelBuilder AddMistralTextEmbeddingGeneration(
Verify.NotNull(builder);

builder.Services.AddKeyedSingleton<ITextEmbeddingGenerationService>(serviceId, (serviceProvider, _) =>
new MistralAITextEmbeddingGenerationService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider)));
new MistralAITextEmbeddingGenerationService(modelId, apiKey, endpoint, HttpClientProvider.GetHttpClient(httpClient, serviceProvider), serviceProvider.GetService<ILoggerFactory>()));

return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ internal async Task<IReadOnlyList<TextContent>> GetTextResultsAsync(
activity?.SetCompletionResponse(responseContent, responseData.Usage.PromptTokens, responseData.Usage.CompletionTokens);
}

this.CaptureUsageDetails(responseData.Usage);
this.LogUsage(responseData.Usage);

return responseContent;
}
Expand Down Expand Up @@ -396,7 +396,7 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
try
{
responseData = (await RunRequestAsync(() => this.Client.GetChatCompletionsAsync(chatOptions, cancellationToken)).ConfigureAwait(false)).Value;
this.CaptureUsageDetails(responseData.Usage);
this.LogUsage(responseData.Usage);
if (responseData.Choices.Count == 0)
{
throw new KernelException("Chat completions not found");
Expand Down Expand Up @@ -1435,11 +1435,11 @@ private static async Task<T> RunRequestAsync<T>(Func<Task<T>> request)
/// Captures usage details, including token information.
/// </summary>
/// <param name="usage">Instance of <see cref="CompletionsUsage"/> with usage details.</param>
private void CaptureUsageDetails(CompletionsUsage usage)
private void LogUsage(CompletionsUsage usage)
{
if (usage is null)
{
this.Logger.LogDebug("Usage information is not available.");
this.Logger.LogDebug("Token usage information unavailable.");
return;
}

Expand Down
Loading