Skip to content

.Net: Summarization and translation evaluation examples with Filters #6262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/_typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ EOF = "EOF" # End of File
ans = "ans" # Short for answers
arange = "arange" # Method in Python numpy package
prompty = "prompty" # prompty is a format name.
ist = "ist" # German language

[default.extend-identifiers]
ags = "ags" # Azure Graph Service
Expand Down
9 changes: 9 additions & 0 deletions dotnet/SK-dotnet.sln
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "FunctionInvocationApproval"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "CodeInterpreterPlugin", "samples\Demos\CodeInterpreterPlugin\CodeInterpreterPlugin.csproj", "{3ED53702-0E53-473A-A0F4-645DB33541C2}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "QualityCheckWithFilters", "samples\Demos\QualityCheck\QualityCheckWithFilters\QualityCheckWithFilters.csproj", "{1D3EEB5B-0E06-4700-80D5-164956E43D0A}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TimePlugin", "samples\Demos\TimePlugin\TimePlugin.csproj", "{F312FCE1-12D7-4DEF-BC29-2FF6618509F3}"
EndProject
Global
Expand Down Expand Up @@ -725,6 +727,12 @@ Global
{3ED53702-0E53-473A-A0F4-645DB33541C2}.Publish|Any CPU.Build.0 = Debug|Any CPU
{3ED53702-0E53-473A-A0F4-645DB33541C2}.Release|Any CPU.ActiveCfg = Release|Any CPU
{3ED53702-0E53-473A-A0F4-645DB33541C2}.Release|Any CPU.Build.0 = Release|Any CPU
{1D3EEB5B-0E06-4700-80D5-164956E43D0A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{1D3EEB5B-0E06-4700-80D5-164956E43D0A}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1D3EEB5B-0E06-4700-80D5-164956E43D0A}.Publish|Any CPU.ActiveCfg = Debug|Any CPU
{1D3EEB5B-0E06-4700-80D5-164956E43D0A}.Publish|Any CPU.Build.0 = Debug|Any CPU
{1D3EEB5B-0E06-4700-80D5-164956E43D0A}.Release|Any CPU.ActiveCfg = Release|Any CPU
{1D3EEB5B-0E06-4700-80D5-164956E43D0A}.Release|Any CPU.Build.0 = Release|Any CPU
{F312FCE1-12D7-4DEF-BC29-2FF6618509F3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{F312FCE1-12D7-4DEF-BC29-2FF6618509F3}.Debug|Any CPU.Build.0 = Debug|Any CPU
{F312FCE1-12D7-4DEF-BC29-2FF6618509F3}.Publish|Any CPU.ActiveCfg = Debug|Any CPU
Expand Down Expand Up @@ -831,6 +839,7 @@ Global
{925B1185-8B58-4E2D-95C9-4CA0BA9364E5} = {FA3720F1-C99A-49B2-9577-A940257098BF}
{6B56D8EE-9991-43E3-90B2-B8F5C5CE77C2} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263}
{3ED53702-0E53-473A-A0F4-645DB33541C2} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263}
{1D3EEB5B-0E06-4700-80D5-164956E43D0A} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263}
{F312FCE1-12D7-4DEF-BC29-2FF6618509F3} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using QualityCheckWithFilters.Models;
using QualityCheckWithFilters.Services;

namespace QualityCheckWithFilters.Filters;

