diff --git a/dotnet/samples/Concepts/Concepts.csproj b/dotnet/samples/Concepts/Concepts.csproj index b5cfce829772..fc2cb101c126 100644 --- a/dotnet/samples/Concepts/Concepts.csproj +++ b/dotnet/samples/Concepts/Concepts.csproj @@ -8,7 +8,7 @@ false true - $(NoWarn);CS8618,IDE0009,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101,SKEXP0110,OPENAI001,CA1724 + $(NoWarn);CS8618,IDE0009,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101,SKEXP0110,OPENAI001,CA1724,IDE1006 Library 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs b/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs index b0fdcad2e86f..d4631323c24d 100644 --- a/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs +++ b/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.OpenAI; @@ -8,29 +9,38 @@ namespace KernelExamples; +/// +/// This sample shows how to use a custom AI service selector to select a specific model by matching it's id. +/// public class CustomAIServiceSelector(ITestOutputHelper output) : BaseTest(output) { - /// - /// Show how to use a custom AI service selector to select a specific model - /// [Fact] - public async Task RunAsync() + public async Task UsingCustomSelectToSelectServiceByMatchingModelId() { - Console.WriteLine($"======== {nameof(CustomAIServiceSelector)} ========"); + Console.WriteLine($"======== {nameof(UsingCustomSelectToSelectServiceByMatchingModelId)} ========"); - // Build a kernel with multiple chat completion services + // Use the custom AI service selector to select any registered service starting with "gpt" on it's model id + var customSelector = new GptAIServiceSelector(modelNameStartsWith: "gpt", this.Output); + + // Build a kernel with multiple chat services var builder = Kernel.CreateBuilder() .AddAzureOpenAIChatCompletion( deploymentName: TestConfiguration.AzureOpenAI.ChatDeploymentName, endpoint: TestConfiguration.AzureOpenAI.Endpoint, apiKey: TestConfiguration.AzureOpenAI.ApiKey, serviceId: "AzureOpenAIChat", - modelId: TestConfiguration.AzureOpenAI.ChatModelId) + modelId: "o1-mini") .AddOpenAIChatCompletion( - modelId: TestConfiguration.OpenAI.ChatModelId, + modelId: "o1-mini", apiKey: TestConfiguration.OpenAI.ApiKey, serviceId: "OpenAIChat"); - builder.Services.AddSingleton(new GptAIServiceSelector(this.Output)); // Use the custom AI service selector to select the GPT model + + // The kernel also allows you to use a IChatClient chat service as well + builder.Services + .AddSingleton(customSelector) + .AddKeyedChatClient("OpenAIChatClient", new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey) + .AsChatClient("gpt-4o")); // Add a IChatClient to the kernel + Kernel kernel = builder.Build(); // This invocation is done with the model selected by the custom selector @@ -45,20 +55,35 @@ public async Task RunAsync() /// a completion model whose name starts with "gpt". But this logic could /// be as elaborate as needed to apply your own selection criteria. /// - private sealed class GptAIServiceSelector(ITestOutputHelper output) : IAIServiceSelector + private sealed class GptAIServiceSelector(string modelNameStartsWith, ITestOutputHelper output) : IAIServiceSelector, IChatClientSelector { private readonly ITestOutputHelper _output = output; + private readonly string _modelNameStartsWith = modelNameStartsWith; - public bool TrySelectAIService( + /// + private bool TrySelect( Kernel kernel, KernelFunction function, KernelArguments arguments, - [NotNullWhen(true)] out T? service, out PromptExecutionSettings? serviceSettings) where T : class, IAIService + [NotNullWhen(true)] out T? service, out PromptExecutionSettings? serviceSettings) where T : class { foreach (var serviceToCheck in kernel.GetAllServices()) { + string? serviceModelId = null; + string? endpoint = null; + + if (serviceToCheck is IAIService aiService) + { + serviceModelId = aiService.GetModelId(); + endpoint = aiService.GetEndpoint(); + } + else if (serviceToCheck is IChatClient chatClient) + { + var metadata = chatClient.GetService(); + serviceModelId = metadata?.ModelId; + endpoint = metadata?.ProviderUri?.ToString(); + } + // Find the first service that has a model id that starts with "gpt" - var serviceModelId = serviceToCheck.GetModelId(); - var endpoint = serviceToCheck.GetEndpoint(); - if (!string.IsNullOrEmpty(serviceModelId) && serviceModelId.StartsWith("gpt", StringComparison.OrdinalIgnoreCase)) + if (!string.IsNullOrEmpty(serviceModelId) && serviceModelId.StartsWith(this._modelNameStartsWith, StringComparison.OrdinalIgnoreCase)) { this._output.WriteLine($"Selected model: {serviceModelId} {endpoint}"); service = serviceToCheck; @@ -71,5 +96,23 @@ public bool TrySelectAIService( serviceSettings = null; return false; } + + /// + public bool TrySelectAIService( + Kernel kernel, + KernelFunction function, + KernelArguments arguments, + [NotNullWhen(true)] out T? service, + out PromptExecutionSettings? serviceSettings) where T : class, IAIService + => this.TrySelect(kernel, function, arguments, out service, out serviceSettings); + + /// + public bool TrySelectChatClient( + Kernel kernel, + KernelFunction function, + KernelArguments arguments, + [NotNullWhen(true)] out T? service, + out PromptExecutionSettings? serviceSettings) where T : class, IChatClient + => this.TrySelect(kernel, function, arguments, out service, out serviceSettings); } } diff --git a/dotnet/samples/Demos/HomeAutomation/Program.cs b/dotnet/samples/Demos/HomeAutomation/Program.cs index 3b8d1f009c2f..5b6f61cb5c2d 100644 --- a/dotnet/samples/Demos/HomeAutomation/Program.cs +++ b/dotnet/samples/Demos/HomeAutomation/Program.cs @@ -21,7 +21,6 @@ Example that demonstrates how to use Semantic Kernel in conjunction with depende using Microsoft.SemanticKernel.ChatCompletion; // For Azure OpenAI configuration #pragma warning disable IDE0005 // Using directive is unnecessary. -using Microsoft.SemanticKernel.Connectors.AzureOpenAI; using Microsoft.SemanticKernel.Connectors.OpenAI; namespace HomeAutomation; diff --git a/dotnet/src/Agents/Core/ChatCompletionAgent.cs b/dotnet/src/Agents/Core/ChatCompletionAgent.cs index 015b0a22b0f1..59eb482ff976 100644 --- a/dotnet/src/Agents/Core/ChatCompletionAgent.cs +++ b/dotnet/src/Agents/Core/ChatCompletionAgent.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Linq; using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; @@ -9,7 +10,6 @@ using Microsoft.SemanticKernel.Agents.Extensions; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Diagnostics; -using Microsoft.SemanticKernel.Services; namespace Microsoft.SemanticKernel.Agents; @@ -101,13 +101,42 @@ internal static (IChatCompletionService service, PromptExecutionSettings? execut { // Need to provide a KernelFunction to the service selector as a container for the execution-settings. KernelFunction nullPrompt = KernelFunctionFactory.CreateFromPrompt("placeholder", arguments?.ExecutionSettings?.Values); - (IChatCompletionService chatCompletionService, PromptExecutionSettings? executionSettings) = - kernel.ServiceSelector.SelectAIService( - kernel, - nullPrompt, - arguments ?? []); - return (chatCompletionService, executionSettings); + kernel.ServiceSelector.TrySelectAIService(kernel, nullPrompt, arguments ?? [], out IChatCompletionService? chatCompletionService, out PromptExecutionSettings? executionSettings); + +#pragma warning disable CA2000 // Dispose objects before losing scope + if (chatCompletionService is null + && kernel.ServiceSelector is IChatClientSelector chatClientSelector + && chatClientSelector.TrySelectChatClient(kernel, nullPrompt, arguments ?? [], out var chatClient, out executionSettings) + && chatClient is not null) + { + // This change is temporary until Agents support IChatClient natively in near future. + chatCompletionService = chatClient!.AsChatCompletionService(); + } +#pragma warning restore CA2000 // Dispose objects before losing scope + + if (chatCompletionService is null) + { + var message = new StringBuilder().Append("No service was found for any of the supported types: ").Append(typeof(IChatCompletionService)).Append(", ").Append(typeof(Microsoft.Extensions.AI.IChatClient)).Append('.'); + if (nullPrompt.ExecutionSettings is not null) + { + string serviceIds = string.Join("|", nullPrompt.ExecutionSettings.Keys); + if (!string.IsNullOrEmpty(serviceIds)) + { + message.Append(" Expected serviceIds: ").Append(serviceIds).Append('.'); + } + + string modelIds = string.Join("|", nullPrompt.ExecutionSettings.Values.Select(model => model.ModelId)); + if (!string.IsNullOrEmpty(modelIds)) + { + message.Append(" Expected modelIds: ").Append(modelIds).Append('.'); + } + } + + throw new KernelException(message.ToString()); + } + + return (chatCompletionService!, executionSettings); } #region private diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/KernelCore/KernelTests.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/KernelCore/KernelTests.cs index 61685bb1daec..cd1c2a549003 100644 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/KernelCore/KernelTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/KernelCore/KernelTests.cs @@ -62,13 +62,13 @@ public async Task FunctionUsageMetricsAreCapturedByTelemetryAsExpected() // Set up a MeterListener to capture the measurements using MeterListener listener = EnableTelemetryMeters(); - var measurements = new Dictionary> + var measurements = new Dictionary> { ["semantic_kernel.function.invocation.token_usage.prompt"] = [], ["semantic_kernel.function.invocation.token_usage.completion"] = [], }; - listener.SetMeasurementEventCallback((instrument, measurement, tags, state) => + listener.SetMeasurementEventCallback((instrument, measurement, tags, state) => { if (instrument.Name == "semantic_kernel.function.invocation.token_usage.prompt" || instrument.Name == "semantic_kernel.function.invocation.token_usage.completion") diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/KernelCore/KernelTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/KernelCore/KernelTests.cs index fdf17710b77c..fd97491adf25 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/KernelCore/KernelTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/KernelCore/KernelTests.cs @@ -63,13 +63,13 @@ public async Task FunctionUsageMetricsAreCapturedByTelemetryAsExpected() // Set up a MeterListener to capture the measurements using MeterListener listener = EnableTelemetryMeters(); - var measurements = new Dictionary> + var measurements = new Dictionary> { ["semantic_kernel.function.invocation.token_usage.prompt"] = [], ["semantic_kernel.function.invocation.token_usage.completion"] = [], }; - listener.SetMeasurementEventCallback((instrument, measurement, tags, state) => + listener.SetMeasurementEventCallback((instrument, measurement, tags, state) => { if (instrument.Name == "semantic_kernel.function.invocation.token_usage.prompt" || instrument.Name == "semantic_kernel.function.invocation.token_usage.completion") diff --git a/dotnet/src/Experimental/Process.IntegrationTests.Shared/ProcessCloudEventsTests.cs b/dotnet/src/Experimental/Process.IntegrationTests.Shared/ProcessCloudEventsTests.cs index ee262b50f7e9..0433b88f367b 100644 --- a/dotnet/src/Experimental/Process.IntegrationTests.Shared/ProcessCloudEventsTests.cs +++ b/dotnet/src/Experimental/Process.IntegrationTests.Shared/ProcessCloudEventsTests.cs @@ -1,9 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. #pragma warning disable IDE0005 // Using directive is unnecessary. -using System; -using System.Linq; -using System.Runtime.Serialization; using System.Threading.Tasks; using Microsoft.Extensions.Configuration; using Microsoft.SemanticKernel; diff --git a/dotnet/src/Experimental/Process.IntegrationTests.Shared/ProcessTests.cs b/dotnet/src/Experimental/Process.IntegrationTests.Shared/ProcessTests.cs index d5d2ca19934e..5964ed1a1773 100644 --- a/dotnet/src/Experimental/Process.IntegrationTests.Shared/ProcessTests.cs +++ b/dotnet/src/Experimental/Process.IntegrationTests.Shared/ProcessTests.cs @@ -1,9 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. #pragma warning disable IDE0005 // Using directive is unnecessary. -using System; using System.Linq; -using System.Runtime.Serialization; using System.Threading.Tasks; using Microsoft.Extensions.Configuration; using Microsoft.SemanticKernel; diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs index 1359b701e29c..1e1b58133c83 100644 --- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs @@ -8,12 +8,14 @@ using System.Text.Json; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Http.Resilience; using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.OpenAI; +using OpenAI; using OpenAI.Chat; using SemanticKernel.IntegrationTests.TestSettings; using Xunit; @@ -43,6 +45,40 @@ public async Task ItCanUseOpenAiChatForTextGenerationAsync() Assert.Contains("Uranus", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); } + [Fact] + public async Task ItCanUseOpenAiChatClientAndContentsAsync() + { + var OpenAIConfiguration = this._configuration.GetSection("OpenAI").Get(); + Assert.NotNull(OpenAIConfiguration); + Assert.NotNull(OpenAIConfiguration.ChatModelId); + Assert.NotNull(OpenAIConfiguration.ApiKey); + Assert.NotNull(OpenAIConfiguration.ServiceId); + + // Arrange + var openAIClient = new OpenAIClient(OpenAIConfiguration.ApiKey); + var builder = Kernel.CreateBuilder(); + builder.Services.AddChatClient(openAIClient.AsChatClient(OpenAIConfiguration.ChatModelId)); + var kernel = builder.Build(); + + var func = kernel.CreateFunctionFromPrompt( + "List the two planets after '{{$input}}', excluding moons, using bullet points.", + new OpenAIPromptExecutionSettings()); + + // Act + var result = await func.InvokeAsync(kernel, new() { [InputParameterName] = "Jupiter" }); + + // Assert + Assert.NotNull(result); + Assert.Contains("Saturn", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + Assert.Contains("Uranus", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + var chatResponse = Assert.IsType(result.GetValue()); + Assert.Contains("Saturn", chatResponse.Message.Text, StringComparison.InvariantCultureIgnoreCase); + var chatMessage = Assert.IsType(result.GetValue()); + Assert.Contains("Uranus", chatMessage.Text, StringComparison.InvariantCultureIgnoreCase); + var chatMessageContent = Assert.IsType(result.GetValue()); + Assert.Contains("Uranus", chatMessageContent.Content, StringComparison.InvariantCultureIgnoreCase); + } + [Fact] public async Task OpenAIStreamingTestAsync() { @@ -65,6 +101,43 @@ public async Task OpenAIStreamingTestAsync() Assert.Contains("Pike Place", fullResult.ToString(), StringComparison.OrdinalIgnoreCase); } + [Fact] + public async Task ItCanUseOpenAiStreamingChatClientAndContentsAsync() + { + var OpenAIConfiguration = this._configuration.GetSection("OpenAI").Get(); + Assert.NotNull(OpenAIConfiguration); + Assert.NotNull(OpenAIConfiguration.ChatModelId); + Assert.NotNull(OpenAIConfiguration.ApiKey); + Assert.NotNull(OpenAIConfiguration.ServiceId); + + // Arrange + var openAIClient = new OpenAIClient(OpenAIConfiguration.ApiKey); + var builder = Kernel.CreateBuilder(); + builder.Services.AddChatClient(openAIClient.AsChatClient(OpenAIConfiguration.ChatModelId)); + var kernel = builder.Build(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin"); + + StringBuilder fullResultSK = new(); + StringBuilder fullResultMEAI = new(); + + var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; + + // Act + await foreach (var content in kernel.InvokeStreamingAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt })) + { + fullResultSK.Append(content); + } + await foreach (var content in kernel.InvokeStreamingAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt })) + { + fullResultMEAI.Append(content); + } + + // Assert + Assert.Contains("Pike Place", fullResultSK.ToString(), StringComparison.OrdinalIgnoreCase); + Assert.Contains("Pike Place", fullResultMEAI.ToString(), StringComparison.OrdinalIgnoreCase); + } + [Fact] public async Task OpenAIHttpRetryPolicyTestAsync() { diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index cd4f12741f96..a88dc3386d97 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -41,6 +41,7 @@ + diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs new file mode 100644 index 000000000000..af8217d5e1fa --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel.Services; + +namespace Microsoft.SemanticKernel.AI.ChatCompletion; + +/// +/// Allow to be used as an in a +/// +internal sealed class ChatClientAIService : IAIService, IChatClient +{ + private readonly IChatClient _chatClient; + + /// + /// Storage for AI service attributes. + /// + internal Dictionary _internalAttributes { get; } = []; + + /// + /// Initializes a new instance of the class. + /// + /// Target . + internal ChatClientAIService(IChatClient chatClient) + { + Verify.NotNull(chatClient); + this._chatClient = chatClient; + + var metadata = this._chatClient.GetService(); + Verify.NotNull(metadata); + + this._internalAttributes[nameof(metadata.ModelId)] = metadata.ModelId; + this._internalAttributes[nameof(metadata.ProviderName)] = metadata.ProviderName; + this._internalAttributes[nameof(metadata.ProviderUri)] = metadata.ProviderUri; + } + + /// + public IReadOnlyDictionary Attributes => this._internalAttributes; + + /// + public void Dispose() + { + } + + /// + public Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + => this._chatClient.GetResponseAsync(chatMessages, options, cancellationToken); + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + => this._chatClient.GetService(serviceType, serviceKey); + + /// + public IAsyncEnumerable GetStreamingResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + => this._chatClient.GetStreamingResponseAsync(chatMessages, options, cancellationToken); +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs new file mode 100644 index 000000000000..92bf6b9db105 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel.ChatCompletion; + +/// Provides extension methods for . +public static class ChatClientExtensions +{ + /// + /// Get chat response which may contain multiple choices for the prompt and settings. + /// + /// Target chat client service. + /// The standardized prompt input. + /// The AI execution settings (optional). + /// The containing services, plugins, and other state for use throughout the operation. + /// The to monitor for cancellation requests. The default is . + /// Get chat response with choices generated by the remote model + internal static Task GetResponseAsync( + this IChatClient chatClient, + string prompt, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + { + var chatOptions = executionSettings.ToChatOptions(kernel); + + // Try to parse the text as a chat history + if (ChatPromptParser.TryParse(prompt, out var chatHistoryFromPrompt)) + { + var messageList = chatHistoryFromPrompt.ToChatMessageList(); + return chatClient.GetResponseAsync(messageList, chatOptions, cancellationToken); + } + + return chatClient.GetResponseAsync(prompt, chatOptions, cancellationToken); + } + + /// Creates an for the specified . + /// The chat client to be represented as a chat completion service. + /// An optional that can be used to resolve services to use in the instance. + /// + /// The . If is an , will + /// be returned. Otherwise, a new will be created that wraps . + /// + [Experimental("SKEXP0001")] + public static IChatCompletionService AsChatCompletionService(this IChatClient client, IServiceProvider? serviceProvider = null) + { + Verify.NotNull(client); + + return client is IChatCompletionService chatCompletionService ? + chatCompletionService : + new ChatClientChatCompletionService(client, serviceProvider); + } + + /// + /// Get the model identifier for the specified . + /// + [Experimental("SKEXP0001")] + public static string? GetModelId(this IChatClient client) + { + Verify.NotNull(client); + + return client.GetService()?.ModelId; + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs new file mode 100644 index 000000000000..24117a28b2ee --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Microsoft.Extensions.AI; + +internal static class ChatMessageExtensions +{ + /// Converts a to a . + /// This conversion should not be necessary once SK eventually adopts the shared content types. + internal static ChatMessageContent ToChatMessageContent(this ChatMessage message, ChatResponse? response = null) + { + ChatMessageContent result = new() + { + ModelId = response?.ModelId, + AuthorName = message.AuthorName, + InnerContent = response?.RawRepresentation ?? message.RawRepresentation, + Metadata = message.AdditionalProperties, + Role = new AuthorRole(message.Role.Value), + }; + + foreach (AIContent content in message.Contents) + { + KernelContent? resultContent = null; + switch (content) + { + case Microsoft.Extensions.AI.TextContent tc: + resultContent = new Microsoft.SemanticKernel.TextContent(tc.Text); + break; + + case Microsoft.Extensions.AI.DataContent dc when dc.MediaTypeStartsWith("image/"): + resultContent = dc.Data is not null ? + new Microsoft.SemanticKernel.ImageContent(dc.Uri) : + new Microsoft.SemanticKernel.ImageContent(new Uri(dc.Uri)); + break; + + case Microsoft.Extensions.AI.DataContent dc when dc.MediaTypeStartsWith("audio/"): + resultContent = dc.Data is not null ? + new Microsoft.SemanticKernel.AudioContent(dc.Uri) : + new Microsoft.SemanticKernel.AudioContent(new Uri(dc.Uri)); + break; + + case Microsoft.Extensions.AI.DataContent dc: + resultContent = dc.Data is not null ? + new Microsoft.SemanticKernel.BinaryContent(dc.Uri) : + new Microsoft.SemanticKernel.BinaryContent(new Uri(dc.Uri)); + break; + + case Microsoft.Extensions.AI.FunctionCallContent fcc: + resultContent = new Microsoft.SemanticKernel.FunctionCallContent(fcc.Name, null, fcc.CallId, fcc.Arguments is not null ? new(fcc.Arguments) : null); + break; + + case Microsoft.Extensions.AI.FunctionResultContent frc: + resultContent = new Microsoft.SemanticKernel.FunctionResultContent(callId: frc.CallId, result: frc.Result); + break; + } + + if (resultContent is not null) + { + resultContent.Metadata = content.AdditionalProperties; + resultContent.InnerContent = content.RawRepresentation; + resultContent.ModelId = response?.ModelId; + result.Items.Add(resultContent); + } + } + + return result; + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatResponseUpdateExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatResponseUpdateExtensions.cs new file mode 100644 index 000000000000..da505c4d131b --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatResponseUpdateExtensions.cs @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Microsoft.Extensions.AI; + +/// Provides extension methods for . +internal static class ChatResponseUpdateExtensions +{ + /// Converts a to a . + /// This conversion should not be necessary once SK eventually adopts the shared content types. + internal static StreamingChatMessageContent ToStreamingChatMessageContent(this ChatResponseUpdate update) + { + StreamingChatMessageContent content = new( + update.Role is not null ? new AuthorRole(update.Role.Value.Value) : null, + null) + { + InnerContent = update.RawRepresentation, + ChoiceIndex = update.ChoiceIndex, + Metadata = update.AdditionalProperties, + ModelId = update.ModelId + }; + + foreach (AIContent item in update.Contents) + { + StreamingKernelContent? resultContent = + item is Microsoft.Extensions.AI.TextContent tc ? new Microsoft.SemanticKernel.StreamingTextContent(tc.Text) : + item is Microsoft.Extensions.AI.FunctionCallContent fcc ? + new Microsoft.SemanticKernel.StreamingFunctionCallUpdateContent(fcc.CallId, fcc.Name, fcc.Arguments is not null ? + JsonSerializer.Serialize(fcc.Arguments!, AbstractionsJsonContext.Default.IDictionaryStringObject!) : + null) : + null; + + if (resultContent is not null) + { + resultContent.ModelId = update.ModelId; + content.Items.Add(resultContent); + } + } + + return content; + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatClientChatCompletionService.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatClientChatCompletionService.cs index 419dca381015..3a270a453bfe 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatClientChatCompletionService.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatClientChatCompletionService.cs @@ -3,12 +3,8 @@ using System; using System.Collections.Generic; using System.Collections.ObjectModel; -using System.Diagnostics.CodeAnalysis; -using System.Globalization; using System.Linq; using System.Runtime.CompilerServices; -using System.Text.Json; -using System.Text.Json.Serialization.Metadata; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -23,7 +19,7 @@ internal sealed class ChatClientChatCompletionService : IChatCompletionService private readonly IChatClient _chatClient; /// Initializes the for . - public ChatClientChatCompletionService(IChatClient chatClient, IServiceProvider? serviceProvider) + internal ChatClientChatCompletionService(IChatClient chatClient, IServiceProvider? serviceProvider) { Verify.NotNull(chatClient); @@ -54,20 +50,20 @@ public async Task> GetChatMessageContentsAsync { Verify.NotNull(chatHistory); - var messageList = ChatCompletionServiceExtensions.ToChatMessageList(chatHistory); + var messageList = chatHistory.ToChatMessageList(); var currentSize = messageList.Count; var completion = await this._chatClient.GetResponseAsync( messageList, - ToChatOptions(executionSettings, kernel), + executionSettings.ToChatOptions(kernel), cancellationToken).ConfigureAwait(false); chatHistory.AddRange( messageList .Skip(currentSize) - .Select(m => ChatCompletionServiceExtensions.ToChatMessageContent(m))); + .Select(m => m.ToChatMessageContent())); - return completion.Choices.Select(m => ChatCompletionServiceExtensions.ToChatMessageContent(m, completion)).ToList(); + return completion.Choices.Select(m => m.ToChatMessageContent(completion)).ToList(); } /// @@ -77,234 +73,11 @@ public async IAsyncEnumerable GetStreamingChatMessa Verify.NotNull(chatHistory); await foreach (var update in this._chatClient.GetStreamingResponseAsync( - ChatCompletionServiceExtensions.ToChatMessageList(chatHistory), - ToChatOptions(executionSettings, kernel), + chatHistory.ToChatMessageList(), + executionSettings.ToChatOptions(kernel), cancellationToken).ConfigureAwait(false)) { - yield return ToStreamingChatMessageContent(update); + yield return update.ToStreamingChatMessageContent(); } } - - /// Converts a pair of and to a . - private static ChatOptions? ToChatOptions(PromptExecutionSettings? settings, Kernel? kernel) - { - if (settings is null) - { - return null; - } - - if (settings.GetType() != typeof(PromptExecutionSettings)) - { - // If the settings are of a derived type, roundtrip through JSON to the base type in order to try - // to get the derived strongly-typed properties to show up in the loosely-typed ExtensionData dictionary. - // This has the unfortunate effect of making all the ExtensionData values into JsonElements, so we lose - // some type fidelity. (As an alternative, we could introduce new interfaces that could be queried for - // in this method and implemented by the derived settings types to control how they're converted to - // ChatOptions.) - settings = JsonSerializer.Deserialize( - JsonSerializer.Serialize(settings, AbstractionsJsonContext.GetTypeInfo(settings.GetType(), null)), - AbstractionsJsonContext.Default.PromptExecutionSettings); - } - - ChatOptions options = new() - { - ModelId = settings!.ModelId - }; - - if (settings!.ExtensionData is IDictionary extensionData) - { - foreach (var entry in extensionData) - { - if (entry.Key.Equals("temperature", StringComparison.OrdinalIgnoreCase) && - TryConvert(entry.Value, out float temperature)) - { - options.Temperature = temperature; - } - else if (entry.Key.Equals("top_p", StringComparison.OrdinalIgnoreCase) && - TryConvert(entry.Value, out float topP)) - { - options.TopP = topP; - } - else if (entry.Key.Equals("top_k", StringComparison.OrdinalIgnoreCase) && - TryConvert(entry.Value, out int topK)) - { - options.TopK = topK; - } - else if (entry.Key.Equals("seed", StringComparison.OrdinalIgnoreCase) && - TryConvert(entry.Value, out long seed)) - { - options.Seed = seed; - } - else if (entry.Key.Equals("max_tokens", StringComparison.OrdinalIgnoreCase) && - TryConvert(entry.Value, out int maxTokens)) - { - options.MaxOutputTokens = maxTokens; - } - else if (entry.Key.Equals("frequency_penalty", StringComparison.OrdinalIgnoreCase) && - TryConvert(entry.Value, out float frequencyPenalty)) - { - options.FrequencyPenalty = frequencyPenalty; - } - else if (entry.Key.Equals("presence_penalty", StringComparison.OrdinalIgnoreCase) && - TryConvert(entry.Value, out float presencePenalty)) - { - options.PresencePenalty = presencePenalty; - } - else if (entry.Key.Equals("stop_sequences", StringComparison.OrdinalIgnoreCase) && - TryConvert(entry.Value, out IList? stopSequences)) - { - options.StopSequences = stopSequences; - } - else if (entry.Key.Equals("response_format", StringComparison.OrdinalIgnoreCase) && - entry.Value is { } responseFormat) - { - if (TryConvert(responseFormat, out string? responseFormatString)) - { - options.ResponseFormat = responseFormatString switch - { - "text" => ChatResponseFormat.Text, - "json_object" => ChatResponseFormat.Json, - _ => null, - }; - } - else - { - options.ResponseFormat = responseFormat is JsonElement e ? ChatResponseFormat.ForJsonSchema(e) : null; - } - } - else - { - // Roundtripping a derived PromptExecutionSettings through the base type will have put all the - // object values in AdditionalProperties into JsonElements. Convert them back where possible. - object? value = entry.Value; - if (value is JsonElement jsonElement) - { - value = jsonElement.ValueKind switch - { - JsonValueKind.String => jsonElement.GetString(), - JsonValueKind.Number => jsonElement.GetDouble(), // not perfect, but a reasonable heuristic - JsonValueKind.True => true, - JsonValueKind.False => false, - JsonValueKind.Null => null, - _ => value, - }; - - if (jsonElement.ValueKind == JsonValueKind.Array) - { - var enumerator = jsonElement.EnumerateArray(); - - var enumeratorType = enumerator.MoveNext() ? enumerator.Current.ValueKind : JsonValueKind.Null; - - switch (enumeratorType) - { - case JsonValueKind.String: - value = enumerator.Select(e => e.GetString()); - break; - case JsonValueKind.Number: - value = enumerator.Select(e => e.GetDouble()); - break; - case JsonValueKind.True or JsonValueKind.False: - value = enumerator.Select(e => e.ValueKind == JsonValueKind.True); - break; - } - } - } - - (options.AdditionalProperties ??= [])[entry.Key] = value; - } - } - } - - if (settings.FunctionChoiceBehavior?.GetConfiguration(new([]) { Kernel = kernel }).Functions is { Count: > 0 } functions) - { - options.ToolMode = settings.FunctionChoiceBehavior is RequiredFunctionChoiceBehavior ? ChatToolMode.RequireAny : ChatToolMode.Auto; - options.Tools = functions.Select(f => f.AsAIFunction(kernel)).Cast().ToList(); - } - - return options; - - // Be a little lenient on the types of the values used in the extension data, - // e.g. allow doubles even when requesting floats. - static bool TryConvert(object? value, [NotNullWhen(true)] out T? result) - { - if (value is not null) - { - // If the value is a T, use it. - if (value is T typedValue) - { - result = typedValue; - return true; - } - - if (value is JsonElement json) - { - // If the value is JsonElement, it likely resulted from JSON serializing as object. - // Try to deserialize it as a T. This currently will only be successful either when - // reflection-based serialization is enabled or T is one of the types special-cased - // in the AbstractionsJsonContext. For other cases with NativeAOT, we would need to - // have a JsonSerializationOptions with the relevant type information. - if (AbstractionsJsonContext.TryGetTypeInfo(typeof(T), firstOptions: null, out JsonTypeInfo? jti)) - { - try - { - result = (T)json.Deserialize(jti)!; - return true; - } - catch (Exception e) when (e is ArgumentException or JsonException or NotSupportedException or InvalidOperationException) - { - } - } - } - else - { - // Otherwise, try to convert it to a T using Convert, in particular to handle conversions between numeric primitive types. - try - { - result = (T)Convert.ChangeType(value, typeof(T), CultureInfo.InvariantCulture); - return true; - } - catch (Exception e) when (e is ArgumentException or FormatException or InvalidCastException or OverflowException) - { - } - } - } - - result = default; - return false; - } - } - - /// Converts a to a . - /// This conversion should not be necessary once SK eventually adopts the shared content types. - private static StreamingChatMessageContent ToStreamingChatMessageContent(ChatResponseUpdate update) - { - StreamingChatMessageContent content = new( - update.Role is not null ? new AuthorRole(update.Role.Value.Value) : null, - null) - { - InnerContent = update.RawRepresentation, - ChoiceIndex = update.ChoiceIndex, - Metadata = update.AdditionalProperties, - ModelId = update.ModelId - }; - - foreach (AIContent item in update.Contents) - { - StreamingKernelContent? resultContent = - item is Microsoft.Extensions.AI.TextContent tc ? new Microsoft.SemanticKernel.StreamingTextContent(tc.Text) : - item is Microsoft.Extensions.AI.FunctionCallContent fcc ? - new Microsoft.SemanticKernel.StreamingFunctionCallUpdateContent(fcc.CallId, fcc.Name, fcc.Arguments is not null ? - JsonSerializer.Serialize(fcc.Arguments!, AbstractionsJsonContext.Default.IDictionaryStringObject!) : - null) : - null; - - if (resultContent is not null) - { - resultContent.ModelId = update.ModelId; - content.Items.Add(resultContent); - } - } - - return content; - } } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceChatClient.cs index 862239ccd505..a038c169184e 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceChatClient.cs @@ -19,7 +19,7 @@ internal sealed class ChatCompletionServiceChatClient : IChatClient private readonly IChatCompletionService _chatCompletionService; /// Initializes the for . - public ChatCompletionServiceChatClient(IChatCompletionService chatCompletionService) + internal ChatCompletionServiceChatClient(IChatCompletionService chatCompletionService) { Verify.NotNull(chatCompletionService); @@ -40,12 +40,12 @@ public ChatCompletionServiceChatClient(IChatCompletionService chatCompletionServ Verify.NotNull(chatMessages); var response = await this._chatCompletionService.GetChatMessageContentAsync( - new ChatHistory(chatMessages.Select(m => ChatCompletionServiceExtensions.ToChatMessageContent(m))), + new ChatHistory(chatMessages.Select(m => m.ToChatMessageContent())), ToPromptExecutionSettings(options), kernel: null, cancellationToken).ConfigureAwait(false); - return new(ChatCompletionServiceExtensions.ToChatMessage(response)) + return new(response.ToChatMessage()) { ModelId = response.ModelId, RawRepresentation = response.InnerContent, @@ -58,12 +58,12 @@ public async IAsyncEnumerable GetStreamingResponseAsync(ILis Verify.NotNull(chatMessages); await foreach (var update in this._chatCompletionService.GetStreamingChatMessageContentsAsync( - new ChatHistory(chatMessages.Select(m => ChatCompletionServiceExtensions.ToChatMessageContent(m))), + new ChatHistory(chatMessages.Select(m => m.ToChatMessageContent())), ToPromptExecutionSettings(options), kernel: null, cancellationToken).ConfigureAwait(false)) { - yield return ToStreamingChatCompletionUpdate(update); + yield return update.ToChatResponseUpdate(); } } @@ -191,46 +191,4 @@ public void Dispose() return settings; } - - /// Converts a to a . - /// This conversion should not be necessary once SK eventually adopts the shared content types. - private static ChatResponseUpdate ToStreamingChatCompletionUpdate(StreamingChatMessageContent content) - { - ChatResponseUpdate update = new() - { - AdditionalProperties = content.Metadata is not null ? new AdditionalPropertiesDictionary(content.Metadata) : null, - AuthorName = content.AuthorName, - ChoiceIndex = content.ChoiceIndex, - ModelId = content.ModelId, - RawRepresentation = content, - Role = content.Role is not null ? new ChatRole(content.Role.Value.Label) : null, - }; - - foreach (var item in content.Items) - { - AIContent? aiContent = null; - switch (item) - { - case Microsoft.SemanticKernel.StreamingTextContent tc: - aiContent = new Microsoft.Extensions.AI.TextContent(tc.Text); - break; - - case Microsoft.SemanticKernel.StreamingFunctionCallUpdateContent fcc: - aiContent = new Microsoft.Extensions.AI.FunctionCallContent( - fcc.CallId ?? string.Empty, - fcc.Name ?? string.Empty, - fcc.Arguments is not null ? JsonSerializer.Deserialize>(fcc.Arguments, AbstractionsJsonContext.Default.IDictionaryStringObject!) : null); - break; - } - - if (aiContent is not null) - { - aiContent.RawRepresentation = content; - - update.Contents.Add(aiContent); - } - } - - return update; - } } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceExtensions.cs index cf5834725700..844d940e5e54 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceExtensions.cs @@ -128,148 +128,4 @@ public static IChatClient AsChatClient(this IChatCompletionService service) chatClient : new ChatCompletionServiceChatClient(service); } - - /// Creates an for the specified . - /// The chat client to be represented as a chat completion service. - /// An optional that can be used to resolve services to use in the instance. - /// - /// The . If is an , will - /// be returned. Otherwise, a new will be created that wraps . - /// - [Experimental("SKEXP0001")] - public static IChatCompletionService AsChatCompletionService(this IChatClient client, IServiceProvider? serviceProvider = null) - { - Verify.NotNull(client); - - return client is IChatCompletionService chatCompletionService ? - chatCompletionService : - new ChatClientChatCompletionService(client, serviceProvider); - } - - /// Converts a to a . - /// This conversion should not be necessary once SK eventually adopts the shared content types. - internal static ChatMessage ToChatMessage(ChatMessageContent content) - { - ChatMessage message = new() - { - AdditionalProperties = content.Metadata is not null ? new(content.Metadata) : null, - AuthorName = content.AuthorName, - RawRepresentation = content.InnerContent, - Role = content.Role.Label is string label ? new ChatRole(label) : ChatRole.User, - }; - - foreach (var item in content.Items) - { - AIContent? aiContent = null; - switch (item) - { - case Microsoft.SemanticKernel.TextContent tc: - aiContent = new Microsoft.Extensions.AI.TextContent(tc.Text); - break; - - case Microsoft.SemanticKernel.ImageContent ic: - aiContent = - ic.DataUri is not null ? new Microsoft.Extensions.AI.DataContent(ic.DataUri, ic.MimeType ?? "image/*") : - ic.Uri is not null ? new Microsoft.Extensions.AI.DataContent(ic.Uri, ic.MimeType ?? "image/*") : - null; - break; - - case Microsoft.SemanticKernel.AudioContent ac: - aiContent = - ac.DataUri is not null ? new Microsoft.Extensions.AI.DataContent(ac.DataUri, ac.MimeType ?? "audio/*") : - ac.Uri is not null ? new Microsoft.Extensions.AI.DataContent(ac.Uri, ac.MimeType ?? "audio/*") : - null; - break; - - case Microsoft.SemanticKernel.BinaryContent bc: - aiContent = - bc.DataUri is not null ? new Microsoft.Extensions.AI.DataContent(bc.DataUri, bc.MimeType) : - bc.Uri is not null ? new Microsoft.Extensions.AI.DataContent(bc.Uri, bc.MimeType) : - null; - break; - - case Microsoft.SemanticKernel.FunctionCallContent fcc: - aiContent = new Microsoft.Extensions.AI.FunctionCallContent(fcc.Id ?? string.Empty, fcc.FunctionName, fcc.Arguments); - break; - - case Microsoft.SemanticKernel.FunctionResultContent frc: - aiContent = new Microsoft.Extensions.AI.FunctionResultContent(frc.CallId ?? string.Empty, frc.Result); - break; - } - - if (aiContent is not null) - { - aiContent.RawRepresentation = item.InnerContent; - aiContent.AdditionalProperties = item.Metadata is not null ? new(item.Metadata) : null; - - message.Contents.Add(aiContent); - } - } - - return message; - } - - /// Converts a to a . - /// This conversion should not be necessary once SK eventually adopts the shared content types. - internal static ChatMessageContent ToChatMessageContent(ChatMessage message, Microsoft.Extensions.AI.ChatResponse? response = null) - { - ChatMessageContent result = new() - { - ModelId = response?.ModelId, - AuthorName = message.AuthorName, - InnerContent = response?.RawRepresentation ?? message.RawRepresentation, - Metadata = message.AdditionalProperties, - Role = new AuthorRole(message.Role.Value), - }; - - foreach (AIContent content in message.Contents) - { - KernelContent? resultContent = null; - switch (content) - { - case Microsoft.Extensions.AI.TextContent tc: - resultContent = new Microsoft.SemanticKernel.TextContent(tc.Text); - break; - - case Microsoft.Extensions.AI.DataContent dc when dc.MediaTypeStartsWith("image/"): - resultContent = dc.Data is not null ? - new Microsoft.SemanticKernel.ImageContent(dc.Uri) : - new Microsoft.SemanticKernel.ImageContent(new Uri(dc.Uri)); - break; - - case Microsoft.Extensions.AI.DataContent dc when dc.MediaTypeStartsWith("audio/"): - resultContent = dc.Data is not null ? - new Microsoft.SemanticKernel.AudioContent(dc.Uri) : - new Microsoft.SemanticKernel.AudioContent(new Uri(dc.Uri)); - break; - - case Microsoft.Extensions.AI.DataContent dc: - resultContent = dc.Data is not null ? - new Microsoft.SemanticKernel.BinaryContent(dc.Uri) : - new Microsoft.SemanticKernel.BinaryContent(new Uri(dc.Uri)); - break; - - case Microsoft.Extensions.AI.FunctionCallContent fcc: - resultContent = new Microsoft.SemanticKernel.FunctionCallContent(fcc.Name, null, fcc.CallId, fcc.Arguments is not null ? new(fcc.Arguments) : null); - break; - - case Microsoft.Extensions.AI.FunctionResultContent frc: - resultContent = new Microsoft.SemanticKernel.FunctionResultContent(callId: frc.CallId, result: frc.Result); - break; - } - - if (resultContent is not null) - { - resultContent.Metadata = content.AdditionalProperties; - resultContent.InnerContent = content.RawRepresentation; - resultContent.ModelId = response?.ModelId; - result.Items.Add(resultContent); - } - } - - return result; - } - - internal static List ToChatMessageList(ChatHistory chatHistory) - => chatHistory.Select(ToChatMessage).ToList(); } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs index faf11b2fe450..a238e77417da 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel.ChatCompletion; @@ -80,4 +81,7 @@ public static async Task ReduceAsync(this ChatHistory chatHistory, return chatHistory; } + + internal static List ToChatMessageList(this ChatHistory chatHistory) + => chatHistory.Select(m => m.ToChatMessage()).ToList(); } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs new file mode 100644 index 000000000000..98bb09be6f85 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs @@ -0,0 +1,204 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel; + +internal static class PromptExecutionSettingsExtensions +{ + /// Converts a pair of and to a . + internal static ChatOptions? ToChatOptions(this PromptExecutionSettings? settings, Kernel? kernel) + { + if (settings is null) + { + return null; + } + + if (settings.GetType() != typeof(PromptExecutionSettings)) + { + // If the settings are of a derived type, roundtrip through JSON to the base type in order to try + // to get the derived strongly-typed properties to show up in the loosely-typed ExtensionData dictionary. + // This has the unfortunate effect of making all the ExtensionData values into JsonElements, so we lose + // some type fidelity. (As an alternative, we could introduce new interfaces that could be queried for + // in this method and implemented by the derived settings types to control how they're converted to + // ChatOptions.) + settings = JsonSerializer.Deserialize( + JsonSerializer.Serialize(settings, AbstractionsJsonContext.GetTypeInfo(settings.GetType(), null)), + AbstractionsJsonContext.Default.PromptExecutionSettings); + } + + ChatOptions options = new() + { + ModelId = settings!.ModelId + }; + + if (settings!.ExtensionData is IDictionary extensionData) + { + foreach (var entry in extensionData) + { + if (entry.Key.Equals("temperature", StringComparison.OrdinalIgnoreCase) && + TryConvert(entry.Value, out float temperature)) + { + options.Temperature = temperature; + } + else if (entry.Key.Equals("top_p", StringComparison.OrdinalIgnoreCase) && + TryConvert(entry.Value, out float topP)) + { + options.TopP = topP; + } + else if (entry.Key.Equals("top_k", StringComparison.OrdinalIgnoreCase) && + TryConvert(entry.Value, out int topK)) + { + options.TopK = topK; + } + else if (entry.Key.Equals("seed", StringComparison.OrdinalIgnoreCase) && + TryConvert(entry.Value, out long seed)) + { + options.Seed = seed; + } + else if (entry.Key.Equals("max_tokens", StringComparison.OrdinalIgnoreCase) && + TryConvert(entry.Value, out int maxTokens)) + { + options.MaxOutputTokens = maxTokens; + } + else if (entry.Key.Equals("frequency_penalty", StringComparison.OrdinalIgnoreCase) && + TryConvert(entry.Value, out float frequencyPenalty)) + { + options.FrequencyPenalty = frequencyPenalty; + } + else if (entry.Key.Equals("presence_penalty", StringComparison.OrdinalIgnoreCase) && + TryConvert(entry.Value, out float presencePenalty)) + { + options.PresencePenalty = presencePenalty; + } + else if (entry.Key.Equals("stop_sequences", StringComparison.OrdinalIgnoreCase) && + TryConvert(entry.Value, out IList? stopSequences)) + { + options.StopSequences = stopSequences; + } + else if (entry.Key.Equals("response_format", StringComparison.OrdinalIgnoreCase) && + entry.Value is { } responseFormat) + { + if (TryConvert(responseFormat, out string? responseFormatString)) + { + options.ResponseFormat = responseFormatString switch + { + "text" => ChatResponseFormat.Text, + "json_object" => ChatResponseFormat.Json, + _ => null, + }; + } + else + { + options.ResponseFormat = responseFormat is JsonElement e ? ChatResponseFormat.ForJsonSchema(e) : null; + } + } + else + { + // Roundtripping a derived PromptExecutionSettings through the base type will have put all the + // object values in AdditionalProperties into JsonElements. Convert them back where possible. + object? value = entry.Value; + if (value is JsonElement jsonElement) + { + value = jsonElement.ValueKind switch + { + JsonValueKind.String => jsonElement.GetString(), + JsonValueKind.Number => jsonElement.GetDouble(), // not perfect, but a reasonable heuristic + JsonValueKind.True => true, + JsonValueKind.False => false, + JsonValueKind.Null => null, + _ => value, + }; + + if (jsonElement.ValueKind == JsonValueKind.Array) + { + var enumerator = jsonElement.EnumerateArray(); + + var enumeratorType = enumerator.MoveNext() ? enumerator.Current.ValueKind : JsonValueKind.Null; + + switch (enumeratorType) + { + case JsonValueKind.String: + value = enumerator.Select(e => e.GetString()); + break; + case JsonValueKind.Number: + value = enumerator.Select(e => e.GetDouble()); + break; + case JsonValueKind.True or JsonValueKind.False: + value = enumerator.Select(e => e.ValueKind == JsonValueKind.True); + break; + } + } + } + + (options.AdditionalProperties ??= [])[entry.Key] = value; + } + } + } + + if (settings.FunctionChoiceBehavior?.GetConfiguration(new([]) { Kernel = kernel }).Functions is { Count: > 0 } functions) + { + options.ToolMode = settings.FunctionChoiceBehavior is RequiredFunctionChoiceBehavior ? ChatToolMode.RequireAny : ChatToolMode.Auto; + options.Tools = functions.Select(f => f.AsAIFunction(kernel)).Cast().ToList(); + } + + return options; + + // Be a little lenient on the types of the values used in the extension data, + // e.g. allow doubles even when requesting floats. + static bool TryConvert(object? value, [NotNullWhen(true)] out T? result) + { + if (value is not null) + { + // If the value is a T, use it. + if (value is T typedValue) + { + result = typedValue; + return true; + } + + if (value is JsonElement json) + { + // If the value is JsonElement, it likely resulted from JSON serializing as object. + // Try to deserialize it as a T. This currently will only be successful either when + // reflection-based serialization is enabled or T is one of the types special-cased + // in the AbstractionsJsonContext. For other cases with NativeAOT, we would need to + // have a JsonSerializationOptions with the relevant type information. + if (AbstractionsJsonContext.TryGetTypeInfo(typeof(T), firstOptions: null, out JsonTypeInfo? jti)) + { + try + { + result = (T)json.Deserialize(jti)!; + return true; + } + catch (Exception e) when (e is ArgumentException or JsonException or NotSupportedException or InvalidOperationException) + { + } + } + } + else + { + // Otherwise, try to convert it to a T using Convert, in particular to handle conversions between numeric primitive types. + try + { + result = (T)Convert.ChangeType(value, typeof(T), CultureInfo.InvariantCulture); + return true; + } + catch (Exception e) when (e is ArgumentException or FormatException or InvalidCastException or OverflowException) + { + } + } + } + + result = default; + return false; + } + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/CompatibilitySuppressions.xml b/dotnet/src/SemanticKernel.Abstractions/CompatibilitySuppressions.xml new file mode 100644 index 000000000000..da61649a30bd --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/CompatibilitySuppressions.xml @@ -0,0 +1,18 @@ + + + + + CP0002 + M:Microsoft.SemanticKernel.ChatCompletion.ChatCompletionServiceExtensions.AsChatCompletionService(Microsoft.Extensions.AI.IChatClient,System.IServiceProvider) + lib/net8.0/Microsoft.SemanticKernel.Abstractions.dll + lib/net8.0/Microsoft.SemanticKernel.Abstractions.dll + true + + + CP0002 + M:Microsoft.SemanticKernel.ChatCompletion.ChatCompletionServiceExtensions.AsChatCompletionService(Microsoft.Extensions.AI.IChatClient,System.IServiceProvider) + lib/netstandard2.0/Microsoft.SemanticKernel.Abstractions.dll + lib/netstandard2.0/Microsoft.SemanticKernel.Abstractions.dll + true + + \ No newline at end of file diff --git a/dotnet/src/SemanticKernel.Abstractions/Contents/ChatMessageContentExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/Contents/ChatMessageContentExtensions.cs new file mode 100644 index 000000000000..2e8d45ea89e0 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Contents/ChatMessageContentExtensions.cs @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel; + +internal static class ChatMessageContentExtensions +{ + /// Converts a to a . + /// This conversion should not be necessary once SK eventually adopts the shared content types. + internal static ChatMessage ToChatMessage(this ChatMessageContent content) + { + ChatMessage message = new() + { + AdditionalProperties = content.Metadata is not null ? new(content.Metadata) : null, + AuthorName = content.AuthorName, + RawRepresentation = content.InnerContent, + Role = content.Role.Label is string label ? new ChatRole(label) : ChatRole.User, + }; + + foreach (var item in content.Items) + { + AIContent? aiContent = null; + switch (item) + { + case Microsoft.SemanticKernel.TextContent tc: + aiContent = new Microsoft.Extensions.AI.TextContent(tc.Text); + break; + + case Microsoft.SemanticKernel.ImageContent ic: + aiContent = + ic.DataUri is not null ? new Microsoft.Extensions.AI.DataContent(ic.DataUri, ic.MimeType ?? "image/*") : + ic.Uri is not null ? new Microsoft.Extensions.AI.DataContent(ic.Uri, ic.MimeType ?? "image/*") : + null; + break; + + case Microsoft.SemanticKernel.AudioContent ac: + aiContent = + ac.DataUri is not null ? new Microsoft.Extensions.AI.DataContent(ac.DataUri, ac.MimeType ?? "audio/*") : + ac.Uri is not null ? new Microsoft.Extensions.AI.DataContent(ac.Uri, ac.MimeType ?? "audio/*") : + null; + break; + + case Microsoft.SemanticKernel.BinaryContent bc: + aiContent = + bc.DataUri is not null ? new Microsoft.Extensions.AI.DataContent(bc.DataUri, bc.MimeType) : + bc.Uri is not null ? new Microsoft.Extensions.AI.DataContent(bc.Uri, bc.MimeType) : + null; + break; + + case Microsoft.SemanticKernel.FunctionCallContent fcc: + aiContent = new Microsoft.Extensions.AI.FunctionCallContent(fcc.Id ?? string.Empty, fcc.FunctionName, fcc.Arguments); + break; + + case Microsoft.SemanticKernel.FunctionResultContent frc: + aiContent = new Microsoft.Extensions.AI.FunctionResultContent(frc.CallId ?? string.Empty, frc.Result); + break; + } + + if (aiContent is not null) + { + aiContent.RawRepresentation = item.InnerContent; + aiContent.AdditionalProperties = item.Metadata is not null ? new(item.Metadata) : null; + + message.Contents.Add(aiContent); + } + } + + return message; + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Contents/StreamingChatMessageContentExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/Contents/StreamingChatMessageContentExtensions.cs new file mode 100644 index 000000000000..ae955bfad14f --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Contents/StreamingChatMessageContentExtensions.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Text.Json; +using Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel; + +/// Provides extension methods for . +internal static class StreamingChatMessageContentExtensions +{ + /// Converts a to a . + /// This conversion should not be necessary once SK eventually adopts the shared content types. + internal static ChatResponseUpdate ToChatResponseUpdate(this StreamingChatMessageContent content) + { + ChatResponseUpdate update = new() + { + AdditionalProperties = content.Metadata is not null ? new AdditionalPropertiesDictionary(content.Metadata) : null, + AuthorName = content.AuthorName, + ChoiceIndex = content.ChoiceIndex, + ModelId = content.ModelId, + RawRepresentation = content.InnerContent, + Role = content.Role is not null ? new ChatRole(content.Role.Value.Label) : null, + }; + + foreach (var item in content.Items) + { + AIContent? aiContent = null; + switch (item) + { + case Microsoft.SemanticKernel.StreamingTextContent tc: + aiContent = new Microsoft.Extensions.AI.TextContent(tc.Text); + break; + + case Microsoft.SemanticKernel.StreamingFunctionCallUpdateContent fcc: + aiContent = new Microsoft.Extensions.AI.FunctionCallContent( + fcc.CallId ?? string.Empty, + fcc.Name ?? string.Empty, + fcc.Arguments is not null ? JsonSerializer.Deserialize>(fcc.Arguments, AbstractionsJsonContext.Default.IDictionaryStringObject!) : null); + break; + } + + if (aiContent is not null) + { + aiContent.RawRepresentation = content; + + update.Contents.Add(aiContent); + } + } + + return update; + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs index 0902a4f80c98..945e9cc7a74e 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs @@ -3,6 +3,8 @@ using System; using System.Collections.Generic; using System.Globalization; +using System.Linq; +using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel; @@ -101,6 +103,104 @@ public FunctionResult(FunctionResult result, object? value = null) { return innerContent; } + + // Attempting to use the new Microsoft.Extensions.AI Chat types will trigger automatic conversion of SK chat contents. + + // ChatMessageContent as ChatMessage + if (typeof(T) == typeof(ChatMessage) + && content is ChatMessageContent chatMessageContent) + { + return (T?)(object)chatMessageContent.ToChatMessage(); + } + + // ChatMessageContent as ChatResponse + if (typeof(T) == typeof(ChatResponse) + && content is ChatMessageContent singleChoiceMessageContent) + { + return (T?)(object)new Microsoft.Extensions.AI.ChatResponse(singleChoiceMessageContent.ToChatMessage()); + } + } + + if (this.Value is IReadOnlyList messageContentList) + { + if (messageContentList.Count == 0) + { + throw new InvalidCastException($"Cannot cast a response with no choices to {typeof(T)}"); + } + + if (typeof(T) == typeof(ChatResponse)) + { + return (T)(object)new ChatResponse(messageContentList.Select(m => m.ToChatMessage()).ToList()); + } + + var firstMessage = messageContentList[0]; + if (typeof(T) == typeof(ChatMessage)) + { + return (T)(object)firstMessage.ToChatMessage(); + } + } + + if (this.Value is Microsoft.Extensions.AI.ChatResponse chatResponse) + { + // If no choices are present, return default + if (chatResponse.Choices.Count == 0) + { + throw new InvalidCastException($"Cannot cast a response with no choices to {typeof(T)}"); + } + + var chatMessage = chatResponse.Message; + if (typeof(T) == typeof(string)) + { + return (T?)(object?)chatMessage.ToString(); + } + + // ChatMessage from a ChatResponse + if (typeof(T) == typeof(ChatMessage)) + { + return (T?)(object)chatMessage; + } + + if (typeof(Microsoft.Extensions.AI.AIContent).IsAssignableFrom(typeof(T))) + { + // Return the first matching content type of a message if any + var updateContent = chatMessage.Contents.FirstOrDefault(c => c is T); + if (updateContent is not null) + { + return (T)(object)updateContent; + } + } + + if (chatMessage.Contents is T contentsList) + { + return contentsList; + } + + if (chatResponse.RawRepresentation is T rawResponseRepresentation) + { + return rawResponseRepresentation; + } + + if (chatMessage.RawRepresentation is T rawMessageRepresentation) + { + return rawMessageRepresentation; + } + + if (typeof(Microsoft.Extensions.AI.AIContent).IsAssignableFrom(typeof(T))) + { + // Return the first matching content type of a message if any + var updateContent = chatMessage.Contents.FirstOrDefault(c => c is T); + if (updateContent is not null) + { + return (T)(object)updateContent; + } + } + + // Avoid breaking changes this transformation will be dropped once we migrate fully to Microsoft.Extensions.AI abstractions. + // This is also necessary to don't break existing code using KernelContents when using IChatClient connectors. + if (typeof(KernelContent).IsAssignableFrom(typeof(T))) + { + return (T)(object)chatMessage.ToChatMessageContent(); + } } throw new InvalidCastException($"Cannot cast {this.Value.GetType()} to {typeof(T)}"); diff --git a/dotnet/src/SemanticKernel.Abstractions/Services/AIServiceExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/Services/AIServiceExtensions.cs index 30a3ee7794e5..679864841dbb 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Services/AIServiceExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Services/AIServiceExtensions.cs @@ -64,7 +64,7 @@ public static class AIServiceExtensions /// /// /// Specifies the type of the required. This must be the same type - /// with which the service was registered in the orvia + /// with which the service was registered in the or via /// the . /// /// The to use to select a service from the . diff --git a/dotnet/src/SemanticKernel.Abstractions/Services/IAIServiceSelector.cs b/dotnet/src/SemanticKernel.Abstractions/Services/IAIServiceSelector.cs index 93064508d118..353abb9715cc 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Services/IAIServiceSelector.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Services/IAIServiceSelector.cs @@ -16,7 +16,7 @@ public interface IAIServiceSelector /// /// /// Specifies the type of the required. This must be the same type - /// with which the service was registered in the orvia + /// with which the service was registered in the or via /// the . /// /// The containing services, plugins, and other state for use throughout the operation. diff --git a/dotnet/src/SemanticKernel.Abstractions/Services/IChatClientSelector.cs b/dotnet/src/SemanticKernel.Abstractions/Services/IChatClientSelector.cs new file mode 100644 index 000000000000..30f8e2bcb4e6 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Services/IChatClientSelector.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.SemanticKernel; + +#pragma warning disable CA1716 // Identifiers should not match keywords + +/// +/// Represents a selector which will return a combination of the containing instances of T and it's pairing +/// from the specified provider based on the model settings. +/// +[Experimental("SKEXP0001")] +public interface IChatClientSelector +{ + /// + /// Resolves an and associated from the specified + /// based on a and associated . + /// + /// + /// Specifies the type of the required. This must be the same type + /// with which the service was registered in the or via + /// the . + /// + /// The containing services, plugins, and other state for use throughout the operation. + /// The function. + /// The function arguments. + /// The selected service, or null if none was selected. + /// The settings associated with the selected service. This may be null even if a service is selected. + /// true if a matching service was selected; otherwise, false. + bool TrySelectChatClient( + Kernel kernel, + KernelFunction function, + KernelArguments arguments, + [NotNullWhen(true)] out T? service, + out PromptExecutionSettings? serviceSettings) where T : class, IChatClient; +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Services/OrderedAIServiceSelector.cs b/dotnet/src/SemanticKernel.Abstractions/Services/OrderedAIServiceSelector.cs index 1200acd3a803..c11851c4d46d 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Services/OrderedAIServiceSelector.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Services/OrderedAIServiceSelector.cs @@ -3,7 +3,9 @@ using System; using System.Diagnostics.CodeAnalysis; using System.Linq; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.SemanticKernel.Services; @@ -11,15 +13,23 @@ namespace Microsoft.SemanticKernel.Services; /// Implementation of that selects the AI service based on the order of the execution settings. /// Uses the service id or model id to select the preferred service provider and then returns the service and associated execution settings. /// -internal sealed class OrderedAIServiceSelector : IAIServiceSelector +internal sealed class OrderedAIServiceSelector : IAIServiceSelector, IChatClientSelector { public static OrderedAIServiceSelector Instance { get; } = new(); /// - public bool TrySelectAIService( + [Experimental("SKEXP0001")] + public bool TrySelectChatClient(Kernel kernel, KernelFunction function, KernelArguments arguments, [NotNullWhen(true)] out T? service, out PromptExecutionSettings? serviceSettings) where T : class, IChatClient + => this.TrySelect(kernel, function, arguments, out service, out serviceSettings); + + /// + public bool TrySelectAIService(Kernel kernel, KernelFunction function, KernelArguments arguments, [NotNullWhen(true)] out T? service, out PromptExecutionSettings? serviceSettings) where T : class, IAIService + => this.TrySelect(kernel, function, arguments, out service, out serviceSettings); + + private bool TrySelect( Kernel kernel, KernelFunction function, KernelArguments arguments, [NotNullWhen(true)] out T? service, - out PromptExecutionSettings? serviceSettings) where T : class, IAIService + out PromptExecutionSettings? serviceSettings) where T : class { // Allow the execution settings from the kernel arguments to take precedence var executionSettings = arguments.ExecutionSettings ?? function.ExecutionSettings; @@ -94,11 +104,20 @@ kernel.Services is IKeyedServiceProvider ? kernel.Services.GetService(); } - private T? GetServiceByModelId(Kernel kernel, string modelId) where T : class, IAIService + private T? GetServiceByModelId(Kernel kernel, string modelId) where T : class { foreach (var service in kernel.GetAllServices()) { - string? serviceModelId = service.GetModelId(); + string? serviceModelId = null; + if (service is IAIService aiService) + { + serviceModelId = aiService.GetModelId(); + } + else if (service is IChatClient chatClient) + { + serviceModelId = chatClient.GetModelId(); + } + if (!string.IsNullOrEmpty(serviceModelId) && serviceModelId == modelId) { return service; diff --git a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs index 367e5e7a2553..3fc35c0f3d15 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs @@ -7,11 +7,14 @@ using System.Diagnostics.Metrics; using System.Linq; using System.Runtime.CompilerServices; +using System.Text; using System.Text.Json; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.SemanticKernel.AI.ChatCompletion; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Services; using Microsoft.SemanticKernel.TextGeneration; @@ -252,6 +255,7 @@ protected override async ValueTask InvokeCoreAsync( { IChatCompletionService chatCompletion => await this.GetChatCompletionResultAsync(chatCompletion, kernel, promptRenderingResult, cancellationToken).ConfigureAwait(false), ITextGenerationService textGeneration => await this.GetTextGenerationResultAsync(textGeneration, kernel, promptRenderingResult, cancellationToken).ConfigureAwait(false), + IChatClient chatClient => await this.GetChatClientResultAsync(chatClient, kernel, promptRenderingResult, cancellationToken).ConfigureAwait(false), // The service selector didn't find an appropriate service. This should only happen with a poorly implemented selector. _ => throw new NotSupportedException($"The AI service {promptRenderingResult.AIService.GetType()} is not supported. Supported services are {typeof(IChatCompletionService)} and {typeof(ITextGenerationService)}") }; @@ -271,7 +275,7 @@ protected override async IAsyncEnumerable InvokeStreamingCoreAsync? asyncReference = null; + IAsyncEnumerable? asyncReference = null; if (result.AIService is IChatCompletionService chatCompletion) { @@ -281,32 +285,114 @@ protected override async IAsyncEnumerable InvokeStreamingCoreAsync (TResult)(object)content.ToString(), - - _ when content is TResult contentAsT - => contentAsT, - - _ when content.InnerContent is TResult innerContentAsT - => innerContentAsT, - - _ when typeof(TResult) == typeof(byte[]) - => (TResult)(object)content.ToByteArray(), + if (typeof(TResult) == typeof(string)) + { + yield return (TResult)(object)kernelContent.ToString(); + continue; + } + + if (content is TResult contentAsT) + { + yield return contentAsT; + continue; + } + + if (kernelContent.InnerContent is TResult innerContentAsT) + { + yield return innerContentAsT; + continue; + } + + if (typeof(TResult) == typeof(byte[])) + { + if (content is StreamingKernelContent byteKernelContent) + { + yield return (TResult)(object)byteKernelContent.ToByteArray(); + continue; + } + } + + // Attempting to use the new Microsoft Extensions AI types will trigger automatic conversion of SK chat contents. + if (typeof(ChatResponseUpdate).IsAssignableFrom(typeof(TResult)) + && content is StreamingChatMessageContent streamingChatMessageContent) + { + yield return (TResult)(object)streamingChatMessageContent.ToChatResponseUpdate(); + continue; + } + } + else if (content is ChatResponseUpdate chatUpdate) + { + if (typeof(TResult) == typeof(string)) + { + yield return (TResult)(object)chatUpdate.ToString(); + continue; + } + + if (chatUpdate is TResult contentAsT) + { + yield return contentAsT; + continue; + } + + if (chatUpdate.Contents is TResult contentListsAsT) + { + yield return contentListsAsT; + continue; + } + + if (chatUpdate.RawRepresentation is TResult rawRepresentationAsT) + { + yield return rawRepresentationAsT; + continue; + } + + if (typeof(Microsoft.Extensions.AI.AIContent).IsAssignableFrom(typeof(TResult))) + { + // Return the first matching content type of an update if any + var updateContent = chatUpdate.Contents.FirstOrDefault(c => c is TResult); + if (updateContent is not null) + { + yield return (TResult)(object)updateContent; + continue; + } + } + + if (typeof(TResult) == typeof(byte[])) + { + DataContent? dataContent = (DataContent?)chatUpdate.Contents.FirstOrDefault(c => c is DataContent dataContent && dataContent.Data.HasValue); + if (dataContent is not null) + { + yield return (TResult)(object)dataContent.Data!.Value.ToArray(); + continue; + } + } + + // Avoid breaking changes this transformation will be dropped once we migrate fully to Microsoft Extensions AI abstractions. + // This is also necessary to don't break existing code using KernelContents when using IChatClient connectors. + if (typeof(StreamingKernelContent).IsAssignableFrom(typeof(TResult))) + { + yield return (TResult)(object)chatUpdate.ToStreamingChatMessageContent(); + continue; + } + } - _ => throw new NotSupportedException($"The specific type {typeof(TResult)} is not supported. Support types are {typeof(StreamingTextContent)}, string, byte[], or a matching type for {typeof(StreamingTextContent)}.{nameof(StreamingTextContent.InnerContent)} property") - }; + throw new NotSupportedException($"The specific type {typeof(TResult)} is not supported. Support types are derivations of {typeof(StreamingKernelContent)}, {typeof(StreamingKernelContent)}, string, byte[], or a matching type for {typeof(StreamingKernelContent)}.{nameof(StreamingKernelContent.InnerContent)} property"); } // There is no post cancellation check to override the result as the stream data was already sent. @@ -450,13 +536,13 @@ private KernelFunctionFromPrompt( private const string MeasurementModelTagName = "semantic_kernel.function.model_id"; /// to record function invocation prompt token usage. - private static readonly Histogram s_invocationTokenUsagePrompt = s_meter.CreateHistogram( + private static readonly Histogram s_invocationTokenUsagePrompt = s_meter.CreateHistogram( name: "semantic_kernel.function.invocation.token_usage.prompt", unit: "{token}", description: "Measures the prompt token usage"); /// to record function invocation completion token usage. - private static readonly Histogram s_invocationTokenUsageCompletion = s_meter.CreateHistogram( + private static readonly Histogram s_invocationTokenUsageCompletion = s_meter.CreateHistogram( name: "semantic_kernel.function.invocation.token_usage.completion", unit: "{token}", description: "Measures the completion token usage"); @@ -481,7 +567,7 @@ private async Task RenderPromptAsync( { var serviceSelector = kernel.ServiceSelector; - IAIService? aiService; + IAIService? aiService = null; string renderedPrompt = string.Empty; // Try to use IChatCompletionService. @@ -491,12 +577,41 @@ private async Task RenderPromptAsync( { aiService = chatService; } - else + else if (serviceSelector.TrySelectAIService( + kernel, this, arguments, + out ITextGenerationService? textService, out executionSettings)) { - // If IChatCompletionService isn't available, try to fallback to ITextGenerationService, - // throwing if it's not available. - (aiService, executionSettings) = serviceSelector.SelectAIService(kernel, this, arguments); + aiService = textService; + } +#pragma warning disable CA2000 // Dispose objects before losing scope + else if (serviceSelector is IChatClientSelector chatClientServiceSelector + && chatClientServiceSelector.TrySelectChatClient(kernel, this, arguments, out var chatClient, out executionSettings)) + { + // Resolves a ChatClient as AIService so it don't need to implement IChatCompletionService. + aiService = new ChatClientAIService(chatClient); + } + + if (aiService is null) + { + var message = new StringBuilder().Append("No service was found for any of the supported types: ").Append(typeof(IChatCompletionService)).Append(", ").Append(typeof(ITextGenerationService)).Append(", ").Append(typeof(IChatClient)).Append('.'); + if (this.ExecutionSettings is not null) + { + string serviceIds = string.Join("|", this.ExecutionSettings.Keys); + if (!string.IsNullOrEmpty(serviceIds)) + { + message.Append(" Expected serviceIds: ").Append(serviceIds).Append('.'); + } + + string modelIds = string.Join("|", this.ExecutionSettings.Values.Select(model => model.ModelId)); + if (!string.IsNullOrEmpty(modelIds)) + { + message.Append(" Expected modelIds: ").Append(modelIds).Append('.'); + } + } + + throw new KernelException(message.ToString()); } +#pragma warning restore CA2000 // Dispose objects before losing scope Verify.NotNull(aiService); @@ -615,6 +730,46 @@ JsonElement SerializeToElement(object? value) } } + /// + /// Captures usage details, including token information. + /// + private void CaptureUsageDetails(string? modelId, UsageDetails? usageDetails, ILogger logger) + { + if (!logger.IsEnabled(LogLevel.Information) && + !s_invocationTokenUsageCompletion.Enabled && + !s_invocationTokenUsagePrompt.Enabled) + { + // Bail early to avoid unnecessary work. + return; + } + + if (string.IsNullOrWhiteSpace(modelId)) + { + logger.LogInformation("No model ID provided to capture usage details."); + return; + } + + if (usageDetails is null) + { + logger.LogInformation("No usage details was provided."); + return; + } + + if (usageDetails.InputTokenCount.HasValue && usageDetails.OutputTokenCount.HasValue) + { + TagList tags = new() { + { MeasurementFunctionTagName, this.Name }, + { MeasurementModelTagName, modelId } + }; + s_invocationTokenUsagePrompt.Record(usageDetails.InputTokenCount.Value, in tags); + s_invocationTokenUsageCompletion.Record(usageDetails.OutputTokenCount.Value, in tags); + } + else + { + logger.LogWarning("Unable to get token details from model result."); + } + } + private async Task GetChatCompletionResultAsync( IChatCompletionService chatCompletion, Kernel kernel, @@ -646,6 +801,40 @@ private async Task GetChatCompletionResultAsync( return new FunctionResult(this, chatContents, kernel.Culture) { RenderedPrompt = promptRenderingResult.RenderedPrompt }; } + private async Task GetChatClientResultAsync( + IChatClient chatClient, + Kernel kernel, + PromptRenderingResult promptRenderingResult, + CancellationToken cancellationToken) + { + var chatResponse = await chatClient.GetResponseAsync( + promptRenderingResult.RenderedPrompt, + promptRenderingResult.ExecutionSettings, + kernel, + cancellationToken).ConfigureAwait(false); + + if (chatResponse.Choices is { Count: 0 }) + { + return new FunctionResult(this, chatResponse) + { + Culture = kernel.Culture, + RenderedPrompt = promptRenderingResult.RenderedPrompt + }; + } + + var modelId = chatClient.GetService()?.ModelId; + + // Usage details are global and duplicated for each chat message content, use first one to get usage information + this.CaptureUsageDetails(chatClient.GetService()?.ModelId, chatResponse.Usage, this._logger); + + return new FunctionResult(this, chatResponse) + { + Culture = kernel.Culture, + RenderedPrompt = promptRenderingResult.RenderedPrompt, + Metadata = chatResponse.AdditionalProperties, + }; + } + private async Task GetTextGenerationResultAsync( ITextGenerationService textGeneration, Kernel kernel, diff --git a/dotnet/src/SemanticKernel.UnitTests/AI/ServiceConversionExtensionsTests.cs b/dotnet/src/SemanticKernel.UnitTests/AI/ServiceConversionExtensionsTests.cs index 556799ecc85e..7b13b2f999e8 100644 --- a/dotnet/src/SemanticKernel.UnitTests/AI/ServiceConversionExtensionsTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/AI/ServiceConversionExtensionsTests.cs @@ -25,7 +25,7 @@ public void InvalidArgumentsThrow() Assert.Throws("generator", () => EmbeddingGenerationExtensions.AsEmbeddingGenerationService(null!)); Assert.Throws("service", () => ChatCompletionServiceExtensions.AsChatClient(null!)); - Assert.Throws("client", () => ChatCompletionServiceExtensions.AsChatCompletionService(null!)); + Assert.Throws("client", () => Microsoft.SemanticKernel.ChatCompletion.ChatClientExtensions.AsChatCompletionService(null!)); } [Fact] diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs new file mode 100644 index 000000000000..5a67a0ecf370 --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Services; +using Xunit; + +namespace SemanticKernel.UnitTests.Functions; + +public class CustomAIChatClientSelectorTests +{ + [Fact] + public void ItGetsChatClientUsingModelIdAttribute() + { + // Arrange + IKernelBuilder builder = Kernel.CreateBuilder(); + using var chatClient = new ChatClientTest(); + builder.Services.AddKeyedSingleton("service1", chatClient); + Kernel kernel = builder.Build(); + + var function = kernel.CreateFunctionFromPrompt("Hello AI"); + IChatClientSelector chatClientSelector = new CustomChatClientSelector(); + + // Act + chatClientSelector.TrySelectChatClient(kernel, function, [], out var selectedChatClient, out var defaultExecutionSettings); + + // Assert + Assert.NotNull(selectedChatClient); + Assert.Equal("Value1", selectedChatClient.GetModelId()); + Assert.Null(defaultExecutionSettings); + selectedChatClient.Dispose(); + } + + private sealed class CustomChatClientSelector : IChatClientSelector + { +#pragma warning disable CS8769 // Nullability of reference types in value doesn't match target type. Cannot use [NotNullWhen] because of access to internals from abstractions. + public bool TrySelectChatClient(Kernel kernel, KernelFunction function, KernelArguments arguments, [NotNullWhen(true)] out T? service, out PromptExecutionSettings? serviceSettings) + where T : class, IChatClient + { + var keyedService = (kernel.Services as IKeyedServiceProvider)?.GetKeyedService("service1"); + if (keyedService is null || keyedService.GetModelId() is null) + { + service = null; + serviceSettings = null; + return false; + } + + service = string.Equals(keyedService.GetModelId(), "Value1", StringComparison.OrdinalIgnoreCase) ? keyedService as T : null; + serviceSettings = null; + + if (service is null) + { + throw new InvalidOperationException("Service not found"); + } + + return true; + } + } + + private sealed class ChatClientTest : IChatClient + { + private readonly ChatClientMetadata _metadata; + + public ChatClientTest() + { + this._metadata = new ChatClientMetadata(modelId: "Value1"); + } + + public void Dispose() + { + } + + public Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public object? GetService(Type serviceType, object? serviceKey = null) + { + return this._metadata; + } + + public IAsyncEnumerable GetStreamingResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIServiceSelectorTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIServiceSelectorTests.cs index a53d8550c4d7..4697a0958c64 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIServiceSelectorTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIServiceSelectorTests.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Services; @@ -33,7 +35,8 @@ public void ItGetsAIServiceUsingArbitraryAttributes() private sealed class CustomAIServiceSelector : IAIServiceSelector { #pragma warning disable CS8769 // Nullability of reference types in value doesn't match target type. Cannot use [NotNullWhen] because of access to internals from abstractions. - bool IAIServiceSelector.TrySelectAIService(Kernel kernel, KernelFunction function, KernelArguments arguments, out T? service, out PromptExecutionSettings? serviceSettings) where T : class + public bool TrySelectAIService(Kernel kernel, KernelFunction function, KernelArguments arguments, [NotNullWhen(true)] out T? service, out PromptExecutionSettings? serviceSettings) + where T : class, IAIService { var keyedService = (kernel.Services as IKeyedServiceProvider)?.GetKeyedService("service1"); if (keyedService is null || keyedService.Attributes is null) @@ -45,6 +48,12 @@ bool IAIServiceSelector.TrySelectAIService(Kernel kernel, KernelFunction func service = keyedService.Attributes.ContainsKey("Key1") ? keyedService as T : null; serviceSettings = null; + + if (service is null) + { + throw new InvalidOperationException("Service not found"); + } + return true; } } diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/FunctionResultTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/FunctionResultTests.cs index 787718b6e8e4..4d2f5e14d763 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/FunctionResultTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/FunctionResultTests.cs @@ -4,7 +4,9 @@ using System.Collections.Generic; using System.Globalization; using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; using Xunit; +using MEAI = Microsoft.Extensions.AI; namespace SemanticKernel.UnitTests.Functions; @@ -134,4 +136,172 @@ public void GetValueWhenValueIsKernelContentGenericTypeMatchShouldReturn() Assert.Equal(valueType, target.GetValue()); Assert.Equal(valueType, target.GetValue()); } + + [Fact] + public void GetValueConvertsFromMEAIChatMessageToSKChatMessageContent() + { + // Arrange + string expectedValue = Guid.NewGuid().ToString(); + var openAICompletion = OpenAI.Chat.OpenAIChatModelFactory.ChatCompletion( + role: OpenAI.Chat.ChatMessageRole.User, + content: new OpenAI.Chat.ChatMessageContent(expectedValue)); + + var valueType = new MEAI.ChatResponse( + [ + new MEAI.ChatMessage(MEAI.ChatRole.User, expectedValue) + { + RawRepresentation = openAICompletion.Content + }, + new MEAI.ChatMessage(MEAI.ChatRole.Assistant, expectedValue) + { + RawRepresentation = openAICompletion.Content + } + ]) + { + RawRepresentation = openAICompletion + }; + + FunctionResult target = new(s_nopFunction, valueType); + + // Act and Assert + var message = target.GetValue()!; + Assert.Equal(valueType.Message.Text, message.Content); + Assert.Same(valueType.Message.RawRepresentation, message.InnerContent); + } + + [Fact] + public void GetValueConvertsFromSKChatMessageContentToMEAIChatMessage() + { + // Arrange + string expectedValue = Guid.NewGuid().ToString(); + var openAIChatMessage = OpenAI.Chat.ChatMessage.CreateUserMessage(expectedValue); + var valueType = new ChatMessageContent(AuthorRole.User, expectedValue) { InnerContent = openAIChatMessage }; + FunctionResult target = new(s_nopFunction, valueType); + + // Act and Assert + Assert.Equal(valueType.Content, target.GetValue()!.Text); + Assert.Same(valueType.InnerContent, target.GetValue()!.RawRepresentation); + } + + [Fact] + public void GetValueConvertsFromSKChatMessageContentToMEAIChatResponse() + { + // Arrange + string expectedValue = Guid.NewGuid().ToString(); + var openAIChatMessage = OpenAI.Chat.ChatMessage.CreateUserMessage(expectedValue); + var valueType = new ChatMessageContent(AuthorRole.User, expectedValue) { InnerContent = openAIChatMessage }; + FunctionResult target = new(s_nopFunction, valueType); + + // Act and Assert + + Assert.Equal(valueType.Content, target.GetValue()!.Message.Text); + Assert.Same(valueType.InnerContent, target.GetValue()!.Message.RawRepresentation); + } + + [Theory] + [InlineData(1)] + [InlineData(2)] + [InlineData(5)] + public void GetValueConvertsFromSKChatMessageContentListToMEAIChatResponse(int listSize) + { + // Arrange + List multipleChoiceResponse = []; + for (int i = 0; i < listSize; i++) + { + multipleChoiceResponse.Add(new ChatMessageContent(AuthorRole.User, Guid.NewGuid().ToString()) + { + InnerContent = OpenAI.Chat.ChatMessage.CreateUserMessage(i.ToString()) + }); + } + FunctionResult target = new(KernelFunctionFactory.CreateFromMethod(() => { }), (IReadOnlyList)multipleChoiceResponse); + + // Act and Assert + // Ensure returns the ChatResponse for no choices as well + var result = target.GetValue()!; + for (int i = 0; i < listSize; i++) + { + Assert.Equal(multipleChoiceResponse[i].Content, result.Choices[i].Text); + Assert.Same(multipleChoiceResponse[i].InnerContent, result.Choices[i].RawRepresentation); + } + Assert.Equal(multipleChoiceResponse.Count, result.Choices.Count); + + if (listSize > 0) + { + // Ensure the conversion to the first message works in one or multiple choice response + Assert.Equal(multipleChoiceResponse[0].Content, target.GetValue()!.Text); + Assert.Same(multipleChoiceResponse[0].InnerContent, target.GetValue()!.RawRepresentation); + } + } + + [Fact] + public void GetValueThrowsForEmptyChoicesFromSKChatMessageContentListToMEAITypes() + { + // Arrange + List multipleChoiceResponse = []; + FunctionResult target = new(KernelFunctionFactory.CreateFromMethod(() => { }), (IReadOnlyList)multipleChoiceResponse); + + // Act and Assert + var exception = Assert.Throws(target.GetValue); + Assert.Contains("no choices", exception.Message); + + exception = Assert.Throws(target.GetValue); + Assert.Contains("no choices", exception.Message); + } + + [Fact] + public void GetValueCanRetrieveMEAITypes() + { + // Arrange + string expectedValue = Guid.NewGuid().ToString(); + var openAICompletion = OpenAI.Chat.OpenAIChatModelFactory.ChatCompletion( + role: OpenAI.Chat.ChatMessageRole.User, + content: new OpenAI.Chat.ChatMessageContent(expectedValue)); + + var valueType = new MEAI.ChatResponse( + new MEAI.ChatMessage(MEAI.ChatRole.User, expectedValue) + { + RawRepresentation = openAICompletion.Content + }) + { + RawRepresentation = openAICompletion + }; + + FunctionResult target = new(s_nopFunction, valueType); + + // Act and Assert + Assert.Same(valueType, target.GetValue()); + Assert.Same(valueType.Message, target.GetValue()); + Assert.Same(valueType.Message.Contents[0], target.GetValue()); + Assert.Same(valueType.Message.Contents[0], target.GetValue()); + + // Check the the content list is returned + Assert.Same(valueType.Message.Contents, target.GetValue>()!); + Assert.Same(valueType.Message.Contents[0], target.GetValue>()![0]); + Assert.IsType(target.GetValue>()![0]); + + // Check the raw representations are returned + Assert.Same(valueType.RawRepresentation, target.GetValue()!); + Assert.Same(valueType.Message.RawRepresentation, target.GetValue()!); + } + + [Fact] + public void GetValueThrowsForEmptyChoicesToMEAITypes() + { + // Arrange + string expectedValue = Guid.NewGuid().ToString(); + var valueType = new MEAI.ChatResponse([]); + FunctionResult target = new(s_nopFunction, valueType); + + // Act and Assert + Assert.Empty(target.GetValue()!.Choices); + + var exception = Assert.Throws(target.GetValue); + Assert.Contains("no choices", exception.Message); + + exception = Assert.Throws(target.GetValue); + Assert.Contains("no choices", exception.Message); + + exception = Assert.Throws(target.GetValue); + Assert.Contains("no choices", exception.Message); + } } diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionFromPromptTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionFromPromptTests.cs index 72dc5199dafb..2ec0c214b2a5 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionFromPromptTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionFromPromptTests.cs @@ -17,8 +17,7 @@ using Moq; using SemanticKernel.UnitTests.Functions.JsonSerializerContexts; using Xunit; - -// ReSharper disable StringLiteralTypo +using MEAI = Microsoft.Extensions.AI; namespace SemanticKernel.UnitTests.Functions; @@ -150,18 +149,18 @@ public async Task ItUsesServiceIdWhenProvidedInMethodAsync() public async Task ItUsesChatServiceIdWhenProvidedInMethodAsync() { // Arrange - var mockTextGeneration1 = new Mock(); - var mockTextGeneration2 = new Mock(); + var mockTextGeneration = new Mock(); + var mockChatCompletion = new Mock(); var fakeTextContent = new TextContent("llmResult"); var fakeChatContent = new ChatMessageContent(AuthorRole.User, "content"); - mockTextGeneration1.Setup(c => c.GetTextContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync([fakeTextContent]); - mockTextGeneration2.Setup(c => c.GetChatMessageContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync([fakeChatContent]); + mockTextGeneration.Setup(c => c.GetTextContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync([fakeTextContent]); + mockChatCompletion.Setup(c => c.GetChatMessageContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync([fakeChatContent]); IKernelBuilder builder = Kernel.CreateBuilder(); - builder.Services.AddKeyedSingleton("service1", mockTextGeneration1.Object); - builder.Services.AddKeyedSingleton("service2", mockTextGeneration2.Object); - builder.Services.AddKeyedSingleton("service3", mockTextGeneration1.Object); + builder.Services.AddKeyedSingleton("service1", mockTextGeneration.Object); + builder.Services.AddKeyedSingleton("service2", mockChatCompletion.Object); + builder.Services.AddKeyedSingleton("service3", mockTextGeneration.Object); Kernel kernel = builder.Build(); var func = kernel.CreateFunctionFromPrompt("my prompt", [new PromptExecutionSettings { ServiceId = "service2" }]); @@ -170,8 +169,41 @@ public async Task ItUsesChatServiceIdWhenProvidedInMethodAsync() await kernel.InvokeAsync(func); // Assert - mockTextGeneration1.Verify(a => a.GetTextContentsAsync("my prompt", It.IsAny(), It.IsAny(), It.IsAny()), Times.Never()); - mockTextGeneration2.Verify(a => a.GetChatMessageContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Once()); + mockTextGeneration.Verify(a => a.GetTextContentsAsync("my prompt", It.IsAny(), It.IsAny(), It.IsAny()), Times.Never()); + mockChatCompletion.Verify(a => a.GetChatMessageContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Once()); + } + + [Fact] + public async Task ItUsesChatClientIdWhenProvidedInMethodAsync() + { + // Arrange + var mockTextGeneration = new Mock(); + var mockChatCompletion = new Mock(); + var mockChatClient = new Mock(); + var fakeTextContent = new TextContent("llmResult"); + var fakeChatContent = new ChatMessageContent(AuthorRole.User, "content"); + var fakeChatResponse = new MEAI.ChatResponse(new MEAI.ChatMessage()); + + mockTextGeneration.Setup(c => c.GetTextContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync([fakeTextContent]); + mockChatCompletion.Setup(c => c.GetChatMessageContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync([fakeChatContent]); + mockChatClient.Setup(c => c.GetResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny())).ReturnsAsync(fakeChatResponse); + mockChatClient.Setup(c => c.GetService(typeof(MEAI.ChatClientMetadata), It.IsAny())).Returns(new MEAI.ChatClientMetadata()); + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddKeyedSingleton("service1", mockTextGeneration.Object); + builder.Services.AddKeyedSingleton("service2", mockChatClient.Object); + builder.Services.AddKeyedSingleton("service3", mockChatCompletion.Object); + Kernel kernel = builder.Build(); + + var func = kernel.CreateFunctionFromPrompt("my prompt", [new PromptExecutionSettings { ServiceId = "service2" }]); + + // Act + await kernel.InvokeAsync(func); + + // Assert + mockTextGeneration.Verify(a => a.GetTextContentsAsync("my prompt", It.IsAny(), It.IsAny(), It.IsAny()), Times.Never()); + mockChatCompletion.Verify(a => a.GetChatMessageContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Never()); + mockChatClient.Verify(a => a.GetResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny()), Times.Once()); } [Fact] @@ -194,7 +226,7 @@ public async Task ItFailsIfInvalidServiceIdIsProvidedAsync() var exception = await Assert.ThrowsAsync(() => kernel.InvokeAsync(func)); // Assert - Assert.Equal("Required service of type Microsoft.SemanticKernel.TextGeneration.ITextGenerationService not registered. Expected serviceIds: service3.", exception.Message); + Assert.Contains("Expected serviceIds: service3.", exception.Message); } [Fact] @@ -219,6 +251,28 @@ public async Task ItParsesStandardizedPromptWhenServiceIsChatCompletionAsync() Assert.Equal("How many 20 cents can I get from 1 dollar?", fakeService.ChatHistory[1].Content); } + [Fact] + public async Task ItParsesStandardizedPromptWhenServiceIsChatClientAsync() + { + using var fakeService = new FakeChatClient(); + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + await kernel.InvokeAsync(function); + + Assert.NotNull(fakeService.ChatMessages); + Assert.Equal(2, fakeService.ChatMessages.Count); + Assert.Equal("You are a helpful assistant.", fakeService.ChatMessages[0].Text); + Assert.Equal("How many 20 cents can I get from 1 dollar?", fakeService.ChatMessages[1].Text); + } + [Fact] public async Task ItParsesStandardizedPromptWhenServiceIsStreamingChatCompletionAsync() { @@ -342,6 +396,76 @@ public async Task InvokeAsyncReturnsTheConnectorChatResultWhenInServiceIsOnlyCha Assert.Equal("something", result.GetValue()!.ToString()); } + [Fact] + public async Task InvokeAsyncReturnsTheConnectorChatResultWhenInServiceIsOnlyChatClientAsync() + { + var customTestType = new CustomTestType(); + var fakeChatMessage = new MEAI.ChatMessage(MEAI.ChatRole.User, "something") { RawRepresentation = customTestType }; + var fakeChatResponse = new MEAI.ChatResponse(fakeChatMessage); + Mock mockChatClient = new(); + mockChatClient.Setup(c => c.GetResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny())).ReturnsAsync(fakeChatResponse); + mockChatClient.Setup(c => c.GetService(typeof(MEAI.ChatClientMetadata), It.IsAny())).Returns(new MEAI.ChatClientMetadata()); + + using var chatClient = mockChatClient.Object; + KernelBuilder builder = new(); + builder.Services.AddTransient((sp) => chatClient); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt("Anything"); + + var result = await kernel.InvokeAsync(function); + + Assert.Equal("something", result.GetValue()); + Assert.Equal("something", result.GetValue()!.Text); + Assert.Equal(MEAI.ChatRole.User, result.GetValue()!.Role); + Assert.Same(customTestType, result.GetValue()!); + Assert.Equal("something", result.GetValue()!.ToString()); + Assert.Equal("something", result.GetValue()!.ToString()); + } + + [Fact] + public async Task InvokeAsyncReturnsTheConnectorChatResultChoicesWhenInServiceIsOnlyChatClientAsync() + { + var customTestType = new CustomTestType(); + var fakeChatResponse = new MEAI.ChatResponse([ + new MEAI.ChatMessage(MEAI.ChatRole.User, "something 1") { RawRepresentation = customTestType }, + new MEAI.ChatMessage(MEAI.ChatRole.Assistant, "something 2") + ]); + + Mock mockChatClient = new(); + mockChatClient.Setup(c => c.GetResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny())).ReturnsAsync(fakeChatResponse); + mockChatClient.Setup(c => c.GetService(typeof(MEAI.ChatClientMetadata), It.IsAny())).Returns(new MEAI.ChatClientMetadata()); + + using var chatClient = mockChatClient.Object; + KernelBuilder builder = new(); + builder.Services.AddTransient((sp) => chatClient); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt("Anything"); + + var result = await kernel.InvokeAsync(function); + + var response = result.GetValue(); + Assert.NotNull(response); + Assert.Collection(response.Choices, + item1 => + { + Assert.Equal("something 1", item1.Text); Assert.Equal(MEAI.ChatRole.User, item1.Role); + }, + item2 => + { + Assert.Equal("something 2", item2.Text); Assert.Equal(MEAI.ChatRole.Assistant, item2.Role); + }); + + // Other specific types will be checked against the first choice + Assert.Equal("something 1", result.GetValue()); + Assert.Equal("something 1", result.GetValue()!.Text); + Assert.Equal(MEAI.ChatRole.User, result.GetValue()!.Role); + Assert.Same(customTestType, result.GetValue()!); + Assert.Equal("something 1", result.GetValue()!.ToString()); + Assert.Equal("something 1", result.GetValue()!.ToString()); + } + [Fact] public async Task InvokeAsyncReturnsTheConnectorChatResultWhenInServiceIsChatAndTextCompletionAsync() { @@ -947,6 +1071,500 @@ public async Task ItCanBeCloned(JsonSerializerOptions? jsos) Assert.Equal("Prompt with a variable", result); } + [Fact] + public async Task ItCanRetrieveDirectMEAIChatMessageUpdatesAsync() + { + using var fakeService = new FakeChatClient() + { + GetStreamingResponseResult = [ + new MEAI.ChatResponseUpdate + { + Role = MEAI.ChatRole.Assistant, + Text = "Hi! How can " + }, + new MEAI.ChatResponseUpdate + { + Text = "I assist you today?" + }] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(KernelFunctionFromPrompt.Create("Prompt with {{$A}} variable"))) + { + Assert.Same(fakeService.GetStreamingResponseResult![updateIndex], update); + Assert.Equal(fakeService.GetStreamingResponseResult![updateIndex].Text, update.Text); + + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingResponseResult.Count); + } + + [Fact] + public async Task ItCanRetrieveDirectMEAITextContentAsync() + { + using var fakeService = new FakeChatClient() + { + GetStreamingResponseResult = [ + new MEAI.ChatResponseUpdate + { + Role = MEAI.ChatRole.Assistant, + Text = "Hi! How can " + }, + new MEAI.ChatResponseUpdate + { + Text = "I assist you today?" + }] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(KernelFunctionFromPrompt.Create("Prompt with {{$A}} variable"))) + { + Assert.Same(fakeService.GetStreamingResponseResult![updateIndex].Contents[0], update); + + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingResponseResult.Count); + } + + [Fact] + public async Task ItCanRetrieveDirectMEAIStringAsync() + { + using var fakeService = new FakeChatClient() + { + GetStreamingResponseResult = [ + new MEAI.ChatResponseUpdate + { + Role = MEAI.ChatRole.Assistant, + Text = "Hi! How can " + }, + new MEAI.ChatResponseUpdate + { + Text = "I assist you today?" + }] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(KernelFunctionFromPrompt.Create("Prompt with {{$A}} variable"))) + { + Assert.Equal(fakeService.GetStreamingResponseResult![updateIndex].Text, update); + + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingResponseResult.Count); + } + + [Fact] + public async Task ItCanRetrieveDirectMEAIRawRepresentationAsync() + { + var rawRepresentation = OpenAI.Chat.OpenAIChatModelFactory.StreamingChatCompletionUpdate(contentUpdate: new OpenAI.Chat.ChatMessageContent("Hi!")); + using var fakeService = new FakeChatClient() + { + GetStreamingResponseResult = [ + new MEAI.ChatResponseUpdate + { + Role = MEAI.ChatRole.Assistant, + Text = "Hi! How can ", + RawRepresentation = rawRepresentation + }, + new MEAI.ChatResponseUpdate + { + Text = "I assist you today?", + RawRepresentation = rawRepresentation + }] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(KernelFunctionFromPrompt.Create("Prompt with {{$A}} variable"))) + { + Assert.Same(fakeService.GetStreamingResponseResult![updateIndex].RawRepresentation, update); + + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingResponseResult.Count); + } + + [Fact] + public async Task ItCanRetrieveDirectMEAIContentListAsync() + { + var rawRepresentation = OpenAI.Chat.OpenAIChatModelFactory.StreamingChatCompletionUpdate(contentUpdate: new OpenAI.Chat.ChatMessageContent("Hi!")); + using var fakeService = new FakeChatClient() + { + GetStreamingResponseResult = [ + new MEAI.ChatResponseUpdate + { + Role = MEAI.ChatRole.Assistant, + Text = "Hi! How can ", + RawRepresentation = rawRepresentation + }, + new MEAI.ChatResponseUpdate + { + Text = "I assist you today?", + RawRepresentation = rawRepresentation + }] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync>(KernelFunctionFromPrompt.Create("Prompt with {{$A}} variable"))) + { + Assert.Same(fakeService.GetStreamingResponseResult![updateIndex].Contents, update); + + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingResponseResult.Count); + } + + [Fact] + public async Task ItConvertsFromMEAIChatMessageUpdateToSKStreamingChatMessageContentAsync() + { + var rawRepresentation = new { test = "a" }; + using var fakeService = new FakeChatClient() + { + GetStreamingResponseResult = [ + new MEAI.ChatResponseUpdate + { + Role = MEAI.ChatRole.Assistant, + Text = "Hi! How can ", + RawRepresentation = rawRepresentation + }, + new MEAI.ChatResponseUpdate + { + Text = "I assist you today?", + RawRepresentation = rawRepresentation + }] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(function)) + { + Assert.Equal(fakeService.GetStreamingResponseResult![updateIndex].Text, update.Content); + Assert.Same(fakeService.GetStreamingResponseResult![updateIndex].RawRepresentation, update.InnerContent); + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingResponseResult.Count); + } + + [Fact] + public async Task ItConvertsFromMEAIChatMessageUpdateToSKStreamingContentAsync() + { + var rawRepresentation = new { test = "a" }; + using var fakeService = new FakeChatClient() + { + GetStreamingResponseResult = [ + new MEAI.ChatResponseUpdate + { + Role = MEAI.ChatRole.Assistant, + Text = "Hi! How can ", + RawRepresentation = rawRepresentation + }, + new MEAI.ChatResponseUpdate + { + Text = "I assist you today?", + RawRepresentation = rawRepresentation + }] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(function)) + { + var streamingChatContent = Assert.IsType(update); + Assert.Same(fakeService.GetStreamingResponseResult![updateIndex].RawRepresentation, update.InnerContent); + + Assert.Equal(fakeService.GetStreamingResponseResult![updateIndex].Text, streamingChatContent.Content); + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingResponseResult.Count); + } + + [Fact] + public async Task ItConvertsFromSKStreamingChatMessageContentToMEAIChatResponseUpdate() + { + var innerContent = new { test = "a" }; + var fakeService = new FakeChatCompletionService() + { + GetStreamingChatMessageContentsResult = [ + new StreamingChatMessageContent(AuthorRole.Assistant, "Hi! How can ") + { + InnerContent = innerContent + }, + new StreamingChatMessageContent(null, "I assist you today?") + { + InnerContent = innerContent + }] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(function)) + { + Assert.Same(fakeService.GetStreamingChatMessageContentsResult![updateIndex].InnerContent, update.RawRepresentation); + + Assert.Equal(fakeService.GetStreamingChatMessageContentsResult![updateIndex].Content, update.Text); + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingChatMessageContentsResult.Count); + } + + [Fact] + public async Task ItConvertsFromSKStreamingChatMessageContentToStringAsync() + { + var innerContent = new { test = "a" }; + var fakeService = new FakeChatCompletionService() + { + GetStreamingChatMessageContentsResult = [ + new StreamingChatMessageContent(AuthorRole.Assistant, "Hi! How can ") + { + InnerContent = innerContent + }, + new StreamingChatMessageContent(null, "I assist you today?") + { + InnerContent = innerContent + }] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(function)) + { + Assert.Equal(fakeService.GetStreamingChatMessageContentsResult![updateIndex].Content, update); + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingChatMessageContentsResult.Count); + } + + [Fact] + public async Task ItConvertsFromSKStreamingChatMessageContentToItselfAsync() + { + var innerContent = new { test = "a" }; + var fakeService = new FakeChatCompletionService() + { + GetStreamingChatMessageContentsResult = [ + new StreamingChatMessageContent(AuthorRole.Assistant, "Hi! How can ") + { + InnerContent = innerContent + }, + new StreamingChatMessageContent(null, "I assist you today?") + { + InnerContent = innerContent + }] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(function)) + { + Assert.Same(fakeService.GetStreamingChatMessageContentsResult![updateIndex], update); + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingChatMessageContentsResult.Count); + } + + [Fact] + public async Task ItConvertsFromSKStreamingChatMessageContentToInnerContentAsync() + { + var innerContent = new Random(); + var fakeService = new FakeChatCompletionService() + { + GetStreamingChatMessageContentsResult = [ + new StreamingChatMessageContent(AuthorRole.Assistant, "Hi! How can ") + { + InnerContent = innerContent + }, + new StreamingChatMessageContent(null, "I assist you today?") + { + InnerContent = innerContent + }] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(function)) + { + Assert.Same(fakeService.GetStreamingChatMessageContentsResult![updateIndex].InnerContent, update); + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingChatMessageContentsResult.Count); + } + + [Fact] + public async Task ItConvertsFromSKStreamingChatMessageContentToBytesAsync() + { + var innerContent = new Random(); + var fakeService = new FakeChatCompletionService() + { + GetStreamingChatMessageContentsResult = [ + new StreamingChatMessageContent(AuthorRole.Assistant, "Hi! How can ") + { + InnerContent = innerContent + }, + new StreamingChatMessageContent(null, "I assist you today?") + { + InnerContent = innerContent + }] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(function)) + { + Assert.Equal(fakeService.GetStreamingChatMessageContentsResult![updateIndex].Content, + fakeService.GetStreamingChatMessageContentsResult![updateIndex].Encoding.GetString(update)); + + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingChatMessageContentsResult.Count); + } + + /// + /// This scenario covers scenarios on attempting to get a ChatResponseUpdate from a ITextGenerationService. + /// + [Fact] + public async Task ItThrowsConvertingFromNonChatSKStreamingContentToMEAIChatResponseUpdate() + { + var fakeService = new FakeTextGenerationService() + { + GetStreamingTextContentsResult = [new StreamingTextContent("Hi!")] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt("How many 20 cents can I get from 1 dollar?"); + + // Act + Assert + await Assert.ThrowsAsync( + () => kernel.InvokeStreamingAsync(function).GetAsyncEnumerator().MoveNextAsync().AsTask()); + } + + [Fact] + public async Task ItThrowsWhenConvertingFromMEAIChatMessageUpdateWithNoDataContentToBytesAsync() + { + using var fakeService = new FakeChatClient() + { + GetStreamingResponseResult = [ + new MEAI.ChatResponseUpdate + { + Role = MEAI.ChatRole.Assistant, + Text = "Hi! How can ", + }] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + await Assert.ThrowsAsync( + () => kernel.InvokeStreamingAsync(function).GetAsyncEnumerator().MoveNextAsync().AsTask()); + } + public enum KernelInvocationType { InvokePrompt, @@ -990,6 +1608,93 @@ public Task> GetTextContentsAsync(string prompt, Prom } } + private sealed class FakeChatCompletionService : IChatCompletionService + { + public IReadOnlyDictionary Attributes => throw new NotImplementedException(); + + public IList? GetStreamingChatMessageContentsResult { get; set; } + + public Task> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + public async IAsyncEnumerable GetStreamingChatMessageContentsAsync( + ChatHistory chatHistory, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + foreach (var item in this.GetStreamingChatMessageContentsResult ?? [new StreamingChatMessageContent(AuthorRole.Assistant, "Something")]) + { + yield return item; + } + } +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + } + + private sealed class FakeTextGenerationService : ITextGenerationService + { + public IReadOnlyDictionary Attributes => throw new NotImplementedException(); + + public IList? GetStreamingTextContentsResult { get; set; } + +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + public async IAsyncEnumerable GetStreamingTextContentsAsync( + string prompt, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + foreach (var item in this.GetStreamingTextContentsResult ?? [new StreamingTextContent("Something")]) + { + yield return item; + } + } +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + + public Task> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + } + + private sealed class FakeChatClient : MEAI.IChatClient + { + public IList? ChatMessages { get; private set; } + public IList? GetStreamingResponseResult { get; set; } + + public void Dispose() + { + } + + public Task GetResponseAsync(IList chatMessages, MEAI.ChatOptions? options = null, CancellationToken cancellationToken = default) + { + this.ChatMessages = chatMessages; + return Task.FromResult(new MEAI.ChatResponse(new MEAI.ChatMessage(MEAI.ChatRole.Assistant, "Something"))); + } + + public object? GetService(Type serviceType, object? serviceKey = null) + { + return new MEAI.ChatClientMetadata(); + } + +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + public async IAsyncEnumerable GetStreamingResponseAsync( + IList chatMessages, + MEAI.ChatOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + this.ChatMessages = chatMessages; + foreach (var item in this.GetStreamingResponseResult ?? [new MEAI.ChatResponseUpdate { Role = MEAI.ChatRole.Assistant, Text = "Something" }]) + { + yield return item; + } + } +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + } + private Mock GetMockTextGenerationService(IReadOnlyList? textContents = null) { var mockTextGenerationService = new Mock(); @@ -1012,5 +1717,9 @@ private Mock GetMockChatCompletionService(IReadOnlyList< return mockChatCompletionService; } + private sealed class CustomTestType + { + public string Name { get; set; } = "MyCustomType"; + } #endregion } diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/MultipleModelTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/MultipleModelTests.cs index 40121103ce69..1c585726f82d 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/MultipleModelTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/MultipleModelTests.cs @@ -60,7 +60,7 @@ public async Task ItFailsIfInvalidServiceIdIsProvidedAsync() var exception = await Assert.ThrowsAsync(() => kernel.InvokeAsync(func)); // Assert - Assert.Equal("Required service of type Microsoft.SemanticKernel.TextGeneration.ITextGenerationService not registered. Expected serviceIds: service3.", exception.Message); + Assert.Contains("Expected serviceIds: service3.", exception.Message); } [Theory] diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs index eafac8ac5ca3..b31a98c3f1f3 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs @@ -4,10 +4,13 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Services; using Microsoft.SemanticKernel.TextGeneration; +using Moq; using Xunit; namespace SemanticKernel.UnitTests.Functions; @@ -25,6 +28,7 @@ public void ItThrowsAKernelExceptionForNoServices() // Act // Assert Assert.Throws(() => serviceSelector.SelectAIService(kernel, function, [])); + Assert.Throws(() => serviceSelector.SelectAIService(kernel, function, [])); } [Fact] @@ -46,6 +50,27 @@ public void ItGetsAIServiceConfigurationForSingleAIService() Assert.Null(defaultExecutionSettings); } + [Fact] + public void ItGetsChatClientConfigurationForSingleChatClient() + { + // Arrange + var mockChat = new Mock(); + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddKeyedSingleton("chat1", mockChat.Object); + Kernel kernel = builder.Build(); + + var function = kernel.CreateFunctionFromPrompt("Hello AI"); + var serviceSelector = new OrderedAIServiceSelector(); + + // Act + serviceSelector.TrySelectChatClient(kernel, function, [], out var chatClient, out var defaultExecutionSettings); + chatClient?.Dispose(); + + // Assert + Assert.NotNull(chatClient); + Assert.Null(defaultExecutionSettings); + } + [Fact] public void ItGetsAIServiceConfigurationForSingleTextGeneration() { @@ -90,13 +115,41 @@ public void ItGetsAIServiceConfigurationForTextGenerationByServiceId() Assert.Equivalent(expectedExecutionSettings, defaultExecutionSettings); } + [Fact] + public void ItGetsChatClientConfigurationForChatClientByServiceId() + { + // Arrange + IKernelBuilder builder = Kernel.CreateBuilder(); + using var chatClient1 = new ChatClient("model_id_1"); + using var chatClient2 = new ChatClient("model_id_2"); + builder.Services.AddKeyedSingleton("chat1", chatClient1); + builder.Services.AddKeyedSingleton("chat2", chatClient2); + Kernel kernel = builder.Build(); + + var promptConfig = new PromptTemplateConfig() { Template = "Hello AI" }; + var executionSettings = new PromptExecutionSettings(); + promptConfig.AddExecutionSettings(executionSettings, "chat2"); + var function = kernel.CreateFunctionFromPrompt(promptConfig); + var serviceSelector = new OrderedAIServiceSelector(); + + // Act + serviceSelector.TrySelectChatClient(kernel, function, [], out var aiService, out var defaultExecutionSettings); + aiService?.Dispose(); + + // Assert + Assert.Equal(kernel.GetRequiredService("chat2"), aiService); + var expectedExecutionSettings = executionSettings.Clone(); + expectedExecutionSettings.Freeze(); + Assert.Equivalent(expectedExecutionSettings, defaultExecutionSettings); + } + [Fact] public void ItThrowsAKernelExceptionForNotFoundService() { // Arrange IKernelBuilder builder = Kernel.CreateBuilder(); builder.Services.AddKeyedSingleton("service1", new TextGenerationService("model_id_1")); - builder.Services.AddKeyedSingleton("service2", new TextGenerationService("model_id_2")); + builder.Services.AddKeyedSingleton("service2", new ChatCompletionService("model_id_2")); Kernel kernel = builder.Build(); var promptConfig = new PromptTemplateConfig() { Template = "Hello AI" }; @@ -107,6 +160,7 @@ public void ItThrowsAKernelExceptionForNotFoundService() // Act // Assert Assert.Throws(() => serviceSelector.SelectAIService(kernel, function, [])); + Assert.Throws(() => serviceSelector.SelectAIService(kernel, function, [])); } [Fact] @@ -129,6 +183,30 @@ public void ItGetsDefaultServiceForNotFoundModel() Assert.Equal(kernel.GetRequiredService("service2"), aiService); } + [Fact] + public void ItGetsDefaultChatClientForNotFoundModel() + { + // Arrange + IKernelBuilder builder = Kernel.CreateBuilder(); + using var chatClient1 = new ChatClient("model_id_1"); + using var chatClient2 = new ChatClient("model_id_2"); + builder.Services.AddKeyedSingleton("chat1", chatClient1); + builder.Services.AddKeyedSingleton("chat2", chatClient2); + Kernel kernel = builder.Build(); + + var promptConfig = new PromptTemplateConfig() { Template = "Hello AI" }; + promptConfig.AddExecutionSettings(new PromptExecutionSettings { ModelId = "notfound" }); + var function = kernel.CreateFunctionFromPrompt(promptConfig); + var serviceSelector = new OrderedAIServiceSelector(); + + // Act + // Assert + serviceSelector.TrySelectChatClient(kernel, function, [], out var aiService, out var defaultExecutionSettings); + aiService?.Dispose(); + + Assert.Equal(kernel.GetRequiredService("chat2"), aiService); + } + [Fact] public void ItUsesDefaultServiceForNoExecutionSettings() { @@ -148,6 +226,28 @@ public void ItUsesDefaultServiceForNoExecutionSettings() Assert.Null(defaultExecutionSettings); } + [Fact] + public void ItUsesDefaultChatClientForNoExecutionSettings() + { + // Arrange + IKernelBuilder builder = Kernel.CreateBuilder(); + using var chatClient1 = new ChatClient("model_id_1"); + using var chatClient2 = new ChatClient("model_id_2"); + builder.Services.AddKeyedSingleton("chat1", chatClient1); + builder.Services.AddKeyedSingleton("chat2", chatClient2); + Kernel kernel = builder.Build(); + var function = kernel.CreateFunctionFromPrompt("Hello AI"); + var serviceSelector = new OrderedAIServiceSelector(); + + // Act + serviceSelector.TrySelectChatClient(kernel, function, [], out var aiService, out var defaultExecutionSettings); + aiService?.Dispose(); + + // Assert + Assert.Equal(kernel.GetRequiredService("chat2"), aiService); + Assert.Null(defaultExecutionSettings); + } + [Fact] public void ItUsesDefaultServiceAndSettingsForDefaultExecutionSettings() { @@ -171,6 +271,32 @@ public void ItUsesDefaultServiceAndSettingsForDefaultExecutionSettings() Assert.Equivalent(expectedExecutionSettings, defaultExecutionSettings); } + [Fact] + public void ItUsesDefaultChatClientAndSettingsForDefaultExecutionSettings() + { + // Arrange + IKernelBuilder builder = Kernel.CreateBuilder(); + using var chatClient1 = new ChatClient("model_id_1"); + using var chatClient2 = new ChatClient("model_id_2"); + builder.Services.AddKeyedSingleton("chat1", chatClient1); + builder.Services.AddKeyedSingleton("chat2", chatClient2); + Kernel kernel = builder.Build(); + + var executionSettings = new PromptExecutionSettings(); + var function = kernel.CreateFunctionFromPrompt("Hello AI", executionSettings); + var serviceSelector = new OrderedAIServiceSelector(); + + // Act + serviceSelector.TrySelectChatClient(kernel, function, [], out var aiService, out var defaultExecutionSettings); + aiService?.Dispose(); + + // Assert + Assert.Equal(kernel.GetRequiredService("chat2"), aiService); + var expectedExecutionSettings = executionSettings.Clone(); + expectedExecutionSettings.Freeze(); + Assert.Equivalent(expectedExecutionSettings, defaultExecutionSettings); + } + [Fact] public void ItUsesDefaultServiceAndSettingsForDefaultId() { @@ -194,6 +320,32 @@ public void ItUsesDefaultServiceAndSettingsForDefaultId() Assert.Equivalent(expectedExecutionSettings, defaultExecutionSettings); } + [Fact] + public void ItUsesDefaultChatClientAndSettingsForDefaultId() + { + // Arrange + IKernelBuilder builder = Kernel.CreateBuilder(); + using var chatClient1 = new ChatClient("model_id_1"); + using var chatClient2 = new ChatClient("model_id_2"); + builder.Services.AddKeyedSingleton("chat1", chatClient1); + builder.Services.AddKeyedSingleton("chat2", chatClient2); + Kernel kernel = builder.Build(); + + var executionSettings = new PromptExecutionSettings(); + var function = kernel.CreateFunctionFromPrompt("Hello AI", executionSettings); + var serviceSelector = new OrderedAIServiceSelector(); + + // Act + serviceSelector.TrySelectChatClient(kernel, function, [], out var aiService, out var defaultExecutionSettings); + aiService?.Dispose(); + + // Assert + Assert.Equal(kernel.GetRequiredService("chat2"), aiService); + var expectedExecutionSettings = executionSettings.Clone(); + expectedExecutionSettings.Freeze(); + Assert.Equivalent(expectedExecutionSettings, defaultExecutionSettings); + } + [Theory] [InlineData(new string[] { "modelid_1" }, "modelid_1")] [InlineData(new string[] { "modelid_2" }, "modelid_2")] @@ -228,6 +380,44 @@ public void ItGetsAIServiceConfigurationByOrder(string[] serviceIds, string expe } } + [Theory] + [InlineData(new string[] { "modelid_1" }, "modelid_1")] + [InlineData(new string[] { "modelid_2" }, "modelid_2")] + [InlineData(new string[] { "modelid_3" }, "modelid_3")] + [InlineData(new string[] { "modelid_4", "modelid_1" }, "modelid_1")] + [InlineData(new string[] { "modelid_4", "" }, "modelid_3")] + public void ItGetsChatClientConfigurationByOrder(string[] serviceIds, string expectedModelId) + { + // Arrange + IKernelBuilder builder = Kernel.CreateBuilder(); + using var chatClient1 = new ChatClient("modelid_1"); + using var chatClient2 = new ChatClient("modelid_2"); + using var chatClient3 = new ChatClient("modelid_3"); + builder.Services.AddKeyedSingleton("modelid_1", chatClient1); + builder.Services.AddKeyedSingleton("modelid_2", chatClient2); + builder.Services.AddKeyedSingleton("modelid_3", chatClient3); + Kernel kernel = builder.Build(); + + var executionSettings = new Dictionary(); + foreach (var serviceId in serviceIds) + { + executionSettings.Add(serviceId, new PromptExecutionSettings() { ModelId = serviceId }); + } + var function = kernel.CreateFunctionFromPrompt(promptConfig: new PromptTemplateConfig() { Template = "Hello AI", ExecutionSettings = executionSettings }); + var serviceSelector = new OrderedAIServiceSelector(); + + // Act + serviceSelector.TrySelectChatClient(kernel, function, [], out var aiService, out var defaultExecutionSettings); + aiService?.Dispose(); + + // Assert + Assert.Equal(kernel.GetRequiredService(expectedModelId), aiService); + if (!string.IsNullOrEmpty(defaultExecutionSettings!.ModelId)) + { + Assert.Equal(expectedModelId, defaultExecutionSettings!.ModelId); + } + } + [Fact] public void ItGetsAIServiceConfigurationForTextGenerationByModelId() { @@ -253,6 +443,34 @@ public void ItGetsAIServiceConfigurationForTextGenerationByModelId() Assert.Equivalent(expectedExecutionSettings, defaultExecutionSettings); } + [Fact] + public void ItGetsChatClientConfigurationForChatClientByModelId() + { + // Arrange + IKernelBuilder builder = Kernel.CreateBuilder(); + using var chatClient1 = new ChatClient("model1"); + using var chatClient2 = new ChatClient("model2"); + builder.Services.AddKeyedSingleton(null, chatClient1); + builder.Services.AddKeyedSingleton(null, chatClient2); + Kernel kernel = builder.Build(); + + var arguments = new KernelArguments(); + var executionSettings = new PromptExecutionSettings() { ModelId = "model2" }; + var function = kernel.CreateFunctionFromPrompt("Hello AI", executionSettings: executionSettings); + var serviceSelector = new OrderedAIServiceSelector(); + + // Act + serviceSelector.TrySelectChatClient(kernel, function, arguments, out var aiService, out var defaultExecutionSettings); + aiService?.Dispose(); + + // Assert + Assert.NotNull(aiService); + Assert.Equal("model2", aiService.GetModelId()); + var expectedExecutionSettings = executionSettings.Clone(); + expectedExecutionSettings.Freeze(); + Assert.Equivalent(expectedExecutionSettings, defaultExecutionSettings); + } + #region private private sealed class AIService : IAIService { @@ -270,7 +488,7 @@ public TextGenerationService(string modelId) this._attributes.Add("ModelId", modelId); } - public Task> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + public Task> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) { throw new NotImplementedException(); } @@ -280,5 +498,73 @@ public IAsyncEnumerable GetStreamingTextContentsAsync(stri throw new NotImplementedException(); } } + + private sealed class ChatCompletionService : IChatCompletionService + { + public IReadOnlyDictionary Attributes => this._attributes; + + private readonly Dictionary _attributes = []; + + public ChatCompletionService(string modelId) + { + this._attributes.Add("ModelId", modelId); + } + + public Task> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public IAsyncEnumerable GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + } + + private sealed class ChatClient : IChatClient + { + public ChatClientMetadata Metadata { get; } + + public ChatClient(string modelId) + { + this.Metadata = new ChatClientMetadata(modelId: modelId); + } + + public Task> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public IAsyncEnumerable GetStreamingTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public Task GetResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public IAsyncEnumerable GetStreamingResponseAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType.IsInstanceOfType(this) ? this : + serviceType.IsInstanceOfType(this.Metadata) ? this.Metadata : + null; + } + + public void Dispose() + { + } + } #endregion } diff --git a/dotnet/src/SemanticKernel.UnitTests/KernelTests.cs b/dotnet/src/SemanticKernel.UnitTests/KernelTests.cs index b7ed4fc4a480..ccb96a28466c 100644 --- a/dotnet/src/SemanticKernel.UnitTests/KernelTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/KernelTests.cs @@ -17,6 +17,7 @@ using Microsoft.SemanticKernel.TextGeneration; using Moq; using Xunit; +using MEAI = Microsoft.Extensions.AI; #pragma warning disable CS0618 // Events are deprecated @@ -257,6 +258,110 @@ public async Task InvokeStreamingAsyncCallsConnectorStreamingApiAsync() mockTextCompletion.Verify(m => m.GetStreamingTextContentsAsync(It.IsIn("Write a simple phrase about UnitTests importance"), It.IsAny(), It.IsAny(), It.IsAny()), Times.Exactly(1)); } + [Fact] + public async Task InvokeStreamingAsyncCallsWithMEAIContentsAndChatCompletionApiAsync() + { + // Arrange + var mockChatCompletion = this.SetupStreamingChatCompletionMocks( + new StreamingChatMessageContent(AuthorRole.User, "chunk1"), + new StreamingChatMessageContent(AuthorRole.User, "chunk2")); + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddSingleton(mockChatCompletion.Object); + Kernel kernel = builder.Build(); + var prompt = "Write a simple phrase about UnitTests {{$input}}"; + var expectedPrompt = prompt.Replace("{{$input}}", "importance"); + var sut = KernelFunctionFactory.CreateFromPrompt(prompt); + var variables = new KernelArguments() { [InputParameterName] = "importance" }; + + var chunkCount = 0; + + // Act & Assert + await foreach (var chunk in sut.InvokeStreamingAsync(kernel, variables)) + { + Assert.Contains("chunk", chunk.Text); + chunkCount++; + } + + Assert.Equal(2, chunkCount); + mockChatCompletion.Verify(m => m.GetStreamingChatMessageContentsAsync(It.Is((m) => m[0].Content == expectedPrompt), It.IsAny(), It.IsAny(), It.IsAny()), Times.Exactly(1)); + } + + [Fact] + public async Task InvokeStreamingAsyncGenericPermutationsCallsConnectorChatClientAsync() + { + // Arrange + var customRawItem = new MEAI.ChatOptions(); + var mockChatClient = this.SetupStreamingChatClientMock( + new MEAI.ChatResponseUpdate() { Text = "chunk1", RawRepresentation = customRawItem }, + new MEAI.ChatResponseUpdate() { Text = "chunk2", RawRepresentation = customRawItem }); + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddSingleton(mockChatClient.Object); + Kernel kernel = builder.Build(); + var prompt = "Write a simple phrase about UnitTests {{$input}}"; + var expectedPrompt = prompt.Replace("{{$input}}", "importance"); + var sut = KernelFunctionFactory.CreateFromPrompt(prompt); + var variables = new KernelArguments() { [InputParameterName] = "importance" }; + + var totalChunksExpected = 0; + var totalInvocationTimesExpected = 0; + + // Act & Assert + totalInvocationTimesExpected++; + await foreach (var chunk in sut.InvokeStreamingAsync(kernel, variables)) + { + Assert.Contains("chunk", chunk); + totalChunksExpected++; + } + + totalInvocationTimesExpected++; + await foreach (var chunk in sut.InvokeStreamingAsync(kernel, variables)) + { + totalChunksExpected++; + Assert.Same(customRawItem, chunk.InnerContent); + } + + totalInvocationTimesExpected++; + await foreach (var chunk in sut.InvokeStreamingAsync(kernel, variables)) + { + Assert.Contains("chunk", chunk.Content); + Assert.Same(customRawItem, chunk.InnerContent); + totalChunksExpected++; + } + + totalInvocationTimesExpected++; + await foreach (var chunk in sut.InvokeStreamingAsync(kernel, variables)) + { + Assert.Contains("chunk", chunk.Text); + Assert.Same(customRawItem, chunk.RawRepresentation); + totalChunksExpected++; + } + + totalInvocationTimesExpected++; + await foreach (var chunk in sut.InvokeStreamingAsync(kernel, variables)) + { + Assert.Contains("chunk", chunk.ToString()); + totalChunksExpected++; + } + + totalInvocationTimesExpected++; + await foreach (var chunk in sut.InvokeStreamingAsync>(kernel, variables)) + { + Assert.Contains("chunk", chunk[0].ToString()); + totalChunksExpected++; + } + + totalInvocationTimesExpected++; + await foreach (var chunk in sut.InvokeStreamingAsync(kernel, variables)) + { + Assert.Contains("chunk", chunk.Text); + totalChunksExpected++; + } + + Assert.Equal(totalInvocationTimesExpected * 2, totalChunksExpected); + mockChatClient.Verify(m => m.GetStreamingResponseAsync(It.Is>((m) => m[0].Text == expectedPrompt), It.IsAny(), It.IsAny()), Times.Exactly(totalInvocationTimesExpected)); + } + [Fact] public async Task ValidateInvokeAsync() { @@ -316,6 +421,13 @@ public async IAsyncEnumerable GetStreamingChatMessa return (mockTextContent, mockTextCompletion); } + private Mock SetupStreamingChatCompletionMocks(params StreamingChatMessageContent[] streamingContents) + { + var mockChatCompletion = new Mock(); + mockChatCompletion.Setup(m => m.GetStreamingChatMessageContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).Returns(streamingContents.ToAsyncEnumerable()); + return mockChatCompletion; + } + private Mock SetupStreamingMocks(params StreamingTextContent[] streamingContents) { var mockTextCompletion = new Mock(); @@ -324,6 +436,15 @@ private Mock SetupStreamingMocks(params StreamingTextCon return mockTextCompletion; } + private Mock SetupStreamingChatClientMock(params MEAI.ChatResponseUpdate[] chatResponseUpdates) + { + var mockChatClient = new Mock(); + mockChatClient.Setup(m => m.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny())).Returns(chatResponseUpdates.ToAsyncEnumerable()); + mockChatClient.Setup(c => c.GetService(typeof(MEAI.ChatClientMetadata), It.IsAny())).Returns(new MEAI.ChatClientMetadata()); + + return mockChatClient; + } + private void AssertFilters(Kernel kernel1, Kernel kernel2) { var functionFilters1 = kernel1.GetAllServices().ToArray();