Skip to content

.Net: SK integration with MEAI Abstractions (Service Selector + Contents) Phase 1 #10651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
2411de2
Service Selector investigation
RogerBarreto Feb 20, 2025
e980d1b
MEAI integration chatclient kernel service selector
RogerBarreto Feb 20, 2025
4e82956
Update UT
RogerBarreto Feb 20, 2025
fdcf588
WIP
RogerBarreto Feb 20, 2025
8911ca6
UT for IChatClient service selector WIP
RogerBarreto Feb 24, 2025
ff00e14
Undo change
RogerBarreto Feb 24, 2025
0232073
Fix warnings
RogerBarreto Feb 24, 2025
79341db
UT passing OrderedServiceSelector
RogerBarreto Feb 26, 2025
d29ca14
Ensure kernel.Invoke can return MEAI.Contents
RogerBarreto Feb 26, 2025
e73fad7
Adding bidirectional content support
RogerBarreto Feb 26, 2025
26e3cbb
Content Permutations UT
RogerBarreto Feb 26, 2025
19dab90
Ensure types converge
RogerBarreto Feb 26, 2025
92d7b49
Fix warnings
RogerBarreto Feb 26, 2025
1bd07fd
Reducing public apis
RogerBarreto Feb 26, 2025
d82f7ea
Add suppressions
RogerBarreto Feb 26, 2025
66d3f0a
Address PR comments
RogerBarreto Feb 27, 2025
4071341
Addressed Selector comments
RogerBarreto Feb 27, 2025
e2fdba6
Added integration tests
RogerBarreto Feb 27, 2025
02ed447
GptAIServiceSelector
RogerBarreto Feb 27, 2025
1f6bd8b
Add experimental + fixes
RogerBarreto Feb 27, 2025
2d4b08f
Fix UT Agents
RogerBarreto Feb 27, 2025
7a50f09
Fix usings
RogerBarreto Feb 27, 2025
f26bffb
Address PR Feedback
RogerBarreto Feb 27, 2025
bb35f00
Fix warnings
RogerBarreto Feb 28, 2025
38724da
Fix warnings
RogerBarreto Feb 28, 2025
4bc6b63
Fix warnings
RogerBarreto Feb 28, 2025
b188434
Adding missing UT, addresssing Review feedback
RogerBarreto Feb 28, 2025
178be40
Add missing UT
RogerBarreto Feb 28, 2025
d2588bb
Fix warnings
RogerBarreto Feb 28, 2025
b70d35e
Warning fix + UT
RogerBarreto Mar 3, 2025
c7fb8cd
Fix GptSelector
RogerBarreto Mar 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 47 additions & 14 deletions dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics.CodeAnalysis;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.OpenAI;
Expand All @@ -25,12 +26,16 @@ public async Task RunAsync()
endpoint: TestConfiguration.AzureOpenAI.Endpoint,
apiKey: TestConfiguration.AzureOpenAI.ApiKey,
serviceId: "AzureOpenAIChat",
modelId: TestConfiguration.AzureOpenAI.ChatModelId)
modelId: "o1-mini")
.AddOpenAIChatCompletion(
modelId: TestConfiguration.OpenAI.ChatModelId,
modelId: "o1-mini",
apiKey: TestConfiguration.OpenAI.ApiKey,
serviceId: "OpenAIChat");
builder.Services.AddSingleton<IAIServiceSelector>(new GptAIServiceSelector(this.Output)); // Use the custom AI service selector to select the GPT model
builder.Services
.AddSingleton<IAIServiceSelector>(new GptAIServiceSelector(this.Output)) // Use the custom AI service selector to select the GPT model
.AddKeyedChatClient("OpenAIChatClient", new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey)
.AsChatClient("gpt-4o")); // Add a IChatClient to the kernel

Kernel kernel = builder.Build();