/// <summary>
/// Filter which performs text summarization evaluation using BERTScore metric: https://huggingface.co/spaces/evaluate-metric/bertscore.
/// Evaluation result contains three values: precision, recall and F1 score.
/// The higher F1 score - the better the quality of the summary.
/// </summary>
internal sealed class BertSummarizationEvaluationFilter(
EvaluationService evaluationService,
ILogger logger,
double threshold) : IFunctionInvocationFilter
{
public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func<FunctionInvocationContext, Task> next)
{
await next(context);

var sourceText = context.Result.RenderedPrompt!;
var summary = context.Result.ToString();

var request = new SummarizationEvaluationRequest { Sources = [sourceText], Summaries = [summary] };
var response = await evaluationService.EvaluateAsync<SummarizationEvaluationRequest, BertSummarizationEvaluationResponse>(request);

var precision = Math.Round(response.Precision[0], 4);
var recall = Math.Round(response.Recall[0], 4);
var f1 = Math.Round(response.F1[0], 4);

logger.LogInformation("[BERT] Precision: {Precision}, Recall: {Recall}, F1: {F1}", precision, recall, f1);

if (f1 < threshold)
{
throw new KernelException($"BERT summary evaluation score ({f1}) is lower than threshold ({threshold})");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using QualityCheckWithFilters.Models;
using QualityCheckWithFilters.Services;

namespace QualityCheckWithFilters.Filters;

/// <summary>
/// Filter which performs text summarization evaluation using BLEU metric: https://huggingface.co/spaces/evaluate-metric/bleu.
/// Evaluation result contains values like score, precisions, brevity penalty and length ratio.
/// The closer the score and precision values are to 1 - the better the quality of the summary.
/// </summary>
internal sealed class BleuSummarizationEvaluationFilter(
EvaluationService evaluationService,
ILogger logger,
double threshold) : IFunctionInvocationFilter
{
public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func<FunctionInvocationContext, Task> next)
{
await next(context);

var sourceText = context.Result.RenderedPrompt!;
var summary = context.Result.ToString();

var request = new SummarizationEvaluationRequest { Sources = [sourceText], Summaries = [summary] };
var response = await evaluationService.EvaluateAsync<SummarizationEvaluationRequest, BleuSummarizationEvaluationResponse>(request);

var score = Math.Round(response.Score, 4);
var precisions = response.Precisions.Select(l => Math.Round(l, 4)).ToList();
var brevityPenalty = Math.Round(response.BrevityPenalty, 4);
var lengthRatio = Math.Round(response.LengthRatio, 4);

logger.LogInformation("[BLEU] Score: {Score}, Precisions: {Precisions}, Brevity penalty: {BrevityPenalty}, Length Ratio: {LengthRatio}",
score,
string.Join(", ", precisions),
brevityPenalty,
lengthRatio);

if (precisions[0] < threshold)
{
throw new KernelException($"BLEU summary evaluation score ({precisions[0]}) is lower than threshold ({threshold})");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using QualityCheckWithFilters.Models;
using QualityCheckWithFilters.Services;

namespace QualityCheckWithFilters.Filters;

/// <summary>
/// Filter which performs text translation evaluation using COMET metric: https://huggingface.co/Unbabel/wmt22-cometkiwi-da.
/// COMET score ranges from 0 to 1, where higher values indicate better translation.
/// </summary>
internal sealed class CometTranslationEvaluationFilter(
EvaluationService evaluationService,
ILogger logger,
double threshold) : IFunctionInvocationFilter
{
public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func<FunctionInvocationContext, Task> next)
{
await next(context);

var sourceText = context.Result.RenderedPrompt!;
var translation = context.Result.ToString();

logger.LogInformation("Translation: {Translation}", translation);

var request = new TranslationEvaluationRequest { Sources = [sourceText], Translations = [translation] };
var response = await evaluationService.EvaluateAsync<TranslationEvaluationRequest, CometTranslationEvaluationResponse>(request);

var score = Math.Round(response.Scores[0], 4);

logger.LogInformation("[COMET] Score: {Score}", score);

if (score < threshold)
{
throw new KernelException($"COMET translation evaluation score ({score}) is lower than threshold ({threshold})");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using QualityCheckWithFilters.Models;
using QualityCheckWithFilters.Services;

namespace QualityCheckWithFilters.Filters;

/// <summary>
/// Factory class for function invocation filters based on evaluation score type.
/// </summary>
internal sealed class FilterFactory
{
private static readonly Dictionary<EvaluationScoreType, Func<EvaluationService, ILogger, double, IFunctionInvocationFilter>> s_filters = new()
{
[EvaluationScoreType.BERT] = (service, logger, threshold) => new BertSummarizationEvaluationFilter(service, logger, threshold),
[EvaluationScoreType.BLEU] = (service, logger, threshold) => new BleuSummarizationEvaluationFilter(service, logger, threshold),
[EvaluationScoreType.METEOR] = (service, logger, threshold) => new MeteorSummarizationEvaluationFilter(service, logger, threshold),
[EvaluationScoreType.COMET] = (service, logger, threshold) => new CometTranslationEvaluationFilter(service, logger, threshold),
};

public static IFunctionInvocationFilter Create(EvaluationScoreType type, EvaluationService evaluationService, ILogger logger, double threshold)
=> s_filters[type].Invoke(evaluationService, logger, threshold);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using QualityCheckWithFilters.Models;
using QualityCheckWithFilters.Services;

namespace QualityCheckWithFilters.Filters;

/// <summary>
/// Filter which performs text summarization evaluation using METEOR metric: https://huggingface.co/spaces/evaluate-metric/meteor.
/// METEOR score ranges from 0 to 1, where higher values indicate better similarity between original text and generated summary.
/// </summary>
internal sealed class MeteorSummarizationEvaluationFilter(
EvaluationService evaluationService,
ILogger logger,
double threshold) : IFunctionInvocationFilter
{
public async Task OnFunctionInvocationAsync(FunctionInvocationContext context, Func<FunctionInvocationContext, Task> next)
{
await next(context);

var sourceText = context.Result.RenderedPrompt!;
var summary = context.Result.ToString();

var request = new SummarizationEvaluationRequest { Sources = [sourceText], Summaries = [summary] };
var response = await evaluationService.EvaluateAsync<SummarizationEvaluationRequest, MeteorSummarizationEvaluationResponse>(request);

var score = Math.Round(response.Score, 4);

logger.LogInformation("[METEOR] Score: {Score}", score);

if (score < threshold)
{
throw new KernelException($"METEOR summary evaluation score ({score}) is lower than threshold ({threshold})");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json.Serialization;

namespace QualityCheckWithFilters.Models;

/// <summary>Base request model with source texts.</summary>
internal class EvaluationRequest
{
[JsonPropertyName("sources")]
public List<string> Sources { get; set; }
}

/// <summary>Request model with generated summaries.</summary>
internal sealed class SummarizationEvaluationRequest : EvaluationRequest
{
[JsonPropertyName("summaries")]
public List<string> Summaries { get; set; }
}

/// <summary>Request model with generated translations.</summary>
internal sealed class TranslationEvaluationRequest : EvaluationRequest
{
[JsonPropertyName("translations")]
public List<string> Translations { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json.Serialization;

namespace QualityCheckWithFilters.Models;

/// <summary>Response model for BERTScore metric: https://huggingface.co/spaces/evaluate-metric/bertscore.</summary>
internal sealed class BertSummarizationEvaluationResponse
{
[JsonPropertyName("precision")]
public List<double> Precision { get; set; }

[JsonPropertyName("recall")]
public List<double> Recall { get; set; }

[JsonPropertyName("f1")]
public List<double> F1 { get; set; }
}

/// <summary>Response model for BLEU metric: https://huggingface.co/spaces/evaluate-metric/bleu.</summary>
internal sealed class BleuSummarizationEvaluationResponse
{
[JsonPropertyName("bleu")]
public double Score { get; set; }

[JsonPropertyName("precisions")]
public List<double> Precisions { get; set; }

[JsonPropertyName("brevity_penalty")]
public double BrevityPenalty { get; set; }

[JsonPropertyName("length_ratio")]
public double LengthRatio { get; set; }
}

/// <summary>Response model for METEOR metric: https://huggingface.co/spaces/evaluate-metric/meteor.</summary>
internal sealed class MeteorSummarizationEvaluationResponse
{
[JsonPropertyName("meteor")]
public double Score { get; set; }
}

/// <summary>Response model for COMET metric: https://huggingface.co/Unbabel/wmt22-cometkiwi-da.</summary>
internal sealed class CometTranslationEvaluationResponse
{
[JsonPropertyName("scores")]
public List<double> Scores { get; set; }

[JsonPropertyName("system_score")]
public double SystemScore { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics.CodeAnalysis;

namespace QualityCheckWithFilters.Models;

/// <summary>
/// Internal representation of evaluation score type to configure and run examples.
/// </summary>
internal readonly struct EvaluationScoreType(string endpoint) : IEquatable<EvaluationScoreType>
{
public string Endpoint { get; } = endpoint;

public static EvaluationScoreType BERT = new("bert-score");
public static EvaluationScoreType BLEU = new("bleu-score");
public static EvaluationScoreType METEOR = new("meteor-score");
public static EvaluationScoreType COMET = new("comet-score");

public static bool operator ==(EvaluationScoreType left, EvaluationScoreType right) => left.Equals(right);
public static bool operator !=(EvaluationScoreType left, EvaluationScoreType right) => !(left == right);

/// <inheritdoc/>
public override bool Equals([NotNullWhen(true)] object? obj) => obj is EvaluationScoreType other && this == other;

/// <inheritdoc/>
public bool Equals(EvaluationScoreType other) => string.Equals(this.Endpoint, other.Endpoint, StringComparison.OrdinalIgnoreCase);

/// <inheritdoc/>
public override int GetHashCode() => StringComparer.OrdinalIgnoreCase.GetHashCode(this.Endpoint ?? string.Empty);

/// <inheritdoc/>
public override string ToString() => this.Endpoint ?? string.Empty;
}
Loading
Loading