Skip to content

.Net: fixed extension data in Model diagnostics #6275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions dotnet/samples/Demos/TelemetryWithAppInsights/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,14 @@ public bool TrySelectAIService<T>(
Temperature = 0,
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
},
GoogleAIGeminiServiceKey => new GeminiPromptExecutionSettings(),
HuggingFaceServiceKey => new HuggingFacePromptExecutionSettings(),
GoogleAIGeminiServiceKey => new GeminiPromptExecutionSettings()
{
Temperature = 0,
},
HuggingFaceServiceKey => new HuggingFacePromptExecutionSettings()
{
Temperature = 0,
},
_ => null,
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ public async Task<IReadOnlyList<ChatMessageContent>> GenerateChatMessageAsync(
GeminiResponse geminiResponse;
List<GeminiChatMessageContent> chatResponses;
using (var activity = ModelDiagnostics.StartCompletionActivity(
this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, executionSettings))
this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings))
{
try
{
Expand Down Expand Up @@ -227,7 +227,7 @@ public async IAsyncEnumerable<StreamingChatMessageContent> StreamGenerateChatMes
for (state.Iteration = 1; ; state.Iteration++)
{
using (var activity = ModelDiagnostics.StartCompletionActivity(
this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, executionSettings))
this._chatGenerationEndpoint, this._modelId, ModelProvider, chatHistory, state.ExecutionSettings))
{
HttpResponseMessage? httpResponseMessage = null;
Stream? responseStream = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,11 @@ public async Task<IReadOnlyList<TextContent>> GenerateTextAsync(
{
string modelId = executionSettings?.ModelId ?? this.ModelId;
var endpoint = this.GetTextGenerationEndpoint(modelId);
var request = this.CreateTextRequest(prompt, executionSettings);

using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId, this.ModelProvider, prompt, executionSettings);
var huggingFaceExecutionSettings = HuggingFacePromptExecutionSettings.FromExecutionSettings(executionSettings);
var request = this.CreateTextRequest(prompt, huggingFaceExecutionSettings);

using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId, this.ModelProvider, prompt, huggingFaceExecutionSettings);
using var httpRequestMessage = this.CreatePost(request, endpoint, this.ApiKey);

TextGenerationResponse response;
Expand All @@ -154,7 +156,7 @@ public async Task<IReadOnlyList<TextContent>> GenerateTextAsync(
var textContents = GetTextContentsFromResponse(response, modelId);

activity?.SetCompletionResponse(textContents);
this.LogTextGenerationUsage(executionSettings);
this.LogTextGenerationUsage(huggingFaceExecutionSettings);

return textContents;
}
Expand All @@ -166,10 +168,12 @@ public async IAsyncEnumerable<StreamingTextContent> StreamGenerateTextAsync(
{
string modelId = executionSettings?.ModelId ?? this.ModelId;
var endpoint = this.GetTextGenerationEndpoint(modelId);
var request = this.CreateTextRequest(prompt, executionSettings);

var huggingFaceExecutionSettings = HuggingFacePromptExecutionSettings.FromExecutionSettings(executionSettings);
var request = this.CreateTextRequest(prompt, huggingFaceExecutionSettings);
request.Stream = true;

using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId, this.ModelProvider, prompt, executionSettings);
using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId, this.ModelProvider, prompt, huggingFaceExecutionSettings);
HttpResponseMessage? httpResponseMessage = null;
Stream? responseStream = null;
try
Expand Down Expand Up @@ -239,9 +243,8 @@ private static StreamingTextContent GetStreamingTextContentFromStreamResponse(Te

private TextGenerationRequest CreateTextRequest(
string prompt,
PromptExecutionSettings? promptExecutionSettings)
HuggingFacePromptExecutionSettings huggingFaceExecutionSettings)
{
var huggingFaceExecutionSettings = HuggingFacePromptExecutionSettings.FromExecutionSettings(promptExecutionSettings);
ValidateMaxNewTokens(huggingFaceExecutionSettings.MaxNewTokens);
var request = TextGenerationRequest.FromPromptAndExecutionSettings(prompt, huggingFaceExecutionSettings);
return request;
Expand All @@ -253,13 +256,13 @@ private static List<TextContent> GetTextContentsFromResponse(TextGenerationRespo
private static List<TextContent> GetTextContentsFromResponse(ImageToTextGenerationResponse response, string modelId)
=> response.Select(r => new TextContent(r.GeneratedText, modelId, r, Encoding.UTF8)).ToList();

private void LogTextGenerationUsage(PromptExecutionSettings? executionSettings)
private void LogTextGenerationUsage(HuggingFacePromptExecutionSettings executionSettings)
{
if (this.Logger.IsEnabled(LogLevel.Debug))
{
this.Logger?.LogDebug(
this.Logger.LogDebug(
"HuggingFace text generation usage: ModelId: {ModelId}",
executionSettings?.ModelId ?? this.ModelId);
executionSettings.ModelId ?? this.ModelId);
}
}
private Uri GetTextGenerationEndpoint(string modelId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,14 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> StreamCompleteChatM
{
string modelId = executionSettings?.ModelId ?? this._clientCore.ModelId;
var endpoint = this.GetChatGenerationEndpoint();
var request = this.CreateChatRequest(chatHistory, executionSettings);

var huggingFaceExecutionSettings = HuggingFacePromptExecutionSettings.FromExecutionSettings(executionSettings);
huggingFaceExecutionSettings.ModelId ??= this._clientCore.ModelId;

var request = this.CreateChatRequest(chatHistory, huggingFaceExecutionSettings);
request.Stream = true;

using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId, this._clientCore.ModelProvider, chatHistory, executionSettings);
using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId, this._clientCore.ModelProvider, chatHistory, huggingFaceExecutionSettings);
HttpResponseMessage? httpResponseMessage = null;
Stream? responseStream = null;
try
Expand Down Expand Up @@ -142,9 +146,12 @@ internal async Task<IReadOnlyList<ChatMessageContent>> CompleteChatMessageAsync(
{
string modelId = executionSettings?.ModelId ?? this._clientCore.ModelId;
var endpoint = this.GetChatGenerationEndpoint();
var request = this.CreateChatRequest(chatHistory, executionSettings);

using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId, this._clientCore.ModelProvider, chatHistory, executionSettings);
var huggingFaceExecutionSettings = HuggingFacePromptExecutionSettings.FromExecutionSettings(executionSettings);
huggingFaceExecutionSettings.ModelId ??= this._clientCore.ModelId;
var request = this.CreateChatRequest(chatHistory, huggingFaceExecutionSettings);

using var activity = ModelDiagnostics.StartCompletionActivity(endpoint, modelId, this._clientCore.ModelProvider, chatHistory, huggingFaceExecutionSettings);
using var httpRequestMessage = this._clientCore.CreatePost(request, endpoint, this._clientCore.ApiKey);

ChatCompletionResponse response;
Expand All @@ -164,12 +171,12 @@ internal async Task<IReadOnlyList<ChatMessageContent>> CompleteChatMessageAsync(
var chatContents = GetChatMessageContentsFromResponse(response, modelId);

activity?.SetCompletionResponse(chatContents, response.Usage?.PromptTokens, response.Usage?.CompletionTokens);
this.LogChatCompletionUsage(executionSettings, response);
this.LogChatCompletionUsage(huggingFaceExecutionSettings, response);

return chatContents;
}

private void LogChatCompletionUsage(PromptExecutionSettings? executionSettings, ChatCompletionResponse chatCompletionResponse)
private void LogChatCompletionUsage(HuggingFacePromptExecutionSettings executionSettings, ChatCompletionResponse chatCompletionResponse)
{
if (this._clientCore.Logger.IsEnabled(LogLevel.Debug))
{
Expand Down Expand Up @@ -263,11 +270,8 @@ private async IAsyncEnumerable<StreamingChatMessageContent> ProcessChatResponseS

private ChatCompletionRequest CreateChatRequest(
ChatHistory chatHistory,
PromptExecutionSettings? promptExecutionSettings)
HuggingFacePromptExecutionSettings huggingFaceExecutionSettings)
{
var huggingFaceExecutionSettings = HuggingFacePromptExecutionSettings.FromExecutionSettings(promptExecutionSettings);
huggingFaceExecutionSettings.ModelId ??= this._clientCore.ModelId;

HuggingFaceClient.ValidateMaxTokens(huggingFaceExecutionSettings.MaxTokens);
var request = ChatCompletionRequest.FromChatHistoryAndExecutionSettings(chatHistory, huggingFaceExecutionSettings);
return request;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ internal async Task<IReadOnlyList<TextContent>> GetTextResultsAsync(

Completions? responseData = null;
List<TextContent> responseContent;
using (var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.DeploymentOrModelName, ModelProvider, prompt, executionSettings))
using (var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.DeploymentOrModelName, ModelProvider, prompt, textExecutionSettings))
{
try
{
Expand Down Expand Up @@ -183,7 +183,7 @@ internal async IAsyncEnumerable<StreamingTextContent> GetStreamingTextContentsAs

var options = CreateCompletionsOptions(prompt, textExecutionSettings, this.DeploymentOrModelName);

using var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.DeploymentOrModelName, ModelProvider, prompt, executionSettings);
using var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.DeploymentOrModelName, ModelProvider, prompt, textExecutionSettings);

StreamingResponse<Completions> response;
try
Expand Down Expand Up @@ -391,7 +391,7 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
// Make the request.
ChatCompletions? responseData = null;
List<OpenAIChatMessageContent> responseContent;
using (var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.DeploymentOrModelName, ModelProvider, chat, executionSettings))
using (var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.DeploymentOrModelName, ModelProvider, chat, chatExecutionSettings))
{
try
{
Expand Down Expand Up @@ -663,7 +663,7 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
ChatRole? streamedRole = default;
CompletionsFinishReason finishReason = default;

using (var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.DeploymentOrModelName, ModelProvider, chat, executionSettings))
using (var activity = ModelDiagnostics.StartCompletionActivity(this.Endpoint, this.DeploymentOrModelName, ModelProvider, chat, chatExecutionSettings))
{
// Make the request.
StreamingResponse<StreamingChatCompletionsUpdate> response;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,26 @@ internal static class ModelDiagnostics
/// Start a text completion activity for a given model.
/// The activity will be tagged with the a set of attributes specified by the semantic conventions.
/// </summary>
public static Activity? StartCompletionActivity(Uri? endpoint, string modelName, string modelProvider, string prompt, PromptExecutionSettings? executionSettings)
public static Activity? StartCompletionActivity<TPromptExecutionSettings>(
Uri? endpoint,
string modelName,
string modelProvider,
string prompt,
TPromptExecutionSettings? executionSettings
) where TPromptExecutionSettings : PromptExecutionSettings
=> StartCompletionActivity(endpoint, modelName, modelProvider, prompt, executionSettings, prompt => prompt);

/// <summary>
/// Start a chat completion activity for a given model.
/// The activity will be tagged with the a set of attributes specified by the semantic conventions.
/// </summary>
public static Activity? StartCompletionActivity(Uri? endpoint, string modelName, string modelProvider, ChatHistory chatHistory, PromptExecutionSettings? executionSettings)
public static Activity? StartCompletionActivity<TPromptExecutionSettings>(
Uri? endpoint,
string modelName,
string modelProvider,
ChatHistory chatHistory,
TPromptExecutionSettings? executionSettings
) where TPromptExecutionSettings : PromptExecutionSettings
=> StartCompletionActivity(endpoint, modelName, modelProvider, chatHistory, executionSettings, ToOpenAIFormat);

/// <summary>
Expand Down Expand Up @@ -109,16 +121,24 @@ public static bool IsModelDiagnosticsEnabled()
}

#region Private
private static void AddOptionalTags(Activity? activity, PromptExecutionSettings? executionSettings)
private static void AddOptionalTags<TPromptExecutionSettings>(Activity? activity, TPromptExecutionSettings? executionSettings)
where TPromptExecutionSettings : PromptExecutionSettings
{
if (activity is null || executionSettings?.ExtensionData is null)
if (activity is null || executionSettings is null)
{
return;
}

// Serialize and deserialize the execution settings to get the extension data
var deserializedSettings = JsonSerializer.Deserialize<PromptExecutionSettings>(JsonSerializer.Serialize(executionSettings));
if (deserializedSettings is null || deserializedSettings.ExtensionData is null)
{
return;
}

void TryAddTag(string key, string tag)
{
if (executionSettings.ExtensionData.TryGetValue(key, out var value))
if (deserializedSettings.ExtensionData.TryGetValue(key, out var value))
{
activity.SetTag(tag, value);
}
Expand Down Expand Up @@ -194,13 +214,13 @@ private static void ToOpenAIFormat(StringBuilder sb, ChatMessageContentItemColle
/// Start a completion activity and return the activity.
/// The `formatPrompt` delegate won't be invoked if events are disabled.
/// </summary>
private static Activity? StartCompletionActivity<T>(
private static Activity? StartCompletionActivity<TPrompt, TPromptExecutionSettings>(
Uri? endpoint,
string modelName,
string modelProvider,
T prompt,
PromptExecutionSettings? executionSettings,
Func<T, string> formatPrompt)
TPrompt prompt,
TPromptExecutionSettings? executionSettings,
Func<TPrompt, string> formatPrompt) where TPromptExecutionSettings : PromptExecutionSettings
{
if (!IsModelDiagnosticsEnabled())
{
Expand Down
Loading