// This invocation is done with the model selected by the custom selector
Expand All @@ -45,31 +50,59 @@ public async Task RunAsync()
/// a completion model whose name starts with "gpt". But this logic could
/// be as elaborate as needed to apply your own selection criteria.
/// </summary>
private sealed class GptAIServiceSelector(ITestOutputHelper output) : IAIServiceSelector
private sealed class GptAIServiceSelector(ITestOutputHelper output) : IAIServiceSelector, IServiceSelector
{
private readonly ITestOutputHelper _output = output;

public bool TrySelectAIService<T>(
public bool TrySelect<T>(
Kernel kernel, KernelFunction function, KernelArguments arguments,
[NotNullWhen(true)] out T? service, out PromptExecutionSettings? serviceSettings) where T : class, IAIService
[NotNullWhen(true)] out T? service, out PromptExecutionSettings? serviceSettings) where T : class
{
foreach (var serviceToCheck in kernel.GetAllServices<T>())
{
// Find the first service that has a model id that starts with "gpt"
var serviceModelId = serviceToCheck.GetModelId();
var endpoint = serviceToCheck.GetEndpoint();
if (!string.IsNullOrEmpty(serviceModelId) && serviceModelId.StartsWith("gpt", StringComparison.OrdinalIgnoreCase))
if (serviceToCheck is IAIService aiService)
{
this._output.WriteLine($"Selected model: {serviceModelId} {endpoint}");
service = serviceToCheck;
serviceSettings = new OpenAIPromptExecutionSettings();
return true;
// Find the first service that has a model id that starts with "gpt"
var serviceModelId = aiService.GetModelId();
var endpoint = aiService.GetEndpoint();

if (!string.IsNullOrEmpty(serviceModelId) && serviceModelId.StartsWith("gpt", StringComparison.OrdinalIgnoreCase))
{
this._output.WriteLine($"Selected model: {serviceModelId} {endpoint}");
service = serviceToCheck;
serviceSettings = new OpenAIPromptExecutionSettings();
return true;
}
}
else if (serviceToCheck is IChatClient chatClient)
{
var metadata = chatClient.GetService<ChatClientMetadata>();

// Find the first service that has a model id that starts with "gpt"
var serviceModelId = metadata?.ModelId;
var endpoint = metadata?.ProviderUri;

if (!string.IsNullOrEmpty(serviceModelId) && serviceModelId.StartsWith("gpt", StringComparison.OrdinalIgnoreCase))
{
this._output.WriteLine($"Selected model: {serviceModelId} {endpoint}");
service = serviceToCheck;
serviceSettings = new OpenAIPromptExecutionSettings();
return true;
}
}
}

service = null;
serviceSettings = null;
return false;
}

public bool TrySelectAIService<T>(
Kernel kernel,
KernelFunction function,
KernelArguments arguments,
[NotNullWhen(true)] out T? service,
out PromptExecutionSettings? serviceSettings) where T : class, IAIService
=> this.TrySelect(kernel, function, arguments, out service, out serviceSettings);
}
}
1 change: 0 additions & 1 deletion dotnet/samples/Demos/HomeAutomation/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ Example that demonstrates how to use Semantic Kernel in conjunction with depende
using Microsoft.SemanticKernel.ChatCompletion;
// For Azure OpenAI configuration
#pragma warning disable IDE0005 // Using directive is unnecessary.
using Microsoft.SemanticKernel.Connectors.AzureOpenAI;
using Microsoft.SemanticKernel.Connectors.OpenAI;

namespace HomeAutomation;
Expand Down
44 changes: 37 additions & 7 deletions dotnet/src/Agents/Core/ChatCompletionAgent.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
Expand All @@ -9,7 +10,7 @@
using Microsoft.SemanticKernel.Agents.Extensions;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Diagnostics;
using Microsoft.SemanticKernel.Services;
using MEAI = Microsoft.Extensions.AI;

namespace Microsoft.SemanticKernel.Agents;

Expand Down Expand Up @@ -101,13 +102,42 @@ internal static (IChatCompletionService service, PromptExecutionSettings? execut
{
// Need to provide a KernelFunction to the service selector as a container for the execution-settings.
KernelFunction nullPrompt = KernelFunctionFactory.CreateFromPrompt("placeholder", arguments?.ExecutionSettings?.Values);
(IChatCompletionService chatCompletionService, PromptExecutionSettings? executionSettings) =
kernel.ServiceSelector.SelectAIService<IChatCompletionService>(
kernel,
nullPrompt,
arguments ?? []);

return (chatCompletionService, executionSettings);
kernel.ServiceSelector.TrySelectAIService<IChatCompletionService>(kernel, nullPrompt, arguments ?? [], out IChatCompletionService? chatCompletionService, out PromptExecutionSettings? executionSettings);

#pragma warning disable CA2000 // Dispose objects before losing scope
if (chatCompletionService is null
&& kernel.ServiceSelector is IServiceSelector chatClientSelector
&& chatClientSelector.TrySelect<MEAI.IChatClient>(kernel, nullPrompt, arguments ?? [], out var chatClient, out executionSettings)
&& chatClient is not null)
{
// This change is temporary until Agents support IChatClient natively in near future.
chatCompletionService = chatClient!.AsChatCompletionService();
}
#pragma warning restore CA2000 // Dispose objects before losing scope

if (chatCompletionService is null)
{
var message = new StringBuilder().Append("No service was found for any of the supported types: ").Append(typeof(IChatCompletionService)).Append(", ").Append(typeof(MEAI.IChatClient)).Append('.');
if (nullPrompt.ExecutionSettings is not null)
{
string serviceIds = string.Join("|", nullPrompt.ExecutionSettings.Keys);
if (!string.IsNullOrEmpty(serviceIds))
{
message.Append(" Expected serviceIds: ").Append(serviceIds).Append('.');
}

string modelIds = string.Join("|", nullPrompt.ExecutionSettings.Values.Select(model => model.ModelId));
if (!string.IsNullOrEmpty(modelIds))
{
message.Append(" Expected modelIds: ").Append(modelIds).Append('.');
}
}

throw new KernelException(message.ToString());
}

return (chatCompletionService!, executionSettings);
}

#region private
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ public async Task FunctionUsageMetricsAreCapturedByTelemetryAsExpected()
// Set up a MeterListener to capture the measurements
using MeterListener listener = EnableTelemetryMeters();

var measurements = new Dictionary<string, List<int>>
var measurements = new Dictionary<string, List<long>>
{
["semantic_kernel.function.invocation.token_usage.prompt"] = [],
["semantic_kernel.function.invocation.token_usage.completion"] = [],
};

listener.SetMeasurementEventCallback<int>((instrument, measurement, tags, state) =>
listener.SetMeasurementEventCallback<long>((instrument, measurement, tags, state) =>
{
if (instrument.Name == "semantic_kernel.function.invocation.token_usage.prompt" ||
instrument.Name == "semantic_kernel.function.invocation.token_usage.completion")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ public async Task FunctionUsageMetricsAreCapturedByTelemetryAsExpected()
// Set up a MeterListener to capture the measurements
using MeterListener listener = EnableTelemetryMeters();

var measurements = new Dictionary<string, List<int>>
var measurements = new Dictionary<string, List<long>>
{
["semantic_kernel.function.invocation.token_usage.prompt"] = [],
["semantic_kernel.function.invocation.token_usage.completion"] = [],
};

listener.SetMeasurementEventCallback<int>((instrument, measurement, tags, state) =>
listener.SetMeasurementEventCallback<long>((instrument, measurement, tags, state) =>
{
if (instrument.Name == "semantic_kernel.function.invocation.token_usage.prompt" ||
instrument.Name == "semantic_kernel.function.invocation.token_usage.completion")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

#pragma warning disable IDE0005 // Using directive is unnecessary.
using System;
using System.Linq;
using System.Runtime.Serialization;
using System.Threading.Tasks;
using Microsoft.Extensions.Configuration;
using Microsoft.SemanticKernel;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

#pragma warning disable IDE0005 // Using directive is unnecessary.
using System;
using System.Linq;
using System.Runtime.Serialization;
using System.Threading.Tasks;
using Microsoft.Extensions.Configuration;
using Microsoft.SemanticKernel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Http.Resilience;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using OpenAI.Chat;
using OpenAI;
using SemanticKernel.IntegrationTests.TestSettings;
using Xunit;
using MEAI = Microsoft.Extensions.AI;
using OAI = OpenAI.Chat;

namespace SemanticKernel.IntegrationTests.Connectors.OpenAI;

Expand All @@ -43,6 +46,40 @@ public async Task ItCanUseOpenAiChatForTextGenerationAsync()
Assert.Contains("Uranus", result.GetValue<string>(), StringComparison.InvariantCultureIgnoreCase);
}

[Fact]
public async Task ItCanUseOpenAiChatClientAndContentsAsync()
{
var OpenAIConfiguration = this._configuration.GetSection("OpenAI").Get<OpenAIConfiguration>();
Assert.NotNull(OpenAIConfiguration);
Assert.NotNull(OpenAIConfiguration.ChatModelId);
Assert.NotNull(OpenAIConfiguration.ApiKey);
Assert.NotNull(OpenAIConfiguration.ServiceId);

// Arrange
var openAIClient = new OpenAIClient(OpenAIConfiguration.ApiKey);
var builder = Kernel.CreateBuilder();
builder.Services.AddChatClient(openAIClient.AsChatClient(OpenAIConfiguration.ChatModelId));
var kernel = builder.Build();

var func = kernel.CreateFunctionFromPrompt(
"List the two planets after '{{$input}}', excluding moons, using bullet points.",
new OpenAIPromptExecutionSettings());

// Act
var result = await func.InvokeAsync(kernel, new() { [InputParameterName] = "Jupiter" });

// Assert
Assert.NotNull(result);
Assert.Contains("Saturn", result.GetValue<string>(), StringComparison.InvariantCultureIgnoreCase);
Assert.Contains("Uranus", result.GetValue<string>(), StringComparison.InvariantCultureIgnoreCase);
var chatResponse = Assert.IsType<ChatResponse>(result.GetValue<ChatResponse>());
Assert.Contains("Saturn", chatResponse.Message.Text, StringComparison.InvariantCultureIgnoreCase);
var chatMessage = Assert.IsType<MEAI.ChatMessage>(result.GetValue<MEAI.ChatMessage>());
Assert.Contains("Uranus", chatMessage.Text, StringComparison.InvariantCultureIgnoreCase);
var chatMessageContent = Assert.IsType<ChatMessageContent>(result.GetValue<ChatMessageContent>());
Assert.Contains("Uranus", chatMessageContent.Content, StringComparison.InvariantCultureIgnoreCase);
}

[Fact]
public async Task OpenAIStreamingTestAsync()
{
Expand All @@ -65,6 +102,43 @@ public async Task OpenAIStreamingTestAsync()
Assert.Contains("Pike Place", fullResult.ToString(), StringComparison.OrdinalIgnoreCase);
}

[Fact]
public async Task ItCanUseOpenAiStreamingChatClientAndContentsAsync()
{
var OpenAIConfiguration = this._configuration.GetSection("OpenAI").Get<OpenAIConfiguration>();
Assert.NotNull(OpenAIConfiguration);
Assert.NotNull(OpenAIConfiguration.ChatModelId);
Assert.NotNull(OpenAIConfiguration.ApiKey);
Assert.NotNull(OpenAIConfiguration.ServiceId);

// Arrange
var openAIClient = new OpenAIClient(OpenAIConfiguration.ApiKey);
var builder = Kernel.CreateBuilder();
builder.Services.AddChatClient(openAIClient.AsChatClient(OpenAIConfiguration.ChatModelId));
var kernel = builder.Build();

var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin");

StringBuilder fullResultSK = new();
StringBuilder fullResultMEAI = new();

var prompt = "Where is the most famous fish market in Seattle, Washington, USA?";

// Act
await foreach (var content in kernel.InvokeStreamingAsync<StreamingKernelContent>(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt }))
{
fullResultSK.Append(content);
}
await foreach (var content in kernel.InvokeStreamingAsync<ChatResponseUpdate>(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt }))
{
fullResultMEAI.Append(content);
}

// Assert
Assert.Contains("Pike Place", fullResultSK.ToString(), StringComparison.OrdinalIgnoreCase);
Assert.Contains("Pike Place", fullResultMEAI.ToString(), StringComparison.OrdinalIgnoreCase);
}

[Fact]
public async Task OpenAIHttpRetryPolicyTestAsync()
{
Expand Down Expand Up @@ -208,7 +282,7 @@ public async Task LogProbsDataIsReturnedWhenRequestedAsync(bool? logprobs, int?
// Act
var result = await kernel.InvokePromptAsync("Hi, can you help me today?", new(settings));

var logProbabilityInfo = result.Metadata?["ContentTokenLogProbabilities"] as IReadOnlyList<ChatTokenLogProbabilityDetails>;
var logProbabilityInfo = result.Metadata?["ContentTokenLogProbabilities"] as IReadOnlyList<OAI.ChatTokenLogProbabilityDetails>;

// Assert
Assert.NotNull(logProbabilityInfo);
Expand Down
1 change: 1 addition & 0 deletions dotnet/src/IntegrationTests/IntegrationTests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
<None Remove="skills\FunSkill\Joke\skprompt.txt" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.AI.OpenAI" />
<PackageReference Include="Microsoft.Extensions.Configuration" />
<PackageReference Include="Microsoft.Extensions.Configuration.Binder" />
<PackageReference Include="Microsoft.Extensions.Configuration.EnvironmentVariables" />
Expand Down
Loading
Loading