diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 6418a0b49811..cd5397f77916 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -433,8 +433,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OllamaFunctionCalling", "sa EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OpenAIRealtime", "samples\Demos\OpenAIRealtime\OpenAIRealtime.csproj", "{6154129E-7A35-44A5-998E-B7001B5EDE14}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "CreateChatGpt", "CreateChatGpt", "{02EA681E-C7D8-13C7-8484-4AC65E1B71E8}" -EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "VectorDataIntegrationTests", "VectorDataIntegrationTests", "{4F381919-F1BE-47D8-8558-3187ED04A84F}" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "QdrantIntegrationTests", "src\VectorDataIntegrationTests\QdrantIntegrationTests\QdrantIntegrationTests.csproj", "{27D33AB3-4DFF-48BC-8D76-FB2CDF90B707}" @@ -1681,7 +1679,6 @@ Global {B35B1DEB-04DF-4141-9163-01031B22C5D1} = {0D8C6358-5DAA-4EA6-A924-C268A9A21BC9} {481A680F-476A-4627-83DE-2F56C484525E} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {6154129E-7A35-44A5-998E-B7001B5EDE14} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} - {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {4F381919-F1BE-47D8-8558-3187ED04A84F} = {831DDCA2-7D2C-4C31-80DB-6BDB3E1F7AE0} {27D33AB3-4DFF-48BC-8D76-FB2CDF90B707} = {4F381919-F1BE-47D8-8558-3187ED04A84F} {B29A972F-A774-4140-AECF-6B577C476627} = {4F381919-F1BE-47D8-8558-3187ED04A84F} diff --git a/dotnet/samples/Concepts/Agents/ChatCompletion_FunctionTermination.cs b/dotnet/samples/Concepts/Agents/ChatCompletion_FunctionTermination.cs index 0a7960e743c6..f65541945084 100644 --- a/dotnet/samples/Concepts/Agents/ChatCompletion_FunctionTermination.cs +++ b/dotnet/samples/Concepts/Agents/ChatCompletion_FunctionTermination.cs @@ -13,15 +13,17 @@ namespace Agents; /// public class ChatCompletion_FunctionTermination(ITestOutputHelper output) : BaseAgentsTest(output) { - [Fact] - public async Task UseAutoFunctionInvocationFilterWithAgentInvocationAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task UseAutoFunctionInvocationFilterWithAgentInvocation(bool useChatClient) { // Define the agent ChatCompletionAgent agent = new() { Instructions = "Answer questions about the menu.", - Kernel = CreateKernelWithFilter(), + Kernel = CreateKernelWithFilter(useChatClient), Arguments = new KernelArguments(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }), }; @@ -53,15 +55,60 @@ async Task InvokeAgentAsync(string input) } } - [Fact] - public async Task UseAutoFunctionInvocationFilterWithStreamingAgentInvocationAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task UseAutoFunctionInvocationFilterWithAgentChat(bool useChatClient) { // Define the agent ChatCompletionAgent agent = new() { Instructions = "Answer questions about the menu.", - Kernel = CreateKernelWithFilter(), + Kernel = CreateKernelWithFilter(useChatClient), + Arguments = new KernelArguments(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }), + }; + + KernelPlugin plugin = KernelPluginFactory.CreateFromType(); + agent.Kernel.Plugins.Add(plugin); + + // Create a chat for agent interaction. + AgentGroupChat chat = new(); + + // Respond to user input, invoking functions where appropriate. + await InvokeAgentAsync("Hello"); + await InvokeAgentAsync("What is the special soup?"); + await InvokeAgentAsync("What is the special drink?"); + await InvokeAgentAsync("Thank you"); + + // Display the entire chat history. + WriteChatHistory(await chat.GetChatMessagesAsync().ToArrayAsync()); + + // Local function to invoke agent and display the conversation messages. + async Task InvokeAgentAsync(string input) + { + ChatMessageContent message = new(AuthorRole.User, input); + chat.AddChatMessage(message); + this.WriteAgentChatMessage(message); + + await foreach (ChatMessageContent response in chat.InvokeAsync(agent)) + { + this.WriteAgentChatMessage(response); + } + } + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task UseAutoFunctionInvocationFilterWithStreamingAgentInvocation(bool useChatClient) + { + // Define the agent + ChatCompletionAgent agent = + new() + { + Instructions = "Answer questions about the menu.", + Kernel = CreateKernelWithFilter(useChatClient), Arguments = new KernelArguments(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }), }; @@ -115,6 +162,61 @@ async Task InvokeAgentAsync(string input) } } + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task UseAutoFunctionInvocationFilterWithStreamingAgentChat(bool useChatClient) + { + // Define the agent + ChatCompletionAgent agent = + new() + { + Instructions = "Answer questions about the menu.", + Kernel = CreateKernelWithFilter(useChatClient), + Arguments = new KernelArguments(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }), + }; + + KernelPlugin plugin = KernelPluginFactory.CreateFromType(); + agent.Kernel.Plugins.Add(plugin); + + // Create a chat for agent interaction. + AgentGroupChat chat = new(); + + // Respond to user input, invoking functions where appropriate. + await InvokeAgentAsync("Hello"); + await InvokeAgentAsync("What is the special soup?"); + await InvokeAgentAsync("What is the special drink?"); + await InvokeAgentAsync("Thank you"); + + // Display the entire chat history. + WriteChatHistory(await chat.GetChatMessagesAsync().ToArrayAsync()); + + // Local function to invoke agent and display the conversation messages. + async Task InvokeAgentAsync(string input) + { + ChatMessageContent message = new(AuthorRole.User, input); + chat.AddChatMessage(message); + this.WriteAgentChatMessage(message); + + bool isFirst = false; + await foreach (StreamingChatMessageContent response in chat.InvokeStreamingAsync(agent)) + { + if (string.IsNullOrEmpty(response.Content)) + { + continue; + } + + if (!isFirst) + { + Console.WriteLine($"\n# {response.Role} - {response.AuthorName ?? "*"}:"); + isFirst = true; + } + + Console.WriteLine($"\t > streamed: '{response.Content}'"); + } + } + } + private void WriteChatHistory(IEnumerable chat) { Console.WriteLine("================================"); @@ -126,11 +228,18 @@ private void WriteChatHistory(IEnumerable chat) } } - private Kernel CreateKernelWithFilter() + private Kernel CreateKernelWithFilter(bool useChatClient) { IKernelBuilder builder = Kernel.CreateBuilder(); - base.AddChatCompletionToKernel(builder); + if (useChatClient) + { + base.AddChatClientToKernel(builder); + } + else + { + base.AddChatCompletionToKernel(builder); + } builder.Services.AddSingleton(new AutoInvocationFilter()); diff --git a/dotnet/samples/Concepts/Agents/ChatCompletion_HistoryReducer.cs b/dotnet/samples/Concepts/Agents/ChatCompletion_HistoryReducer.cs index 6c0268c7b4ef..4e820570d63d 100644 --- a/dotnet/samples/Concepts/Agents/ChatCompletion_HistoryReducer.cs +++ b/dotnet/samples/Concepts/Agents/ChatCompletion_HistoryReducer.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Agents; using Microsoft.SemanticKernel.ChatCompletion; @@ -19,26 +20,34 @@ public class ChatCompletion_HistoryReducer(ITestOutputHelper output) : BaseTest( /// Demonstrate the use of when directly /// invoking a . /// - [Fact] - public async Task TruncatedAgentReductionAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task TruncatedAgentReduction(bool useChatClient) { // Define the agent - ChatCompletionAgent agent = CreateTruncatingAgent(10, 10); + ChatCompletionAgent agent = CreateTruncatingAgent(10, 10, useChatClient, out var chatClient); await InvokeAgentAsync(agent, 50); + + chatClient?.Dispose(); } /// /// Demonstrate the use of when directly /// invoking a . /// - [Fact] - public async Task SummarizedAgentReductionAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task SummarizedAgentReduction(bool useChatClient) { // Define the agent - ChatCompletionAgent agent = CreateSummarizingAgent(10, 10); + ChatCompletionAgent agent = CreateSummarizingAgent(10, 10, useChatClient, out var chatClient); await InvokeAgentAsync(agent, 50); + + chatClient?.Dispose(); } // Proceed with dialog by directly invoking the agent and explicitly managing the history. @@ -79,25 +88,30 @@ private async Task InvokeAgentAsync(ChatCompletionAgent agent, int messageCount) } } - private ChatCompletionAgent CreateSummarizingAgent(int reducerMessageCount, int reducerThresholdCount) + private ChatCompletionAgent CreateSummarizingAgent(int reducerMessageCount, int reducerThresholdCount, bool useChatClient, out IChatClient? chatClient) { - Kernel kernel = this.CreateKernelWithChatCompletion(); + Kernel kernel = this.CreateKernelWithChatCompletion(useChatClient, out chatClient); + + var service = useChatClient + ? kernel.GetRequiredService().AsChatCompletionService() + : kernel.GetRequiredService(); + return new() { Name = TranslatorName, Instructions = TranslatorInstructions, Kernel = kernel, - HistoryReducer = new ChatHistorySummarizationReducer(kernel.GetRequiredService(), reducerMessageCount, reducerThresholdCount), + HistoryReducer = new ChatHistorySummarizationReducer(service, reducerMessageCount, reducerThresholdCount), }; } - private ChatCompletionAgent CreateTruncatingAgent(int reducerMessageCount, int reducerThresholdCount) => + private ChatCompletionAgent CreateTruncatingAgent(int reducerMessageCount, int reducerThresholdCount, bool useChatClient, out IChatClient? chatClient) => new() { Name = TranslatorName, Instructions = TranslatorInstructions, - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out chatClient), HistoryReducer = new ChatHistoryTruncationReducer(reducerMessageCount, reducerThresholdCount), }; } diff --git a/dotnet/samples/Concepts/Agents/ChatCompletion_Serialization.cs b/dotnet/samples/Concepts/Agents/ChatCompletion_Serialization.cs index 1bc16f452d6c..9153e4b45cda 100644 --- a/dotnet/samples/Concepts/Agents/ChatCompletion_Serialization.cs +++ b/dotnet/samples/Concepts/Agents/ChatCompletion_Serialization.cs @@ -13,8 +13,10 @@ public class ChatCompletion_Serialization(ITestOutputHelper output) : BaseAgents private const string HostName = "Host"; private const string HostInstructions = "Answer questions about the menu."; - [Fact] - public async Task SerializeAndRestoreAgentGroupChatAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task SerializeAndRestoreAgentGroupChat(bool useChatClient) { // Define the agent ChatCompletionAgent agent = @@ -22,7 +24,7 @@ public async Task SerializeAndRestoreAgentGroupChatAsync() { Instructions = HostInstructions, Name = HostName, - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient), Arguments = new KernelArguments(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }), }; @@ -51,6 +53,8 @@ public async Task SerializeAndRestoreAgentGroupChatAsync() this.WriteAgentChatMessage(content); } + chatClient?.Dispose(); + // Local function to invoke agent and display the conversation messages. async Task InvokeAgentAsync(AgentGroupChat chat, string input) { diff --git a/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs b/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs index 1434330c61eb..be69fe412d5e 100644 --- a/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs +++ b/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs @@ -1,4 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System.ClientModel; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Agents; using Microsoft.SemanticKernel.ChatCompletion; @@ -15,13 +18,13 @@ public class ChatCompletion_ServiceSelection(ITestOutputHelper output) : BaseAge private const string ServiceKeyGood = "chat-good"; private const string ServiceKeyBad = "chat-bad"; - [Fact] - public async Task UseServiceSelectionWithChatCompletionAgentAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task UseServiceSelectionWithChatCompletionAgent(bool useChatClient) { - // Create kernel with two instances of IChatCompletionService - // One service is configured with a valid API key and the other with an - // invalid key that will result in a 401 Unauthorized error. - Kernel kernel = CreateKernelWithTwoServices(); + // Create kernel with two instances of chat services - one good, one bad + Kernel kernel = CreateKernelWithTwoServices(useChatClient); // Define the agent targeting ServiceId = ServiceKeyGood ChatCompletionAgent agentGood = @@ -88,38 +91,78 @@ async Task InvokeAgentAsync(ChatCompletionAgent agent, KernelArguments? argument { Console.WriteLine($"Status: {exception.StatusCode}"); } + catch (ClientResultException cre) + { + Console.WriteLine($"Status: {cre.Status}"); + } } } - private Kernel CreateKernelWithTwoServices() + private Kernel CreateKernelWithTwoServices(bool useChatClient) { IKernelBuilder builder = Kernel.CreateBuilder(); - if (this.UseOpenAIConfig) + if (useChatClient) { - builder.AddOpenAIChatCompletion( - TestConfiguration.OpenAI.ChatModelId, - "bad-key", - serviceId: ServiceKeyBad); - - builder.AddOpenAIChatCompletion( - TestConfiguration.OpenAI.ChatModelId, - TestConfiguration.OpenAI.ApiKey, - serviceId: ServiceKeyGood); + // Add chat clients + if (this.UseOpenAIConfig) + { + builder.Services.AddKeyedChatClient( + ServiceKeyBad, + new OpenAI.OpenAIClient("bad-key").GetChatClient(TestConfiguration.OpenAI.ChatModelId).AsIChatClient()); + + builder.Services.AddKeyedChatClient( + ServiceKeyGood, + new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey).GetChatClient(TestConfiguration.OpenAI.ChatModelId).AsIChatClient()); + } + else + { + builder.Services.AddKeyedChatClient( + ServiceKeyBad, + new Azure.AI.OpenAI.AzureOpenAIClient( + new Uri(TestConfiguration.AzureOpenAI.Endpoint), + new Azure.AzureKeyCredential("bad-key")) + .GetChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName) + .AsIChatClient()); + + builder.Services.AddKeyedChatClient( + ServiceKeyGood, + new Azure.AI.OpenAI.AzureOpenAIClient( + new Uri(TestConfiguration.AzureOpenAI.Endpoint), + new Azure.AzureKeyCredential(TestConfiguration.AzureOpenAI.ApiKey)) + .GetChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName) + .AsIChatClient()); + } } else { - builder.AddAzureOpenAIChatCompletion( - TestConfiguration.AzureOpenAI.ChatDeploymentName, - TestConfiguration.AzureOpenAI.Endpoint, - "bad-key", - serviceId: ServiceKeyBad); - - builder.AddAzureOpenAIChatCompletion( - TestConfiguration.AzureOpenAI.ChatDeploymentName, - TestConfiguration.AzureOpenAI.Endpoint, - TestConfiguration.AzureOpenAI.ApiKey, - serviceId: ServiceKeyGood); + // Add chat completion services + if (this.UseOpenAIConfig) + { + builder.AddOpenAIChatCompletion( + TestConfiguration.OpenAI.ChatModelId, + "bad-key", + serviceId: ServiceKeyBad); + + builder.AddOpenAIChatCompletion( + TestConfiguration.OpenAI.ChatModelId, + TestConfiguration.OpenAI.ApiKey, + serviceId: ServiceKeyGood); + } + else + { + builder.AddAzureOpenAIChatCompletion( + TestConfiguration.AzureOpenAI.ChatDeploymentName, + TestConfiguration.AzureOpenAI.Endpoint, + "bad-key", + serviceId: ServiceKeyBad); + + builder.AddAzureOpenAIChatCompletion( + TestConfiguration.AzureOpenAI.ChatDeploymentName, + TestConfiguration.AzureOpenAI.Endpoint, + TestConfiguration.AzureOpenAI.ApiKey, + serviceId: ServiceKeyGood); + } } return builder.Build(); diff --git a/dotnet/samples/Concepts/Agents/ChatCompletion_Streaming.cs b/dotnet/samples/Concepts/Agents/ChatCompletion_Streaming.cs index 7233d9729984..c2e0c7ccacfd 100644 --- a/dotnet/samples/Concepts/Agents/ChatCompletion_Streaming.cs +++ b/dotnet/samples/Concepts/Agents/ChatCompletion_Streaming.cs @@ -14,8 +14,10 @@ public class ChatCompletion_Streaming(ITestOutputHelper output) : BaseAgentsTest private const string ParrotName = "Parrot"; private const string ParrotInstructions = "Repeat the user message in the voice of a pirate and then end with a parrot sound."; - [Fact] - public async Task UseStreamingChatCompletionAgentAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task UseStreamingChatCompletionAgent(bool useChatClient) { // Define the agent ChatCompletionAgent agent = @@ -23,7 +25,7 @@ public async Task UseStreamingChatCompletionAgentAsync() { Name = ParrotName, Instructions = ParrotInstructions, - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient), }; ChatHistoryAgentThread agentThread = new(); @@ -35,10 +37,14 @@ public async Task UseStreamingChatCompletionAgentAsync() // Output the entire chat history await DisplayChatHistory(agentThread); + + chatClient?.Dispose(); } - [Fact] - public async Task UseStreamingChatCompletionAgentWithPluginAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task UseStreamingChatCompletionAgentWithPlugin(bool useChatClient) { const string MenuInstructions = "Answer questions about the menu."; @@ -48,7 +54,7 @@ public async Task UseStreamingChatCompletionAgentWithPluginAsync() { Name = "Host", Instructions = MenuInstructions, - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient), Arguments = new KernelArguments(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }), }; @@ -64,6 +70,8 @@ public async Task UseStreamingChatCompletionAgentWithPluginAsync() // Output the entire chat history await DisplayChatHistory(agentThread); + + chatClient?.Dispose(); } // Local function to invoke agent and display the conversation messages. diff --git a/dotnet/samples/Concepts/Agents/ChatCompletion_Templating.cs b/dotnet/samples/Concepts/Agents/ChatCompletion_Templating.cs index dcf5e1d34e53..fe7a305dc441 100644 --- a/dotnet/samples/Concepts/Agents/ChatCompletion_Templating.cs +++ b/dotnet/samples/Concepts/Agents/ChatCompletion_Templating.cs @@ -20,14 +20,16 @@ private readonly static (string Input, string? Style)[] s_inputs = (Input: "What do you think about having fun?", Style: "old school rap") ]; - [Fact] - public async Task InvokeAgentWithInstructionsTemplateAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task InvokeAgentWithInstructionsTemplate(bool useChatClient) { // Instruction based template always processed by KernelPromptTemplateFactory ChatCompletionAgent agent = new() { - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient), Instructions = """ Write a one verse poem on the requested topic in the style of {{$style}}. @@ -40,10 +42,14 @@ Always state the requested style of the poem. }; await InvokeChatCompletionAgentWithTemplateAsync(agent); + + chatClient?.Dispose(); } - [Fact] - public async Task InvokeAgentWithKernelTemplateAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task InvokeAgentWithKernelTemplate(bool useChatClient) { // Default factory is KernelPromptTemplateFactory await InvokeChatCompletionAgentWithTemplateAsync( @@ -52,11 +58,14 @@ Write a one verse poem on the requested topic in the style of {{$style}}. Always state the requested style of the poem. """, PromptTemplateConfig.SemanticKernelTemplateFormat, - new KernelPromptTemplateFactory()); + new KernelPromptTemplateFactory(), + useChatClient); } - [Fact] - public async Task InvokeAgentWithHandlebarsTemplateAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task InvokeAgentWithHandlebarsTemplate(bool useChatClient) { await InvokeChatCompletionAgentWithTemplateAsync( """ @@ -64,11 +73,14 @@ Write a one verse poem on the requested topic in the style of {{style}}. Always state the requested style of the poem. """, HandlebarsPromptTemplateFactory.HandlebarsTemplateFormat, - new HandlebarsPromptTemplateFactory()); + new HandlebarsPromptTemplateFactory(), + useChatClient); } - [Fact] - public async Task InvokeAgentWithLiquidTemplateAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task InvokeAgentWithLiquidTemplate(bool useChatClient) { await InvokeChatCompletionAgentWithTemplateAsync( """ @@ -76,13 +88,15 @@ Write a one verse poem on the requested topic in the style of {{style}}. Always state the requested style of the poem. """, LiquidPromptTemplateFactory.LiquidTemplateFormat, - new LiquidPromptTemplateFactory()); + new LiquidPromptTemplateFactory(), + useChatClient); } private async Task InvokeChatCompletionAgentWithTemplateAsync( string instructionTemplate, string templateFormat, - IPromptTemplateFactory templateFactory) + IPromptTemplateFactory templateFactory, + bool useChatClient) { // Define the agent PromptTemplateConfig templateConfig = @@ -94,7 +108,7 @@ private async Task InvokeChatCompletionAgentWithTemplateAsync( ChatCompletionAgent agent = new(templateConfig, templateFactory) { - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient), Arguments = new KernelArguments() { {"style", "haiku"} @@ -102,6 +116,8 @@ private async Task InvokeChatCompletionAgentWithTemplateAsync( }; await InvokeChatCompletionAgentWithTemplateAsync(agent); + + chatClient?.Dispose(); } private async Task InvokeChatCompletionAgentWithTemplateAsync(ChatCompletionAgent agent) diff --git a/dotnet/samples/Concepts/Agents/ComplexChat_NestedShopper.cs b/dotnet/samples/Concepts/Agents/ComplexChat_NestedShopper.cs index 6f07fb739190..20327fbc5d50 100644 --- a/dotnet/samples/Concepts/Agents/ComplexChat_NestedShopper.cs +++ b/dotnet/samples/Concepts/Agents/ComplexChat_NestedShopper.cs @@ -92,8 +92,10 @@ Select which participant will take the next turn based on the conversation histo {{${{{KernelFunctionTerminationStrategy.DefaultHistoryVariableName}}}}} """; - [Fact] - public async Task NestedChatWithAggregatorAgentAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task NestedChatWithAggregatorAgent(bool useChatClient) { Console.WriteLine($"! {Model}"); @@ -121,7 +123,7 @@ public async Task NestedChatWithAggregatorAgentAsync() new() { TerminationStrategy = - new KernelFunctionTerminationStrategy(outerTerminationFunction, CreateKernelWithChatCompletion()) + new KernelFunctionTerminationStrategy(outerTerminationFunction, CreateKernelWithChatCompletion(useChatClient, out var chatClient)) { ResultParser = (result) => @@ -158,6 +160,8 @@ public async Task NestedChatWithAggregatorAgentAsync() this.WriteAgentChatMessage(message); } + chatClient?.Dispose(); + async Task InvokeChatAsync(string input) { ChatMessageContent message = new(AuthorRole.User, input); diff --git a/dotnet/samples/Concepts/Agents/DeclarativeAgents.cs b/dotnet/samples/Concepts/Agents/DeclarativeAgents.cs index 1f55187359e3..39a99f41b59d 100644 --- a/dotnet/samples/Concepts/Agents/DeclarativeAgents.cs +++ b/dotnet/samples/Concepts/Agents/DeclarativeAgents.cs @@ -6,15 +6,32 @@ namespace Agents; +/// +/// Sample showing how declarative agents can be defined through JSON manifest files. +/// Demonstrates how to load and configure an agent from a declarative manifest that specifies: +/// - The agent's identity (name, description, instructions) +/// - The agent's available actions/plugins +/// - Authentication parameters for accessing external services +/// +/// +/// The test uses a SchedulingAssistant example that can: +/// - Read emails for meeting requests +/// - Check calendar availability +/// - Process scheduling-related tasks +/// The agent is configured via "SchedulingAssistant.json" manifest which defines the required +/// plugins and capabilities. +/// public class DeclarativeAgents(ITestOutputHelper output) : BaseAgentsTest(output) { - [InlineData( - "SchedulingAssistant.json", - "Read the body of my last five emails, if any contain a meeting request for today, check that it's already on my calendar, if not, call out which email it is.")] [Theory] - public async Task LoadsAgentFromDeclarativeAgentManifestAsync(string agentFileName, string input) + [InlineData(true)] + [InlineData(false)] + public async Task LoadsAgentFromDeclarativeAgentManifest(bool useChatClient) { - var kernel = this.CreateKernelWithChatCompletion(); + var agentFileName = "SchedulingAssistant.json"; + var input = "Read the body of my last five emails, if any contain a meeting request for today, check that it's already on my calendar, if not, call out which email it is."; + + var kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient); kernel.AutoFunctionInvocationFilters.Add(new ExpectedSchemaFunctionFilter()); var manifestLookupDirectory = Path.Combine(Directory.GetCurrentDirectory(), "..", "..", "..", "Resources", "DeclarativeAgents"); var manifestFilePath = Path.Combine(manifestLookupDirectory, agentFileName); @@ -45,6 +62,8 @@ public async Task LoadsAgentFromDeclarativeAgentManifestAsync(string agentFileNa var responses = await agent.InvokeAsync(new ChatMessageContent(AuthorRole.User, input), agentThread, options: new() { KernelArguments = kernelArguments }).ToArrayAsync(); Assert.NotEmpty(responses); + + chatClient?.Dispose(); } private sealed class ExpectedSchemaFunctionFilter : IAutoFunctionInvocationFilter diff --git a/dotnet/samples/Concepts/Agents/MixedChat_Agents.cs b/dotnet/samples/Concepts/Agents/MixedChat_Agents.cs index 0895308f0215..887f5c95a7f7 100644 --- a/dotnet/samples/Concepts/Agents/MixedChat_Agents.cs +++ b/dotnet/samples/Concepts/Agents/MixedChat_Agents.cs @@ -33,8 +33,10 @@ Only provide a single proposal per response. Consider suggestions when refining an idea. """; - [Fact] - public async Task ChatWithOpenAIAssistantAgentAndChatCompletionAgentAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ChatWithOpenAIAssistantAgentAndChatCompletionAgent(bool useChatClient) { // Define the agents: one of each type ChatCompletionAgent agentReviewer = @@ -42,7 +44,7 @@ public async Task ChatWithOpenAIAssistantAgentAndChatCompletionAgentAsync() { Instructions = ReviewerInstructions, Name = ReviewerName, - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient), }; // Define the assistant @@ -87,6 +89,8 @@ await this.AssistantClient.CreateAssistantAsync( } Console.WriteLine($"\n[IS COMPLETED: {chat.IsComplete}]"); + + chatClient?.Dispose(); } private sealed class ApprovalTerminationStrategy : TerminationStrategy diff --git a/dotnet/samples/Concepts/Agents/MixedChat_Files.cs b/dotnet/samples/Concepts/Agents/MixedChat_Files.cs index 56ff0f331f0b..77bf16975c4f 100644 --- a/dotnet/samples/Concepts/Agents/MixedChat_Files.cs +++ b/dotnet/samples/Concepts/Agents/MixedChat_Files.cs @@ -16,8 +16,10 @@ public class MixedChat_Files(ITestOutputHelper output) : BaseAssistantTest(outpu { private const string SummaryInstructions = "Summarize the entire conversation for the user in natural language."; - [Fact] - public async Task AnalyzeFileAndGenerateReportAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task AnalyzeFileAndGenerateReport(bool useChatClient) { await using Stream stream = EmbeddedResource.ReadStream("30-user-context.txt")!; string fileId = await this.Client.UploadAssistantFileAsync(stream, "30-user-context.txt"); @@ -38,7 +40,7 @@ await this.AssistantClient.CreateAssistantAsync( new() { Instructions = SummaryInstructions, - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient), }; // Create a chat for agent interaction. @@ -61,6 +63,8 @@ Create a tab delimited file report of the ordered (descending) frequency distrib await this.Client.DeleteFileAsync(fileId); } + chatClient?.Dispose(); + // Local function to invoke agent and display the conversation messages. async Task InvokeAgentAsync(Agent agent, string? input = null) { diff --git a/dotnet/samples/Concepts/Agents/MixedChat_Images.cs b/dotnet/samples/Concepts/Agents/MixedChat_Images.cs index 158da60e418a..825218a5f8cd 100644 --- a/dotnet/samples/Concepts/Agents/MixedChat_Images.cs +++ b/dotnet/samples/Concepts/Agents/MixedChat_Images.cs @@ -19,8 +19,10 @@ public class MixedChat_Images(ITestOutputHelper output) : BaseAssistantTest(outp private const string SummarizerName = "Summarizer"; private const string SummarizerInstructions = "Summarize the entire conversation for the user in natural language."; - [Fact] - public async Task AnalyzeDataAndGenerateChartAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task AnalyzeDataAndGenerateChartAsync(bool useChatClient) { // Define the assistant Assistant assistant = @@ -39,7 +41,7 @@ await this.AssistantClient.CreateAssistantAsync( { Instructions = SummarizerInstructions, Name = SummarizerName, - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient), }; // Create a chat for agent interaction. @@ -73,6 +75,8 @@ await InvokeAgentAsync( await this.AssistantClient.DeleteAssistantAsync(analystAgent.Id); } + chatClient?.Dispose(); + // Local function to invoke agent and display the conversation messages. async Task InvokeAgentAsync(Agent agent, string? input = null) { diff --git a/dotnet/samples/Concepts/Agents/MixedChat_Reset.cs b/dotnet/samples/Concepts/Agents/MixedChat_Reset.cs index 431dcc982a5e..a12dc5087b6c 100644 --- a/dotnet/samples/Concepts/Agents/MixedChat_Reset.cs +++ b/dotnet/samples/Concepts/Agents/MixedChat_Reset.cs @@ -18,8 +18,10 @@ The user may either provide information or query on information previously provi If the query does not correspond with information provided, inform the user that their query cannot be answered. """; - [Fact] - public async Task ResetChatAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ResetChat(bool useChatClient) { // Define the assistant Assistant assistant = @@ -36,7 +38,7 @@ await this.AssistantClient.CreateAssistantAsync( { Name = nameof(ChatCompletionAgent), Instructions = AgentInstructions, - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient), }; // Create a chat for agent interaction. @@ -65,6 +67,8 @@ await this.AssistantClient.CreateAssistantAsync( await this.AssistantClient.DeleteAssistantAsync(assistantAgent.Id); } + chatClient?.Dispose(); + // Local function to invoke agent and display the conversation messages. async Task InvokeAgentAsync(Agent agent, string? input = null) { diff --git a/dotnet/samples/Concepts/Agents/MixedChat_Serialization.cs b/dotnet/samples/Concepts/Agents/MixedChat_Serialization.cs index 4979ceedacb1..45f220b1a64a 100644 --- a/dotnet/samples/Concepts/Agents/MixedChat_Serialization.cs +++ b/dotnet/samples/Concepts/Agents/MixedChat_Serialization.cs @@ -28,8 +28,10 @@ Never repeat the same number. Only respond with a single number that is the result of your calculation without explanation. """; - [Fact] - public async Task SerializeAndRestoreAgentGroupChatAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task SerializeAndRestoreAgentGroupChat(bool useChatClient) { // Define the agents: one of each type ChatCompletionAgent agentTranslator = @@ -37,7 +39,7 @@ public async Task SerializeAndRestoreAgentGroupChatAsync() { Instructions = TranslatorInstructions, Name = TranslatorName, - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient), }; // Define the assistant @@ -74,6 +76,8 @@ await this.AssistantClient.CreateAssistantAsync( this.WriteAgentChatMessage(content); } + chatClient?.Dispose(); + async Task InvokeAgents(AgentGroupChat chat) { await foreach (ChatMessageContent content in chat.InvokeAsync()) diff --git a/dotnet/samples/Concepts/Agents/MixedChat_Streaming.cs b/dotnet/samples/Concepts/Agents/MixedChat_Streaming.cs index fc28c3c683dd..31655862f1ba 100644 --- a/dotnet/samples/Concepts/Agents/MixedChat_Streaming.cs +++ b/dotnet/samples/Concepts/Agents/MixedChat_Streaming.cs @@ -34,8 +34,10 @@ Only provide a single proposal per response. Consider suggestions when refining an idea. """; - [Fact] - public async Task UseStreamingAgentChatAsync() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task UseStreamingAgentChat(bool useChatClient) { // Define the agents: one of each type ChatCompletionAgent agentReviewer = @@ -43,7 +45,7 @@ public async Task UseStreamingAgentChatAsync() { Instructions = ReviewerInstructions, Name = ReviewerName, - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient), }; // Define the assistant @@ -112,6 +114,8 @@ await this.AssistantClient.CreateAssistantAsync( } Console.WriteLine($"\n[IS COMPLETED: {chat.IsComplete}]"); + + chatClient?.Dispose(); } private sealed class ApprovalTerminationStrategy : TerminationStrategy diff --git a/dotnet/samples/Concepts/Concepts.csproj b/dotnet/samples/Concepts/Concepts.csproj index 3e1b7d9b1e38..b2e43b32b9ab 100644 --- a/dotnet/samples/Concepts/Concepts.csproj +++ b/dotnet/samples/Concepts/Concepts.csproj @@ -8,7 +8,7 @@ false true - $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101,SKEXP0110,OPENAI001,CA1724,MEVD9000 + $(NoWarn);CS8618,IDE0009,IDE1006,CA1051,CA1050,CA1707,CA1054,CA2007,VSTHRD111,CS1591,RCS1110,RCS1243,CA5394,SKEXP0001,SKEXP0010,SKEXP0020,SKEXP0040,SKEXP0050,SKEXP0060,SKEXP0070,SKEXP0101,SKEXP0110,OPENAI001,CA1724,IDE1006,IDE0009,MEVD9000 Library 5ee045b0-aea3-4f08-8d31-32d1a6f8fed0 diff --git a/dotnet/samples/Concepts/Filtering/ChatClient_AutoFunctionInvocationFiltering.cs b/dotnet/samples/Concepts/Filtering/ChatClient_AutoFunctionInvocationFiltering.cs new file mode 100644 index 000000000000..1e053618a385 --- /dev/null +++ b/dotnet/samples/Concepts/Filtering/ChatClient_AutoFunctionInvocationFiltering.cs @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.OpenAI; + +namespace Filtering; + +public class ChatClient_AutoFunctionInvocationFiltering(ITestOutputHelper output) : BaseTest(output) +{ + /// + /// Shows how to use . + /// + [Fact] + public async Task UsingAutoFunctionInvocationFilter() + { + var builder = Kernel.CreateBuilder(); + + builder.AddOpenAIChatClient("gpt-4", TestConfiguration.OpenAI.ApiKey); + + // This filter outputs information about auto function invocation and returns overridden result. + builder.Services.AddSingleton(new AutoFunctionInvocationFilter(this.Output)); + + var kernel = builder.Build(); + + var function = KernelFunctionFactory.CreateFromMethod(() => "Result from function", "MyFunction"); + + kernel.ImportPluginFromFunctions("MyPlugin", [function]); + + var executionSettings = new OpenAIPromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Required([function], autoInvoke: true) + }; + + var result = await kernel.InvokePromptAsync("Invoke provided function and return result", new(executionSettings)); + + Console.WriteLine(result); + + // Output: + // Request sequence number: 0 + // Function sequence number: 0 + // Total number of functions: 1 + // Result from auto function invocation filter. + } + + /// + /// Shows how to get list of function calls by using . + /// + [Fact] + public async Task GetFunctionCallsWithFilterAsync() + { + var builder = Kernel.CreateBuilder(); + + builder.AddOpenAIChatCompletion("gpt-3.5-turbo-1106", TestConfiguration.OpenAI.ApiKey); + + builder.Services.AddSingleton(new FunctionCallsFilter(this.Output)); + + var kernel = builder.Build(); + + kernel.ImportPluginFromFunctions("HelperFunctions", + [ + kernel.CreateFunctionFromMethod(() => DateTime.UtcNow.ToString("R"), "GetCurrentUtcTime", "Retrieves the current time in UTC."), + kernel.CreateFunctionFromMethod((string cityName) => + cityName switch + { + "Boston" => "61 and rainy", + "London" => "55 and cloudy", + "Miami" => "80 and sunny", + "Paris" => "60 and rainy", + "Tokyo" => "50 and sunny", + "Sydney" => "75 and sunny", + "Tel Aviv" => "80 and sunny", + _ => "31 and snowing", + }, "GetWeatherForCity", "Gets the current weather for the specified city"), + ]); + + var executionSettings = new OpenAIPromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + }; + + await foreach (var chunk in kernel.InvokePromptStreamingAsync("Check current UTC time and return current weather in Boston city.", new(executionSettings))) + { + Console.WriteLine(chunk.ToString()); + } + + // Output: + // Request #0. Function call: HelperFunctions.GetCurrentUtcTime. + // Request #0. Function call: HelperFunctions.GetWeatherForCity. + // The current UTC time is {time of execution}, and the current weather in Boston is 61°F and rainy. + } + + /// Shows available syntax for auto function invocation filter. + private sealed class AutoFunctionInvocationFilter(ITestOutputHelper output) : IAutoFunctionInvocationFilter + { + public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) + { + // Example: get function information + var functionName = context.Function.Name; + + // Example: get chat history + var chatHistory = context.ChatHistory; + + // Example: get information about all functions which will be invoked + var functionCalls = FunctionCallContent.GetFunctionCalls(context.ChatHistory.Last()); + + // In function calling functionality there are two loops. + // Outer loop is "request" loop - it performs multiple requests to LLM until user ask will be satisfied. + // Inner loop is "function" loop - it handles LLM response with multiple function calls. + + // Workflow example: + // 1. Request to LLM #1 -> Response with 3 functions to call. + // 1.1. Function #1 called. + // 1.2. Function #2 called. + // 1.3. Function #3 called. + // 2. Request to LLM #2 -> Response with 2 functions to call. + // 2.1. Function #1 called. + // 2.2. Function #2 called. + + // context.RequestSequenceIndex - it's a sequence number of outer/request loop operation. + // context.FunctionSequenceIndex - it's a sequence number of inner/function loop operation. + // context.FunctionCount - number of functions which will be called per request (based on example above: 3 for first request, 2 for second request). + + // Example: get request sequence index + output.WriteLine($"Request sequence index: {context.RequestSequenceIndex}"); + + // Example: get function sequence index + output.WriteLine($"Function sequence index: {context.FunctionSequenceIndex}"); + + // Example: get total number of functions which will be called + output.WriteLine($"Total number of functions: {context.FunctionCount}"); + + // Calling next filter in pipeline or function itself. + // By skipping this call, next filters and function won't be invoked, and function call loop will proceed to the next function. + await next(context); + + // Example: get function result + var result = context.Result; + + // Example: override function result value + context.Result = new FunctionResult(context.Result, "Result from auto function invocation filter"); + + // Example: Terminate function invocation + context.Terminate = true; + } + } + + /// Shows how to get list of all function calls per request. + private sealed class FunctionCallsFilter(ITestOutputHelper output) : IAutoFunctionInvocationFilter + { + public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) + { + var chatHistory = context.ChatHistory; + var functionCalls = FunctionCallContent.GetFunctionCalls(chatHistory.Last()).ToArray(); + + if (functionCalls is { Length: > 0 }) + { + foreach (var functionCall in functionCalls) + { + output.WriteLine($"Request #{context.RequestSequenceIndex}. Function call: {functionCall.PluginName}.{functionCall.FunctionName}."); + } + } + + await next(context); + } + } +} diff --git a/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs b/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs index b0fdcad2e86f..02ddbdb3ec35 100644 --- a/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs +++ b/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.OpenAI; @@ -8,29 +9,39 @@ namespace KernelExamples; +/// +/// This sample shows how to use a custom AI service selector to select a specific model by matching the model id. +/// public class CustomAIServiceSelector(ITestOutputHelper output) : BaseTest(output) { - /// - /// Show how to use a custom AI service selector to select a specific model - /// [Fact] - public async Task RunAsync() + public async Task UsingCustomSelectToSelectServiceByMatchingModelId() { - Console.WriteLine($"======== {nameof(CustomAIServiceSelector)} ========"); + Console.WriteLine($"======== {nameof(UsingCustomSelectToSelectServiceByMatchingModelId)} ========"); - // Build a kernel with multiple chat completion services + // Use the custom AI service selector to select any registered service starting with "gpt" on it's model id + var customSelector = new GptAIServiceSelector(modelNameStartsWith: "gpt", this.Output); + + // Build a kernel with multiple chat services var builder = Kernel.CreateBuilder() .AddAzureOpenAIChatCompletion( deploymentName: TestConfiguration.AzureOpenAI.ChatDeploymentName, endpoint: TestConfiguration.AzureOpenAI.Endpoint, apiKey: TestConfiguration.AzureOpenAI.ApiKey, serviceId: "AzureOpenAIChat", - modelId: TestConfiguration.AzureOpenAI.ChatModelId) + modelId: "o1-mini") .AddOpenAIChatCompletion( - modelId: TestConfiguration.OpenAI.ChatModelId, + modelId: "o1-mini", apiKey: TestConfiguration.OpenAI.ApiKey, serviceId: "OpenAIChat"); - builder.Services.AddSingleton(new GptAIServiceSelector(this.Output)); // Use the custom AI service selector to select the GPT model + + // The kernel also allows you to use a IChatClient chat service as well + builder.Services + .AddSingleton(customSelector) + .AddKeyedChatClient("OpenAIChatClient", new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey) + .GetChatClient("gpt-4o") + .AsIChatClient()); // Add a IChatClient to the kernel + Kernel kernel = builder.Build(); // This invocation is done with the model selected by the custom selector @@ -45,20 +56,34 @@ public async Task RunAsync() /// a completion model whose name starts with "gpt". But this logic could /// be as elaborate as needed to apply your own selection criteria. /// - private sealed class GptAIServiceSelector(ITestOutputHelper output) : IAIServiceSelector + private sealed class GptAIServiceSelector(string modelNameStartsWith, ITestOutputHelper output) : IAIServiceSelector, IChatClientSelector { private readonly ITestOutputHelper _output = output; + private readonly string _modelNameStartsWith = modelNameStartsWith; - public bool TrySelectAIService( + private bool TrySelect( Kernel kernel, KernelFunction function, KernelArguments arguments, - [NotNullWhen(true)] out T? service, out PromptExecutionSettings? serviceSettings) where T : class, IAIService + [NotNullWhen(true)] out T? service, out PromptExecutionSettings? serviceSettings) where T : class { foreach (var serviceToCheck in kernel.GetAllServices()) { + string? serviceModelId = null; + string? endpoint = null; + + if (serviceToCheck is IAIService aiService) + { + serviceModelId = aiService.GetModelId(); + endpoint = aiService.GetEndpoint(); + } + else if (serviceToCheck is IChatClient chatClient) + { + var metadata = chatClient.GetService(); + serviceModelId = metadata?.DefaultModelId; + endpoint = metadata?.ProviderUri?.ToString(); + } + // Find the first service that has a model id that starts with "gpt" - var serviceModelId = serviceToCheck.GetModelId(); - var endpoint = serviceToCheck.GetEndpoint(); - if (!string.IsNullOrEmpty(serviceModelId) && serviceModelId.StartsWith("gpt", StringComparison.OrdinalIgnoreCase)) + if (!string.IsNullOrEmpty(serviceModelId) && serviceModelId.StartsWith(this._modelNameStartsWith, StringComparison.OrdinalIgnoreCase)) { this._output.WriteLine($"Selected model: {serviceModelId} {endpoint}"); service = serviceToCheck; @@ -71,5 +96,23 @@ public bool TrySelectAIService( serviceSettings = null; return false; } + + /// + public bool TrySelectAIService( + Kernel kernel, + KernelFunction function, + KernelArguments arguments, + [NotNullWhen(true)] out T? service, + out PromptExecutionSettings? serviceSettings) where T : class, IAIService + => this.TrySelect(kernel, function, arguments, out service, out serviceSettings); + + /// + public bool TrySelectChatClient( + Kernel kernel, + KernelFunction function, + KernelArguments arguments, + [NotNullWhen(true)] out T? service, + out PromptExecutionSettings? serviceSettings) where T : class, IChatClient + => this.TrySelect(kernel, function, arguments, out service, out serviceSettings); } } diff --git a/dotnet/samples/Concepts/Plugins/CustomMutablePlugin.cs b/dotnet/samples/Concepts/Plugins/CustomMutablePlugin.cs index 4cbfcf530b53..4d7ae2eb2cce 100644 --- a/dotnet/samples/Concepts/Plugins/CustomMutablePlugin.cs +++ b/dotnet/samples/Concepts/Plugins/CustomMutablePlugin.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel; namespace Plugins; @@ -63,8 +64,8 @@ public override bool TryGetFunction(string name, [NotNullWhen(true)] out KernelF /// Adds a function to the plugin. /// The function to add. /// is null. - /// 's is null. - /// A function with the same already exists in this plugin. + /// 's is null. + /// A function with the same already exists in this plugin. public void AddFunction(KernelFunction function) { ArgumentNullException.ThrowIfNull(function); diff --git a/dotnet/samples/Demos/HomeAutomation/Program.cs b/dotnet/samples/Demos/HomeAutomation/Program.cs index 3b8d1f009c2f..5b6f61cb5c2d 100644 --- a/dotnet/samples/Demos/HomeAutomation/Program.cs +++ b/dotnet/samples/Demos/HomeAutomation/Program.cs @@ -21,7 +21,6 @@ Example that demonstrates how to use Semantic Kernel in conjunction with depende using Microsoft.SemanticKernel.ChatCompletion; // For Azure OpenAI configuration #pragma warning disable IDE0005 // Using directive is unnecessary. -using Microsoft.SemanticKernel.Connectors.AzureOpenAI; using Microsoft.SemanticKernel.Connectors.OpenAI; namespace HomeAutomation; diff --git a/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Extensions/McpServerBuilderExtensions.cs b/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Extensions/McpServerBuilderExtensions.cs index 9aa7201f4c46..f65dc7e37bc1 100644 --- a/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Extensions/McpServerBuilderExtensions.cs +++ b/dotnet/samples/Demos/ModelContextProtocolClientServer/MCPServer/Extensions/McpServerBuilderExtensions.cs @@ -30,7 +30,7 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, Kernel { foreach (var function in plugin) { - builder.Services.AddSingleton(McpServerTool.Create(function.AsAIFunction(kernel))); + builder.Services.AddSingleton(McpServerTool.Create(function)); } } @@ -41,7 +41,6 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, Kernel builder.Services.AddSingleton>(services => { IEnumerable plugins = services.GetServices(); - Kernel kernel = services.GetRequiredService(); List tools = new(plugins.Count()); @@ -49,7 +48,7 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, Kernel { foreach (var function in plugin) { - tools.Add(McpServerTool.Create(function.AsAIFunction(kernel))); + tools.Add(McpServerTool.Create(function)); } } diff --git a/dotnet/samples/GettingStarted/GettingStarted.csproj b/dotnet/samples/GettingStarted/GettingStarted.csproj index c5c77c4238a2..56acafcd0ca0 100644 --- a/dotnet/samples/GettingStarted/GettingStarted.csproj +++ b/dotnet/samples/GettingStarted/GettingStarted.csproj @@ -35,6 +35,7 @@ all + diff --git a/dotnet/samples/GettingStartedWithAgents/GettingStartedWithAgents.csproj b/dotnet/samples/GettingStartedWithAgents/GettingStartedWithAgents.csproj index aaaaf937f644..a568246d1c85 100644 --- a/dotnet/samples/GettingStartedWithAgents/GettingStartedWithAgents.csproj +++ b/dotnet/samples/GettingStartedWithAgents/GettingStartedWithAgents.csproj @@ -17,6 +17,7 @@ + diff --git a/dotnet/samples/GettingStartedWithAgents/Step01_Agent.cs b/dotnet/samples/GettingStartedWithAgents/Step01_Agent.cs index 2a3d2a71a242..e62b869c0af1 100644 --- a/dotnet/samples/GettingStartedWithAgents/Step01_Agent.cs +++ b/dotnet/samples/GettingStartedWithAgents/Step01_Agent.cs @@ -22,8 +22,10 @@ public class Step01_Agent(ITestOutputHelper output) : BaseAgentsTest(output) /// Demonstrate the usage of where each invocation is /// a unique interaction with no conversation history between them. /// - [Fact] - public async Task UseSingleChatCompletionAgent() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task UseSingleChatCompletionAgent(bool useChatClient) { Kernel kernel = this.CreateKernelWithChatCompletion(); @@ -33,7 +35,7 @@ public async Task UseSingleChatCompletionAgent() { Name = ParrotName, Instructions = ParrotInstructions, - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient), }; // Respond to user input @@ -41,6 +43,8 @@ public async Task UseSingleChatCompletionAgent() await InvokeAgentAsync("I came, I saw, I conquered."); await InvokeAgentAsync("Practice makes perfect."); + chatClient?.Dispose(); + // Local function to invoke agent and display the conversation messages. async Task InvokeAgentAsync(string input) { @@ -138,8 +142,10 @@ async Task InvokeAgentAsync(string input) } } - [Fact] - public async Task UseTemplateForChatCompletionAgent() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task UseTemplateForChatCompletionAgent(bool useChatClient) { // Define the agent string generateStoryYaml = EmbeddedResource.Read("GenerateStory.yaml"); @@ -150,7 +156,7 @@ public async Task UseTemplateForChatCompletionAgent() ChatCompletionAgent agent = new(templateConfig, templateFactory) { - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient), Arguments = new() { { "topic", "Dog" }, @@ -169,6 +175,8 @@ await InvokeAgentAsync( { "length", "3" }, }); + chatClient?.Dispose(); + // Local function to invoke agent and display the conversation messages. async Task InvokeAgentAsync(KernelArguments? arguments = null) { diff --git a/dotnet/samples/GettingStartedWithAgents/Step02_Plugins.cs b/dotnet/samples/GettingStartedWithAgents/Step02_Plugins.cs index 20ecf13c22f5..442c22d82c15 100644 --- a/dotnet/samples/GettingStartedWithAgents/Step02_Plugins.cs +++ b/dotnet/samples/GettingStartedWithAgents/Step02_Plugins.cs @@ -13,14 +13,17 @@ namespace GettingStarted; /// public class Step02_Plugins(ITestOutputHelper output) : BaseAgentsTest(output) { - [Fact] - public async Task UseChatCompletionWithPlugin() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task UseChatCompletionWithPlugin(bool useChatClient) { // Define the agent ChatCompletionAgent agent = CreateAgentWithPlugin( plugin: KernelPluginFactory.CreateFromType(), instructions: "Answer questions about the menu.", - name: "Host"); + name: "Host", + useChatClient: useChatClient); /// Create the chat history thread to capture the agent interaction. AgentThread thread = new ChatHistoryAgentThread(); @@ -32,12 +35,15 @@ public async Task UseChatCompletionWithPlugin() await InvokeAgentAsync(agent, thread, "Thank you"); } - [Fact] - public async Task UseChatCompletionWithPluginEnumParameter() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task UseChatCompletionWithPluginEnumParameter(bool useChatClient) { // Define the agent ChatCompletionAgent agent = CreateAgentWithPlugin( - KernelPluginFactory.CreateFromType()); + KernelPluginFactory.CreateFromType(), + useChatClient: useChatClient); /// Create the chat history thread to capture the agent interaction. AgentThread thread = new ChatHistoryAgentThread(); @@ -46,8 +52,10 @@ public async Task UseChatCompletionWithPluginEnumParameter() await InvokeAgentAsync(agent, thread, "Create a beautiful red colored widget for me."); } - [Fact] - public async Task UseChatCompletionWithTemplateExecutionSettings() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task UseChatCompletionWithTemplateExecutionSettings(bool useChatClient) { // Read the template resource string autoInvokeYaml = EmbeddedResource.Read("AutoInvokeTools.yaml"); @@ -59,7 +67,7 @@ public async Task UseChatCompletionWithTemplateExecutionSettings() ChatCompletionAgent agent = new(templateConfig, templateFactory) { - Kernel = this.CreateKernelWithChatCompletion() + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient), }; agent.Kernel.Plugins.AddFromType(); @@ -69,19 +77,22 @@ public async Task UseChatCompletionWithTemplateExecutionSettings() // Respond to user input, invoking functions where appropriate. await InvokeAgentAsync(agent, thread, "Create a beautiful red colored widget for me."); + + chatClient?.Dispose(); } private ChatCompletionAgent CreateAgentWithPlugin( KernelPlugin plugin, string? instructions = null, - string? name = null) + string? name = null, + bool useChatClient = false) { ChatCompletionAgent agent = new() { Instructions = instructions, Name = name, - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out _), Arguments = new KernelArguments(new PromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }), }; diff --git a/dotnet/samples/GettingStartedWithAgents/Step03_Chat.cs b/dotnet/samples/GettingStartedWithAgents/Step03_Chat.cs index 637151780ee2..5f4c84b7ac1d 100644 --- a/dotnet/samples/GettingStartedWithAgents/Step03_Chat.cs +++ b/dotnet/samples/GettingStartedWithAgents/Step03_Chat.cs @@ -33,8 +33,10 @@ Only provide a single proposal per response. Consider suggestions when refining an idea. """; - [Fact] - public async Task UseAgentGroupChatWithTwoAgents() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task UseAgentGroupChatWithTwoAgents(bool useChatClient) { // Define the agents ChatCompletionAgent agentReviewer = @@ -42,7 +44,7 @@ public async Task UseAgentGroupChatWithTwoAgents() { Instructions = ReviewerInstructions, Name = ReviewerName, - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient1), }; ChatCompletionAgent agentWriter = @@ -50,7 +52,7 @@ public async Task UseAgentGroupChatWithTwoAgents() { Instructions = CopyWriterInstructions, Name = CopyWriterName, - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient2), }; // Create a chat for agent interaction. @@ -84,6 +86,9 @@ public async Task UseAgentGroupChatWithTwoAgents() } Console.WriteLine($"\n[IS COMPLETED: {chat.IsComplete}]"); + + chatClient1?.Dispose(); + chatClient2?.Dispose(); } private sealed class ApprovalTerminationStrategy : TerminationStrategy diff --git a/dotnet/samples/GettingStartedWithAgents/Step04_KernelFunctionStrategies.cs b/dotnet/samples/GettingStartedWithAgents/Step04_KernelFunctionStrategies.cs index 97c28b498803..d6c581a4366a 100644 --- a/dotnet/samples/GettingStartedWithAgents/Step04_KernelFunctionStrategies.cs +++ b/dotnet/samples/GettingStartedWithAgents/Step04_KernelFunctionStrategies.cs @@ -34,8 +34,10 @@ Never delimit the response with quotation marks. Consider suggestions when refining an idea. """; - [Fact] - public async Task UseKernelFunctionStrategiesWithAgentGroupChat() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task UseKernelFunctionStrategiesWithAgentGroupChat(bool useChatClient) { // Define the agents ChatCompletionAgent agentReviewer = @@ -43,7 +45,7 @@ public async Task UseKernelFunctionStrategiesWithAgentGroupChat() { Instructions = ReviewerInstructions, Name = ReviewerName, - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient1), }; ChatCompletionAgent agentWriter = @@ -51,7 +53,7 @@ public async Task UseKernelFunctionStrategiesWithAgentGroupChat() { Instructions = CopyWriterInstructions, Name = CopyWriterName, - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient2), }; KernelFunction terminationFunction = @@ -139,5 +141,8 @@ No participant should take more than one turn in a row. } Console.WriteLine($"\n[IS COMPLETED: {chat.IsComplete}]"); + + chatClient1?.Dispose(); + chatClient2?.Dispose(); } } diff --git a/dotnet/samples/GettingStartedWithAgents/Step05_JsonResult.cs b/dotnet/samples/GettingStartedWithAgents/Step05_JsonResult.cs index 8b7df10b9244..9fad5413bccc 100644 --- a/dotnet/samples/GettingStartedWithAgents/Step05_JsonResult.cs +++ b/dotnet/samples/GettingStartedWithAgents/Step05_JsonResult.cs @@ -27,8 +27,10 @@ Think step-by-step and rate the user input on creativity and expressiveness from } """; - [Fact] - public async Task UseKernelFunctionStrategiesWithJsonResult() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task UseKernelFunctionStrategiesWithJsonResult(bool useChatClient) { // Define the agents ChatCompletionAgent agent = @@ -36,7 +38,7 @@ public async Task UseKernelFunctionStrategiesWithJsonResult() { Instructions = TutorInstructions, Name = TutorName, - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient), }; // Create a chat for agent interaction. @@ -57,6 +59,8 @@ public async Task UseKernelFunctionStrategiesWithJsonResult() await InvokeAgentAsync("The sunset is setting over the mountains."); await InvokeAgentAsync("The sunset is setting over the mountains and filled the sky with a deep red flame, setting the clouds ablaze."); + chatClient?.Dispose(); + // Local function to invoke agent and display the conversation messages. async Task InvokeAgentAsync(string input) { diff --git a/dotnet/samples/GettingStartedWithAgents/Step06_DependencyInjection.cs b/dotnet/samples/GettingStartedWithAgents/Step06_DependencyInjection.cs index ad2476d83f8b..07d8258a653e 100644 --- a/dotnet/samples/GettingStartedWithAgents/Step06_DependencyInjection.cs +++ b/dotnet/samples/GettingStartedWithAgents/Step06_DependencyInjection.cs @@ -1,5 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System.ClientModel; +using Azure.AI.OpenAI; using Azure.Identity; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel; @@ -26,25 +29,66 @@ Think step-by-step and rate the user input on creativity and expressiveness from } """; - [Fact] - public async Task UseDependencyInjectionToCreateAgent() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task UseDependencyInjectionToCreateAgent(bool useChatClient) { ServiceCollection serviceContainer = new(); serviceContainer.AddLogging(c => c.AddConsole().SetMinimumLevel(LogLevel.Information)); - if (this.UseOpenAIConfig) + if (useChatClient) { - serviceContainer.AddOpenAIChatCompletion( - TestConfiguration.OpenAI.ChatModelId, - TestConfiguration.OpenAI.ApiKey); + IChatClient chatClient; + if (this.UseOpenAIConfig) + { + chatClient = new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey) + .GetChatClient(TestConfiguration.OpenAI.ChatModelId) + .AsIChatClient(); + } + else if (!string.IsNullOrEmpty(this.ApiKey)) + { + chatClient = new AzureOpenAIClient( + endpoint: new Uri(TestConfiguration.AzureOpenAI.Endpoint), + credential: new ApiKeyCredential(TestConfiguration.AzureOpenAI.ApiKey)) + .GetChatClient(TestConfiguration.OpenAI.ChatModelId) + .AsIChatClient(); + } + else + { + chatClient = new AzureOpenAIClient( + endpoint: new Uri(TestConfiguration.AzureOpenAI.Endpoint), + credential: new AzureCliCredential()) + .GetChatClient(TestConfiguration.OpenAI.ChatModelId) + .AsIChatClient(); + } + + var functionCallingChatClient = chatClient!.AsBuilder().UseKernelFunctionInvocation().Build(); + serviceContainer.AddTransient((sp) => functionCallingChatClient); } else { - serviceContainer.AddAzureOpenAIChatCompletion( - TestConfiguration.AzureOpenAI.ChatDeploymentName, - TestConfiguration.AzureOpenAI.Endpoint, - new AzureCliCredential()); + if (this.UseOpenAIConfig) + { + serviceContainer.AddOpenAIChatCompletion( + TestConfiguration.OpenAI.ChatModelId, + TestConfiguration.OpenAI.ApiKey); + } + else if (!string.IsNullOrEmpty(this.ApiKey)) + { + serviceContainer.AddAzureOpenAIChatCompletion( + TestConfiguration.AzureOpenAI.ChatDeploymentName, + TestConfiguration.AzureOpenAI.Endpoint, + TestConfiguration.AzureOpenAI.ApiKey); + } + else + { + serviceContainer.AddAzureOpenAIChatCompletion( + TestConfiguration.AzureOpenAI.ChatDeploymentName, + TestConfiguration.AzureOpenAI.Endpoint, + new AzureCliCredential()); + } } // Transient Kernel as each agent may customize its Kernel instance with plug-ins. diff --git a/dotnet/samples/GettingStartedWithAgents/Step07_Telemetry.cs b/dotnet/samples/GettingStartedWithAgents/Step07_Telemetry.cs index b6da6dd6cbb4..0fa19b93cbf1 100644 --- a/dotnet/samples/GettingStartedWithAgents/Step07_Telemetry.cs +++ b/dotnet/samples/GettingStartedWithAgents/Step07_Telemetry.cs @@ -31,10 +31,12 @@ public class Step07_Telemetry(ITestOutputHelper output) : BaseAssistantTest(outp /// Logging is enabled through the and properties. /// This example uses to output logs to the test console, but any compatible logging provider can be used. /// - [Fact] - public async Task Logging() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task Logging(bool useChatClient) { - await RunExampleAsync(loggerFactory: this.LoggerFactory); + await RunExampleAsync(loggerFactory: this.LoggerFactory, useChatClient: useChatClient); // Output: // [AddChatMessages] Adding Messages: 1. @@ -51,18 +53,22 @@ public async Task Logging() /// For output this example uses Console as well as Application Insights. /// [Theory] - [InlineData(true, false)] - [InlineData(false, false)] - [InlineData(true, true)] - [InlineData(false, true)] - public async Task Tracing(bool useApplicationInsights, bool useStreaming) + [InlineData(true, false, false)] + [InlineData(false, false, false)] + [InlineData(true, true, false)] + [InlineData(false, true, false)] + [InlineData(true, false, true)] + [InlineData(false, false, true)] + [InlineData(true, true, true)] + [InlineData(false, true, true)] + public async Task Tracing(bool useApplicationInsights, bool useStreaming, bool useChatClient) { using var tracerProvider = GetTracerProvider(useApplicationInsights); using var activity = s_activitySource.StartActivity("MainActivity"); Console.WriteLine($"Operation/Trace ID: {Activity.Current?.TraceId}"); - await RunExampleAsync(useStreaming: useStreaming); + await RunExampleAsync(useStreaming: useStreaming, useChatClient: useChatClient); // Output: // Operation/Trace ID: 132d831ef39c13226cdaa79873f375b8 @@ -82,7 +88,8 @@ public async Task Tracing(bool useApplicationInsights, bool useStreaming) private async Task RunExampleAsync( bool useStreaming = false, - ILoggerFactory? loggerFactory = null) + ILoggerFactory? loggerFactory = null, + bool useChatClient = false) { // Define the agents ChatCompletionAgent agentReviewer = @@ -97,7 +104,7 @@ private async Task RunExampleAsync( If not, provide insight on how to refine suggested copy without examples. """, Description = "An art director who has opinions about copywriting born of a love for David Ogilvy", - Kernel = this.CreateKernelWithChatCompletion(), + Kernel = this.CreateKernelWithChatCompletion(useChatClient, out var chatClient), LoggerFactory = GetLoggerFactoryOrDefault(loggerFactory), }; @@ -190,6 +197,8 @@ Consider suggestions when refining an idea. } Console.WriteLine($"\n[IS COMPLETED: {chat.IsComplete}]"); + + chatClient?.Dispose(); } private TracerProvider? GetTracerProvider(bool useApplicationInsights) diff --git a/dotnet/samples/GettingStartedWithProcesses/GettingStartedWithProcesses.csproj b/dotnet/samples/GettingStartedWithProcesses/GettingStartedWithProcesses.csproj index 522a1b1571ab..5506e1539d00 100644 --- a/dotnet/samples/GettingStartedWithProcesses/GettingStartedWithProcesses.csproj +++ b/dotnet/samples/GettingStartedWithProcesses/GettingStartedWithProcesses.csproj @@ -18,6 +18,7 @@ + diff --git a/dotnet/samples/LearnResources/LearnResources.csproj b/dotnet/samples/LearnResources/LearnResources.csproj index 398e4883a6a1..9b3ca218e7f0 100644 --- a/dotnet/samples/LearnResources/LearnResources.csproj +++ b/dotnet/samples/LearnResources/LearnResources.csproj @@ -34,6 +34,7 @@ all + diff --git a/dotnet/src/Agents/Core/ChatCompletionAgent.cs b/dotnet/src/Agents/Core/ChatCompletionAgent.cs index 0fbcc3a8a198..8401b230df87 100644 --- a/dotnet/src/Agents/Core/ChatCompletionAgent.cs +++ b/dotnet/src/Agents/Core/ChatCompletionAgent.cs @@ -13,7 +13,6 @@ using Microsoft.SemanticKernel.Arguments.Extensions; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Diagnostics; -using Microsoft.SemanticKernel.Services; namespace Microsoft.SemanticKernel.Agents; @@ -221,13 +220,44 @@ protected override Task RestoreChannelAsync(string channelState, C internal static (IChatCompletionService service, PromptExecutionSettings? executionSettings) GetChatCompletionService(Kernel kernel, KernelArguments? arguments) { - (IChatCompletionService chatCompletionService, PromptExecutionSettings? executionSettings) = - kernel.ServiceSelector.SelectAIService( - kernel, - arguments?.ExecutionSettings, - arguments ?? []); + // Need to provide a KernelFunction to the service selector as a container for the execution-settings. + KernelFunction nullPrompt = KernelFunctionFactory.CreateFromPrompt("placeholder", arguments?.ExecutionSettings?.Values); + + kernel.ServiceSelector.TrySelectAIService(kernel, nullPrompt, arguments ?? [], out IChatCompletionService? chatCompletionService, out PromptExecutionSettings? executionSettings); + +#pragma warning disable CA2000 // Dispose objects before losing scope + if (chatCompletionService is null + && kernel.ServiceSelector is IChatClientSelector chatClientSelector + && chatClientSelector.TrySelectChatClient(kernel, nullPrompt, arguments ?? [], out var chatClient, out executionSettings) + && chatClient is not null) + { + // This change is temporary until Agents support IChatClient natively in near future. + chatCompletionService = chatClient!.AsChatCompletionService(); + } +#pragma warning restore CA2000 // Dispose objects before losing scope + + if (chatCompletionService is null) + { + var message = new StringBuilder().Append("No service was found for any of the supported types: ").Append(typeof(IChatCompletionService)).Append(", ").Append(typeof(Microsoft.Extensions.AI.IChatClient)).Append('.'); + if (nullPrompt.ExecutionSettings is not null) + { + string serviceIds = string.Join("|", nullPrompt.ExecutionSettings.Keys); + if (!string.IsNullOrEmpty(serviceIds)) + { + message.Append(" Expected serviceIds: ").Append(serviceIds).Append('.'); + } + + string modelIds = string.Join("|", nullPrompt.ExecutionSettings.Values.Select(model => model.ModelId)); + if (!string.IsNullOrEmpty(modelIds)) + { + message.Append(" Expected modelIds: ").Append(modelIds).Append('.'); + } + } + + throw new KernelException(message.ToString()); + } - return (chatCompletionService, executionSettings); + return (chatCompletionService!, executionSettings); } #region private diff --git a/dotnet/src/Agents/OpenAI/Agents.OpenAI.csproj b/dotnet/src/Agents/OpenAI/Agents.OpenAI.csproj index 4a29c6e5de28..b0d62880152b 100644 --- a/dotnet/src/Agents/OpenAI/Agents.OpenAI.csproj +++ b/dotnet/src/Agents/OpenAI/Agents.OpenAI.csproj @@ -27,7 +27,7 @@ - + diff --git a/dotnet/src/Agents/UnitTests/AzureAI/Extensions/KernelFunctionExtensionsTests.cs b/dotnet/src/Agents/UnitTests/AzureAI/Extensions/KernelFunctionExtensionsTests.cs index 298d22eae529..857c660193b9 100644 --- a/dotnet/src/Agents/UnitTests/AzureAI/Extensions/KernelFunctionExtensionsTests.cs +++ b/dotnet/src/Agents/UnitTests/AzureAI/Extensions/KernelFunctionExtensionsTests.cs @@ -9,7 +9,7 @@ namespace SemanticKernel.Agents.UnitTests.OpeAzureAInAI.Extensions; /// -/// Unit testing of . +/// Unit testing of . /// public class KernelFunctionExtensionsTests { diff --git a/dotnet/src/Agents/UnitTests/Core/ChatCompletionAgentTests.cs b/dotnet/src/Agents/UnitTests/Core/ChatCompletionAgentTests.cs index 69ac7dce6c6d..ffcc39625205 100644 --- a/dotnet/src/Agents/UnitTests/Core/ChatCompletionAgentTests.cs +++ b/dotnet/src/Agents/UnitTests/Core/ChatCompletionAgentTests.cs @@ -1,8 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Collections.Generic; using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Agents; @@ -173,6 +175,43 @@ public async Task VerifyChatCompletionAgentInvocationAsync() Times.Once); } + /// + /// Verify the invocation and response of using . + /// + [Fact] + public async Task VerifyChatClientAgentInvocationAsync() + { + // Arrange + Mock mockService = new(); + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "what?")])); + + ChatCompletionAgent agent = + new() + { + Instructions = "test instructions", + Kernel = CreateKernel(mockService.Object), + Arguments = [], + }; + + // Act + ChatMessageContent[] result = await agent.InvokeAsync([]).ToArrayAsync(); + + // Assert + Assert.Single(result); + + mockService.Verify( + x => + x.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny()), + Times.Once); + } + /// /// Verify the streaming invocation and response of . /// @@ -218,6 +257,49 @@ public async Task VerifyChatCompletionAgentStreamingAsync() Times.Once); } + /// + /// Verify the streaming invocation and response of using . + /// + [Fact] + public async Task VerifyChatClientAgentStreamingAsync() + { + // Arrange + ChatResponseUpdate[] returnUpdates = + [ + new ChatResponseUpdate(role: ChatRole.Assistant, content: "wh"), + new ChatResponseUpdate(role: null, content: "at?"), + ]; + + Mock mockService = new(); + mockService.Setup( + s => s.GetStreamingResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())).Returns(returnUpdates.ToAsyncEnumerable()); + + ChatCompletionAgent agent = + new() + { + Instructions = "test instructions", + Kernel = CreateKernel(mockService.Object), + Arguments = [], + }; + + // Act + StreamingChatMessageContent[] result = await agent.InvokeStreamingAsync([]).ToArrayAsync(); + + // Assert + Assert.Equal(2, result.Length); + + mockService.Verify( + x => + x.GetStreamingResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny()), + Times.Once); + } + /// /// Verify the invocation and response of . /// @@ -244,6 +326,32 @@ public void VerifyChatCompletionServiceSelection() Assert.Throws(() => ChatCompletionAgent.GetChatCompletionService(kernel, new KernelArguments(new PromptExecutionSettings() { ServiceId = "anything" }))); } + /// + /// Verify the invocation and response of using . + /// + [Fact] + public void VerifyChatClientSelection() + { + // Arrange + Mock mockClient = new(); + Kernel kernel = CreateKernel(mockClient.Object); + + // Act + (IChatCompletionService client, PromptExecutionSettings? settings) = ChatCompletionAgent.GetChatCompletionService(kernel, null); + // Assert + Assert.Equal("ChatClientChatCompletionService", client.GetType().Name); + Assert.Null(settings); + + // Act + (client, settings) = ChatCompletionAgent.GetChatCompletionService(kernel, []); + // Assert + Assert.Equal("ChatClientChatCompletionService", client.GetType().Name); + Assert.Null(settings); + + // Act and Assert + Assert.Throws(() => ChatCompletionAgent.GetChatCompletionService(kernel, new KernelArguments(new PromptExecutionSettings() { ServiceId = "anything" }))); + } + /// /// Verify the invocation and response of . /// @@ -270,4 +378,11 @@ private static Kernel CreateKernel(IChatCompletionService chatCompletionService) builder.Services.AddSingleton(chatCompletionService); return builder.Build(); } + + private static Kernel CreateKernel(IChatClient chatClient) + { + var builder = Kernel.CreateBuilder(); + builder.Services.AddSingleton(chatClient); + return builder.Build(); + } } diff --git a/dotnet/src/Agents/UnitTests/OpenAI/Extensions/KernelFunctionExtensionsTests.cs b/dotnet/src/Agents/UnitTests/OpenAI/Extensions/KernelFunctionExtensionsTests.cs index 3710b4841ab3..51cfd5c39808 100644 --- a/dotnet/src/Agents/UnitTests/OpenAI/Extensions/KernelFunctionExtensionsTests.cs +++ b/dotnet/src/Agents/UnitTests/OpenAI/Extensions/KernelFunctionExtensionsTests.cs @@ -9,7 +9,7 @@ namespace SemanticKernel.Agents.UnitTests.OpenAI.Extensions; /// -/// Unit testing of . +/// Unit testing of . /// public class KernelFunctionExtensionsTests { diff --git a/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Connectors.Amazon.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Connectors.Amazon.UnitTests.csproj index 496a127ebbbd..250dd6b7b94f 100644 --- a/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Connectors.Amazon.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Connectors.Amazon.UnitTests.csproj @@ -9,7 +9,7 @@ enable - $(NoWarn);CS1591;SKEXP0001;SKEXP0070 + $(NoWarn);CS1591;CA2007;VSTHRD111;SKEXP0001;SKEXP0070 diff --git a/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Extensions/BedrockServiceCollectionExtensionTests.cs b/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Extensions/BedrockServiceCollectionExtensionTests.cs index 022c1fe9eb4a..7a37b7c4e941 100644 --- a/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Extensions/BedrockServiceCollectionExtensionTests.cs +++ b/dotnet/src/Connectors/Connectors.Amazon.UnitTests/Extensions/BedrockServiceCollectionExtensionTests.cs @@ -1,9 +1,19 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Text.Json.Serialization; +using System.Threading; +using System.Threading.Tasks; using Amazon.BedrockRuntime; +using Amazon.BedrockRuntime.Model; using Amazon.Runtime; +using Amazon.Runtime.Endpoints; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.Amazon.Core; using Microsoft.SemanticKernel.Embeddings; @@ -16,8 +26,18 @@ namespace Microsoft.SemanticKernel.Connectors.Amazon.UnitTests; /// /// Unit tests for the BedrockServiceCollectionExtension class. /// -public class BedrockServiceCollectionExtensionTests +public sealed class BedrockServiceCollectionExtensionTests : IDisposable { + private readonly Mock _mockLoggerFactory; + private readonly Mock> _mockLogger; + + public BedrockServiceCollectionExtensionTests() + { + this._mockLoggerFactory = new Mock(); + this._mockLogger = new Mock>(); + this._mockLoggerFactory.Setup(lf => lf.CreateLogger(It.IsAny())).Returns(this._mockLogger.Object); + this._mockLogger.Setup(l => l.IsEnabled(It.IsAny())).Returns(true); + } /// /// Ensures that IServiceCollection.AddBedrockChatCompletionService registers the with the correct implementation. /// @@ -94,4 +114,151 @@ public void AwsServiceClientBeforeServiceRequestDoesNothingForNonWebServiceReque // Assert // No exceptions should be thrown } + + [Fact] + public async Task ChatClientUsesOpenTelemetrySourceNameAsync() + { + // Arrange + string customSourceName = "CustomSourceName"; + bool correctSourceNameUsed = false; + bool configCallbackInvoked = false; + var services = new ServiceCollection(); + var modelId = "amazon.titan-text-v2:0"; + + // Arrange + var mockBedrockApi = new Mock(); + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(this.CreateConverseResponse("Hello, world!", ConversationRole.Assistant)); + var bedrockRuntime = mockBedrockApi.Object; + + // Set up an ActivityListener to capture the activity events + using var activityListener = new ActivityListener + { + ShouldListenTo = activitySource => activitySource.Name == customSourceName, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, + ActivityStarted = activity => correctSourceNameUsed = true + }; + + ActivitySource.AddActivityListener(activityListener); + + var builder = Kernel.CreateBuilder(); + builder.Services.AddSingleton(this._mockLoggerFactory.Object); + builder.AddBedrockChatClient( + modelId: modelId, + bedrockRuntime: bedrockRuntime, + openTelemetrySourceName: customSourceName, + openTelemetryConfig: _ => configCallbackInvoked = true); + var kernel = builder.Build(); + + var sut = kernel.GetRequiredService(); + + // Act + var result = await sut.GetResponseAsync([]); + + // Assert + Assert.True(correctSourceNameUsed, "The custom OpenTelemetry source name should have been used"); + Assert.True(configCallbackInvoked, "The OpenTelemetry config callback should have been invoked"); + } + + [Fact] + public async Task EmbeddingGeneratorUsesOpenTelemetrySourceNameAsync() + { + // Arrange + string customSourceName = "CustomSourceName"; + bool correctSourceNameUsed = false; + bool configCallbackInvoked = false; + var services = new ServiceCollection(); + var modelId = "amazon.titan-embed-text-v2:0"; + + // Arrange + var mockBedrockApi = new Mock(); + mockBedrockApi.Setup(m => m.DetermineServiceOperationEndpoint(It.IsAny())) + .Returns(new Endpoint("https://bedrock-runtime.us-east-1.amazonaws.com") + { + URL = "https://bedrock-runtime.us-east-1.amazonaws.com" + }); + mockBedrockApi.Setup(m => m.InvokeModelAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(this.CreateEmbeddingInvokeResponse([0.1f, 0.2f, 0.3f])); + var bedrockRuntime = mockBedrockApi.Object; + + // Set up an ActivityListener to capture the activity events + using var activityListener = new ActivityListener + { + ShouldListenTo = activitySource => activitySource.Name == customSourceName, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllData, + ActivityStarted = activity => correctSourceNameUsed = true + }; + + ActivitySource.AddActivityListener(activityListener); + + var builder = Kernel.CreateBuilder(); + builder.Services.AddSingleton(this._mockLoggerFactory.Object); + builder.AddBedrockEmbeddingGenerator( + modelId: modelId, + bedrockRuntime: bedrockRuntime, + openTelemetrySourceName: customSourceName, + openTelemetryConfig: _ => configCallbackInvoked = true); + var kernel = builder.Build(); + + var sut = kernel.GetRequiredService>>(); + + // Act + var result = await sut.GenerateAsync(["test"]); + + // Assert + Assert.True(correctSourceNameUsed, "The custom OpenTelemetry source name should have been used"); + Assert.True(configCallbackInvoked, "The OpenTelemetry config callback should have been invoked"); + } + + public void Dispose() + { + // Disable OpenTelemetry diagnostics after tests + AppContext.SetSwitch("Microsoft.SemanticKernel.Experimental.GenAI.EnableOTelDiagnostics", false); + } + + private ConverseResponse CreateConverseResponse(string text, ConversationRole role) + { + return new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = role, + Content = new List { new() { Text = text } } + } + }, + Metrics = new ConverseMetrics(), + StopReason = StopReason.Max_tokens, + Usage = new TokenUsage() + }; + } + + private InvokeModelResponse CreateEmbeddingInvokeResponse(float[] embedding) + { + var memoryStream = new MemoryStream(System.Text.Json.JsonSerializer.SerializeToUtf8Bytes(new EmbeddingResponse() + { + Embedding = embedding, + InputTextTokenCount = embedding.Length + })); + + return new InvokeModelResponse + { + Body = memoryStream + }; + } + + private sealed class EmbeddingResponse + { + [JsonPropertyName("embedding")] + public float[]? Embedding { get; set; } + + [JsonPropertyName("inputTextTokenCount")] + public int? InputTextTokenCount { get; set; } + } } diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Extensions/BedrockKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Extensions/BedrockKernelBuilderExtensions.cs index efdb45bd2138..40cd1e32228c 100644 --- a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Extensions/BedrockKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Extensions/BedrockKernelBuilderExtensions.cs @@ -2,17 +2,19 @@ using System; using Amazon.BedrockRuntime; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.SemanticKernel; /// -/// Extensions for adding Bedrock modality services to the kernel builder configuration. +/// Extensions for adding Bedrock modality services to the configuration. /// public static class BedrockKernelBuilderExtensions { /// - /// Add Amazon Bedrock Chat Completion service to the kernel builder using IAmazonBedrockRuntime object. + /// Add Amazon Bedrock to the using object. /// /// The kernel builder. /// The model for chat completion. @@ -32,8 +34,31 @@ public static IKernelBuilder AddBedrockChatCompletionService( return builder; } + /// Add Amazon Bedrock to the . + /// The service collection. + /// The model for chat completion. + /// The optional to use. If not provided will be retrieved from the Service Collection. + /// The optional service ID. + /// An optional source name that will be used on the telemetry data. + /// An optional callback that can be used to configure the instance. + /// Returns back with a configured . + public static IKernelBuilder AddBedrockChatClient( + this IKernelBuilder builder, + string modelId, + IAmazonBedrockRuntime? bedrockRuntime = null, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(builder); + + builder.Services.AddBedrockChatClient(modelId, bedrockRuntime, serviceId, openTelemetrySourceName, openTelemetryConfig); + + return builder; + } + /// - /// Add Amazon Bedrock Text Generation service to the kernel builder using IAmazonBedrockRuntime object. + /// Add Amazon Bedrock Text Generation service to the using object. /// /// The kernel builder. /// The model for text generation. @@ -54,10 +79,10 @@ public static IKernelBuilder AddBedrockTextGenerationService( } /// - /// Add Amazon Bedrock Text Generation service to the kernel builder using IAmazonBedrockRuntime object. + /// Add Amazon Bedrock Text Generation service to the using object. /// /// The kernel builder. - /// The model for text generation. + /// The model for embedding generation. /// The optional to use. If not provided will be retrieved from the Service Collection. /// The optional service ID. /// Returns back with a configured service. @@ -76,22 +101,26 @@ public static IKernelBuilder AddBedrockTextEmbeddingGenerationService( } /// - /// Add Amazon Bedrock Text Generation service to the kernel builder using IAmazonBedrockRuntime object. + /// Add Amazon Bedrock to the using object. /// /// The kernel builder. - /// The model for text generation. + /// The model for embedding generation. /// The optional to use. If not provided will be retrieved from the Service Collection. /// The optional service ID. - /// Returns back with a configured service. + /// An optional source name that will be used on the telemetry data. + /// An optional callback that can be used to configure the instance. + /// Returns back with a configured . public static IKernelBuilder AddBedrockEmbeddingGenerator( this IKernelBuilder builder, string modelId, IAmazonBedrockRuntime? bedrockRuntime = null, - string? serviceId = null) + string? serviceId = null, + string? openTelemetrySourceName = null, + Action>>? openTelemetryConfig = null) { Verify.NotNull(builder); - builder.Services.AddBedrockEmbeddingGenerator(modelId, bedrockRuntime, serviceId); + builder.Services.AddBedrockEmbeddingGenerator(modelId, bedrockRuntime, serviceId, openTelemetrySourceName, openTelemetryConfig); return builder; } diff --git a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Extensions/BedrockServiceCollectionExtensions.DependencyInjection.cs b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Extensions/BedrockServiceCollectionExtensions.DependencyInjection.cs index 4d3ddbf2e2dd..8cd061ce5070 100644 --- a/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Extensions/BedrockServiceCollectionExtensions.DependencyInjection.cs +++ b/dotnet/src/Connectors/Connectors.Amazon/Bedrock/Extensions/BedrockServiceCollectionExtensions.DependencyInjection.cs @@ -15,21 +15,89 @@ namespace Microsoft.Extensions.DependencyInjection; /// public static class BedrockServiceCollectionExtensions { + /// Add Amazon Bedrock to the . + /// The service collection. + /// The model for chat completion. + /// The optional to use. If not provided will be retrieved from the Service Collection. + /// The optional service ID. + /// An optional source name that will be used on the telemetry data. + /// An optional callback that can be used to configure the instance. + /// Returns back with a configured service. + public static IServiceCollection AddBedrockChatClient( + this IServiceCollection services, + string modelId, + IAmazonBedrockRuntime? bedrockRuntime = null, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(modelId); + + if (bedrockRuntime is null) + { + // Add IAmazonBedrockRuntime service client to the DI container + services.TryAddAWSService(); + } + + services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + { + try + { + IAmazonBedrockRuntime runtime = bedrockRuntime ?? serviceProvider.GetRequiredService(); + var loggerFactory = serviceProvider.GetService(); + // Check if the runtime instance is a proxy object + if (runtime.GetType().BaseType == typeof(AmazonServiceClient)) + { + // Cast to AmazonServiceClient and subscribe to the event + ((AmazonServiceClient)runtime).BeforeRequestEvent += BedrockClientUtilities.BedrockServiceClientRequestHandler; + } + var builder = runtime + .AsIChatClient(modelId) + .AsBuilder(); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + builder.UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + return builder + .UseKernelFunctionInvocation(loggerFactory) + .Build(serviceProvider); + } + catch (Exception ex) + { + throw new KernelException($"An error occurred while initializing the Bedrock {nameof(IChatClient)}: {ex.Message}", ex); + } + }); + + return services; + } + /// - /// Add Amazon Bedrock Embedding Generator service to the . + /// Add Amazon Bedrock to the . /// /// The service collection. - /// The model for text generation. + /// The model for embedding generation. /// The optional to use. If not provided will be retrieved from the Service Collection. /// The optional service ID. - /// Returns back with a configured service. + /// An optional source name that will be used on the telemetry data. + /// An optional callback that can be used to configure the instance. + /// Returns back with a configured . public static IServiceCollection AddBedrockEmbeddingGenerator( this IServiceCollection services, string modelId, IAmazonBedrockRuntime? bedrockRuntime = null, - string? serviceId = null) + string? serviceId = null, + string? openTelemetrySourceName = null, + Action>>? openTelemetryConfig = null) { - if (bedrockRuntime == null) + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(modelId); + + if (bedrockRuntime is null) { // Add IAmazonBedrockRuntime service client to the DI container services.TryAddAWSService(); @@ -47,15 +115,16 @@ public static IServiceCollection AddBedrockEmbeddingGenerator( ((AmazonServiceClient)runtime).BeforeRequestEvent += BedrockClientUtilities.BedrockServiceClientRequestHandler; } - var embeddingGenerator = runtime.AsIEmbeddingGenerator(modelId); + var builder = runtime.AsIEmbeddingGenerator(modelId).AsBuilder(); + if (loggerFactory is not null) { - embeddingGenerator = embeddingGenerator - .AsBuilder() - .UseLogging(loggerFactory) - .Build(); + builder.UseLogging(loggerFactory); } - return embeddingGenerator; + + builder.UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + return builder.Build(serviceProvider); } catch (Exception ex) { diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceKernelBuilderExtensionsTests.cs index 5f1e784f2c72..9cfeee361e1b 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceKernelBuilderExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceKernelBuilderExtensionsTests.cs @@ -3,6 +3,7 @@ using System; using Azure; using Azure.AI.Inference; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; @@ -39,6 +40,33 @@ public void KernelBuilderAddAzureAIInferenceChatCompletionAddsValidService(Initi Assert.Equal("ChatClientChatCompletionService", chatCompletionService.GetType().Name); } + [Theory] + [InlineData(InitializationType.ApiKey)] + [InlineData(InitializationType.BreakingGlassClientInline)] + [InlineData(InitializationType.BreakingGlassInServiceProvider)] + public void KernelBuilderAddAzureAIInferenceChatClientAddsValidService(InitializationType type) + { + // Arrange + var client = new ChatCompletionsClient(this._endpoint, new AzureKeyCredential("key")); + var builder = Kernel.CreateBuilder(); + + builder.Services.AddSingleton(client); + + // Act + builder = type switch + { + InitializationType.ApiKey => builder.AddAzureAIInferenceChatClient("model-id", "api-key", this._endpoint), + InitializationType.BreakingGlassClientInline => builder.AddAzureAIInferenceChatClient("model-id", client), + InitializationType.BreakingGlassInServiceProvider => builder.AddAzureAIInferenceChatClient("model-id", chatClient: null), + _ => builder + }; + + // Assert + var sut = builder.Build().GetRequiredService(); + Assert.NotNull(sut); + Assert.Equal("KernelFunctionInvokingChatClient", sut.GetType().Name); + } + public enum InitializationType { ApiKey, diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceServiceCollectionExtensionsTests.cs index 6fe1283270bb..53dd4aeb0995 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference.UnitTests/Extensions/AzureAIInferenceServiceCollectionExtensionsTests.cs @@ -49,6 +49,32 @@ public void ItCanAddChatCompletionService(InitializationType type) Assert.Equal("ChatClientChatCompletionService", chatCompletionService.GetType().Name); } + [Theory] + [InlineData(InitializationType.ApiKey)] + [InlineData(InitializationType.ClientInline)] + [InlineData(InitializationType.ClientInServiceProvider)] + public void ItCanAddChatClientService(InitializationType type) + { + // Arrange + var client = new ChatCompletionsClient(this._endpoint, new AzureKeyCredential("key")); + var builder = Kernel.CreateBuilder(); + + builder.Services.AddSingleton(client); + + IServiceCollection collection = type switch + { + InitializationType.ApiKey => builder.Services.AddAzureAIInferenceChatClient("modelId", "api-key", this._endpoint), + InitializationType.ClientInline => builder.Services.AddAzureAIInferenceChatClient("modelId", client), + InitializationType.ClientInServiceProvider => builder.Services.AddAzureAIInferenceChatClient("modelId", chatClient: null), + _ => builder.Services + }; + + // Act & Assert + var sut = builder.Build().GetRequiredService(); + Assert.NotNull(sut); + Assert.Equal("KernelFunctionInvokingChatClient", sut.GetType().Name); + } + public enum InitializationType { ApiKey, @@ -84,10 +110,50 @@ public async Task ItAddSemanticKernelHeadersOnEachChatCompletionRequestAsync(Ini _ => builder.Services }; - var chatCompletionService = builder.Build().GetRequiredService(); + var sut = builder.Build().GetRequiredService(); + + // Act + await sut.GetChatMessageContentAsync("test"); + + // Assert + Assert.True(handler.RequestHeaders!.Contains(HttpHeaderConstant.Names.SemanticKernelVersion)); + Assert.Equal(HttpHeaderConstant.Values.GetAssemblyVersion(typeof(ChatClientCore)), handler.RequestHeaders.GetValues(HttpHeaderConstant.Names.SemanticKernelVersion).FirstOrDefault()); + + Assert.True(handler.RequestHeaders.Contains("User-Agent")); + Assert.Contains(HttpHeaderConstant.Values.UserAgent, handler.RequestHeaders.GetValues("User-Agent").FirstOrDefault()); + } + + [Theory] + [InlineData(InitializationType.ApiKey)] + [InlineData(InitializationType.ClientInServiceProvider)] + public async Task ItAddSemanticKernelHeadersOnEachChatClientRequestAsync(InitializationType type) + { + // Arrange + using HttpMessageHandlerStub handler = new(); + using HttpClient httpClient = new(handler); + httpClient.BaseAddress = this._endpoint; + handler.ResponseToReturn = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent(File.ReadAllText("TestData/chat_completion_response.json")) + }; + + var builder = Kernel.CreateBuilder(); + + IServiceCollection collection = type switch + { + InitializationType.ApiKey => builder.Services.AddAzureAIInferenceChatClient("modelId", "api-key", this._endpoint, httpClient: httpClient), + InitializationType.ClientInServiceProvider => builder.Services.AddAzureAIInferenceChatClient( + modelId: "modelId", + credential: DelegatedTokenCredential.Create((_, _) => new AccessToken("test", DateTimeOffset.Now)), + endpoint: this._endpoint, + httpClient: httpClient), + _ => builder.Services + }; + + var sut = builder.Build().GetRequiredService(); // Act - await chatCompletionService.GetChatMessageContentAsync("test"); + await sut.GetResponseAsync("test"); // Assert Assert.True(handler.RequestHeaders!.Contains(HttpHeaderConstant.Names.SemanticKernelVersion)); @@ -124,10 +190,10 @@ public async Task ItAddSemanticKernelHeadersOnEachEmbeddingGeneratorRequestAsync _ => builder.Services }; - var embeddingGenerator = builder.Build().GetRequiredService>>(); + var sut = builder.Build().GetRequiredService>>(); // Act - await embeddingGenerator.GenerateAsync("test"); + await sut.GenerateAsync("test"); // Assert Assert.True(handler.RequestHeaders!.Contains(HttpHeaderConstant.Names.SemanticKernelVersion)); diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceKernelBuilderExtensions.cs index 8c7ea7778638..a1982a915760 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceKernelBuilderExtensions.cs @@ -5,6 +5,7 @@ using Azure.AI.Inference; using Azure.Core; using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel.Connectors.AzureAIInference; namespace Microsoft.SemanticKernel; @@ -96,4 +97,87 @@ public static IKernelBuilder AddAzureAIInferenceChatCompletion( return builder; } + + /// + /// Adds the to the . + /// + /// The instance to augment. + /// Target Model Id + /// API Key + /// Endpoint / Target URI + /// Custom for HTTP requests. + /// A local identifier for the given AI service + /// An optional source name that will be used on the telemetry data. + /// An optional callback that can be used to configure the instance. + /// The same instance as . + public static IKernelBuilder AddAzureAIInferenceChatClient( + this IKernelBuilder builder, + string modelId, + string? apiKey = null, + Uri? endpoint = null, + HttpClient? httpClient = null, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(builder); + + builder.Services.AddAzureAIInferenceChatClient(modelId, apiKey, endpoint, httpClient, serviceId, openTelemetrySourceName, openTelemetryConfig); + + return builder; + } + + /// + /// Adds the to the . + /// + /// The instance to augment. + /// Target Model Id + /// Token credential, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc. + /// Endpoint / Target URI + /// Custom for HTTP requests. + /// A local identifier for the given AI service + /// An optional source name that will be used on the telemetry data. + /// An optional callback that can be used to configure the instance. + /// The same instance as . + public static IKernelBuilder AddAzureAIInferenceChatClient( + this IKernelBuilder builder, + string modelId, + TokenCredential credential, + Uri? endpoint = null, + HttpClient? httpClient = null, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(builder); + + builder.Services.AddAzureAIInferenceChatClient(modelId, credential, endpoint, httpClient, serviceId, openTelemetrySourceName, openTelemetryConfig); + + return builder; + } + + /// + /// Adds the to the . + /// + /// The instance to augment. + /// Azure AI Inference model id + /// to use for the service. If null, one must be available in the service provider when this service is resolved. + /// A local identifier for the given AI service + /// An optional source name that will be used on the telemetry data. + /// An optional callback that can be used to configure the instance. + /// The same instance as . + public static IKernelBuilder AddAzureAIInferenceChatClient( + this IKernelBuilder builder, + string modelId, + ChatCompletionsClient? chatClient = null, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(builder); + + builder.Services.AddAzureAIInferenceChatClient(modelId, chatClient, serviceId, openTelemetrySourceName, openTelemetryConfig); + + return builder; + } } diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.DependencyInjection.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.DependencyInjection.cs index f9b92540bc72..d79bff465256 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.DependencyInjection.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.DependencyInjection.cs @@ -17,6 +17,7 @@ namespace Microsoft.Extensions.DependencyInjection; /// public static class AzureAIInferenceServiceCollectionExtensions { + #region EmbeddingGenerator /// /// Add an Azure AI Inference to the . /// @@ -85,9 +86,149 @@ public static IServiceCollection AddAzureAIInferenceEmbeddingGenerator( return builder.Build(); }); } + #endregion - #region Private + #region ChatClient + /// + /// Adds an Azure AI Inference to the . + /// + /// The instance to augment. + /// Target Model Id + /// API Key + /// Endpoint / Target URI + /// Custom for HTTP requests. + /// A local identifier for the given AI service + /// An optional source name that will be used on the telemetry data. + /// An optional callback that can be used to configure the instance. + /// The same instance as . + public static IServiceCollection AddAzureAIInferenceChatClient( + this IServiceCollection services, + string modelId, + string? apiKey = null, + Uri? endpoint = null, + HttpClient? httpClient = null, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(modelId); + + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + { + httpClient ??= serviceProvider.GetService(); + var options = ChatClientCore.GetClientOptions(httpClient); + var loggerFactory = serviceProvider.GetService(); + + var builder = new ChatCompletionsClient(endpoint, new AzureKeyCredential(apiKey ?? SingleSpace), options) + .AsIChatClient(modelId) + .AsBuilder(); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder + .UseKernelFunctionInvocation(loggerFactory) + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig) + .Build(serviceProvider); + }); + } + + /// + /// Adds an Azure AI Inference to the . + /// + /// The instance to augment. + /// Target Model Id + /// Token credential, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc. + /// Endpoint / Target URI + /// Custom for HTTP requests. + /// A local identifier for the given AI service + /// An optional source name that will be used on the telemetry data. + /// An optional callback that can be used to configure the instance. + /// The same instance as . + public static IServiceCollection AddAzureAIInferenceChatClient( + this IServiceCollection services, + string modelId, + TokenCredential credential, + Uri? endpoint = null, + HttpClient? httpClient = null, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(modelId); + + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + { + httpClient ??= serviceProvider.GetService(); + var options = ChatClientCore.GetClientOptions(httpClient); + + var loggerFactory = serviceProvider.GetService(); + + var builder = new ChatCompletionsClient(endpoint, credential, options) + .AsIChatClient(modelId) + .AsBuilder(); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder + .UseKernelFunctionInvocation(loggerFactory) + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig) + .Build(serviceProvider); + }); + } + + /// + /// Adds an Azure AI Inference to the . + /// + /// The instance to augment. + /// Azure AI Inference model id + /// to use for the service. If null, one must be available in the service provider when this service is resolved. + /// A local identifier for the given AI service + /// An optional source name that will be used on the telemetry data. + /// An optional callback that can be used to configure the instance. + /// The same instance as . + public static IServiceCollection AddAzureAIInferenceChatClient(this IServiceCollection services, + string modelId, + ChatCompletionsClient? chatClient = null, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(modelId); + + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + { + chatClient ??= serviceProvider.GetRequiredService(); + + var loggerFactory = serviceProvider.GetService(); + + var builder = chatClient + .AsIChatClient(modelId) + .AsBuilder(); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder + .UseKernelFunctionInvocation(loggerFactory) + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig) + .Build(serviceProvider); + }); + } + #endregion ChatClient + + #region Private /// /// When using Azure AI Inference against Gateway APIs that don't require an API key, /// this single space is used to avoid breaking the client. diff --git a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs index d789f508ecaa..f1cfecbba3b4 100644 --- a/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.AzureAIInference/Extensions/AzureAIInferenceServiceCollectionExtensions.cs @@ -51,15 +51,14 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion( var builder = new Azure.AI.Inference.ChatCompletionsClient(endpoint, new Azure.AzureKeyCredential(apiKey ?? SingleSpace), options) .AsIChatClient(modelId) .AsBuilder() - .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig) + .UseKernelFunctionInvocation(loggerFactory); if (loggerFactory is not null) { builder.UseLogging(loggerFactory); } - builder.UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); - return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider); }); } @@ -98,15 +97,14 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion( var builder = new Azure.AI.Inference.ChatCompletionsClient(endpoint, credential, options) .AsIChatClient(modelId) .AsBuilder() - .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig) + .UseKernelFunctionInvocation(loggerFactory); if (loggerFactory is not null) { builder.UseLogging(loggerFactory); } - builder.UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); - return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider); }); } @@ -139,39 +137,20 @@ public static IServiceCollection AddAzureAIInferenceChatCompletion(this IService var builder = chatClient .AsIChatClient(modelId) .AsBuilder() - .UseFunctionInvocation(loggerFactory, f => f.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig) + .UseKernelFunctionInvocation(loggerFactory); if (loggerFactory is not null) { builder.UseLogging(loggerFactory); } - builder.UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); - return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider); }); } #region Private - /// - /// The maximum number of auto-invokes that can be in-flight at any given time as part of the current - /// asynchronous chain of execution. - /// - /// - /// This is a fail-safe mechanism. If someone accidentally manages to set up execution settings in such a way that - /// auto-invocation is invoked recursively, and in particular where a prompt function is able to auto-invoke itself, - /// we could end up in an infinite loop. This const is a backstop against that happening. We should never come close - /// to this limit, but if we do, auto-invoke will be disabled for the current flow in order to prevent runaway execution. - /// With the current setup, the way this could possibly happen is if a prompt function is configured with built-in - /// execution settings that opt-in to auto-invocation of everything in the kernel, in which case the invocation of that - /// prompt function could advertise itself as a candidate for auto-invocation. We don't want to outright block that, - /// if that's something a developer has asked to do (e.g. it might be invoked with different arguments than its parent - /// was invoked with), but we do want to limit it. This limit is arbitrary and can be tweaked in the future and/or made - /// configurable should need arise. - /// - private const int MaxInflightAutoInvokes = 128; - /// /// When using Azure AI Inference against Gateway APIs that don't require an API key, /// this single space is used to avoid breaking the client. diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIKernelBuilderExtensionsChatClientTests.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIKernelBuilderExtensionsChatClientTests.cs new file mode 100644 index 000000000000..1ef379f4213a --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIKernelBuilderExtensionsChatClientTests.cs @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.ClientModel; +using Azure.AI.OpenAI; +using Azure.Core; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; + +namespace SemanticKernel.Connectors.AzureOpenAI.UnitTests.Extensions; + +public class AzureOpenAIKernelBuilderExtensionsChatClientTests +{ + [Fact] + public void AddAzureOpenAIChatClientNullArgsThrow() + { + // Arrange + IKernelBuilder builder = null!; + string deploymentName = "gpt-35-turbo"; + string endpoint = "https://test-endpoint.openai.azure.com/"; + string apiKey = "test_api_key"; + string serviceId = "test_service_id"; + string modelId = "gpt-35-turbo"; + + // Act & Assert + var exception = Assert.Throws(() => builder.AddAzureOpenAIChatClient(deploymentName, endpoint, apiKey, serviceId, modelId)); + Assert.Equal("builder", exception.ParamName); + + exception = Assert.Throws(() => builder.AddAzureOpenAIChatClient(deploymentName, new AzureOpenAIClient(new Uri(endpoint), new ApiKeyCredential(apiKey)), serviceId, modelId)); + Assert.Equal("builder", exception.ParamName); + + TokenCredential credential = DelegatedTokenCredential.Create((_, _) => new AccessToken(apiKey, DateTimeOffset.Now)); + exception = Assert.Throws(() => builder.AddAzureOpenAIChatClient(deploymentName, endpoint, credential, serviceId, modelId)); + Assert.Equal("builder", exception.ParamName); + } + + [Fact] + public void AddAzureOpenAIChatClientDefaultValidParametersRegistersService() + { + // Arrange + var builder = Kernel.CreateBuilder(); + string deploymentName = "gpt-35-turbo"; + string endpoint = "https://test-endpoint.openai.azure.com/"; + string apiKey = "test_api_key"; + string serviceId = "test_service_id"; + string modelId = "gpt-35-turbo"; + + // Act + builder.AddAzureOpenAIChatClient(deploymentName, endpoint, apiKey, serviceId, modelId); + + // Assert + var kernel = builder.Build(); + Assert.NotNull(kernel.GetRequiredService()); + Assert.NotNull(kernel.GetRequiredService(serviceId)); + } + + [Fact] + public void AddAzureOpenAIChatClientWithCredentialValidParametersRegistersService() + { + // Arrange + var builder = Kernel.CreateBuilder(); + string deploymentName = "gpt-35-turbo"; + string endpoint = "https://test-endpoint.openai.azure.com/"; + TokenCredential credential = DelegatedTokenCredential.Create((_, _) => new AccessToken("apiKey", DateTimeOffset.Now)); + string serviceId = "test_service_id"; + string modelId = "gpt-35-turbo"; + + // Act + builder.AddAzureOpenAIChatClient(deploymentName, endpoint, credential, serviceId, modelId); + + // Assert + var kernel = builder.Build(); + Assert.NotNull(kernel.GetRequiredService()); + Assert.NotNull(kernel.GetRequiredService(serviceId)); + } + + [Fact] + public void AddAzureOpenAIChatClientWithClientValidParametersRegistersService() + { + // Arrange + var builder = Kernel.CreateBuilder(); + string deploymentName = "gpt-35-turbo"; + string endpoint = "https://test-endpoint.openai.azure.com/"; + string apiKey = "test_api_key"; + var azureOpenAIClient = new AzureOpenAIClient(new Uri(endpoint), new ApiKeyCredential(apiKey)); + string serviceId = "test_service_id"; + string modelId = "gpt-35-turbo"; + + // Act + builder.AddAzureOpenAIChatClient(deploymentName, azureOpenAIClient, serviceId, modelId); + + // Assert + var kernel = builder.Build(); + Assert.NotNull(kernel.GetRequiredService()); + Assert.NotNull(kernel.GetRequiredService(serviceId)); + } +} diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceCollectionExtensionsChatClientTests.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceCollectionExtensionsChatClientTests.cs new file mode 100644 index 000000000000..9bffb50600d4 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Extensions/AzureOpenAIServiceCollectionExtensionsChatClientTests.cs @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.ClientModel; +using Azure.AI.OpenAI; +using Azure.Core; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; + +namespace SemanticKernel.Connectors.AzureOpenAI.UnitTests.Extensions; + +public class AzureOpenAIServiceCollectionExtensionsChatClientTests +{ + [Fact] + public void AddAzureOpenAIChatClientNullArgsThrow() + { + // Arrange + ServiceCollection services = null!; + string deploymentName = "gpt-35-turbo"; + string endpoint = "https://test-endpoint.openai.azure.com/"; + string apiKey = "test_api_key"; + string serviceId = "test_service_id"; + string modelId = "gpt-35-turbo"; + + // Act & Assert + var exception = Assert.Throws(() => services.AddAzureOpenAIChatClient(deploymentName, endpoint, apiKey, serviceId, modelId)); + Assert.Equal("services", exception.ParamName); + + exception = Assert.Throws(() => services.AddAzureOpenAIChatClient(deploymentName, new AzureOpenAIClient(new Uri(endpoint), new ApiKeyCredential(apiKey)), serviceId, modelId)); + Assert.Equal("services", exception.ParamName); + + TokenCredential credential = DelegatedTokenCredential.Create((_, _) => new AccessToken(apiKey, DateTimeOffset.Now)); + exception = Assert.Throws(() => services.AddAzureOpenAIChatClient(deploymentName, endpoint, credential, serviceId, modelId)); + Assert.Equal("services", exception.ParamName); + } + + [Fact] + public void AddAzureOpenAIChatClientDefaultValidParametersRegistersService() + { + // Arrange + var services = new ServiceCollection(); + string deploymentName = "gpt-35-turbo"; + string endpoint = "https://test-endpoint.openai.azure.com/"; + string apiKey = "test_api_key"; + string serviceId = "test_service_id"; + string modelId = "gpt-35-turbo"; + + // Act + services.AddAzureOpenAIChatClient(deploymentName, endpoint, apiKey, serviceId, modelId); + + // Assert + var serviceProvider = services.BuildServiceProvider(); + var chatClient = serviceProvider.GetKeyedService(serviceId); + Assert.NotNull(chatClient); + } + + [Fact] + public void AddAzureOpenAIChatClientWithCredentialValidParametersRegistersService() + { + // Arrange + var services = new ServiceCollection(); + string deploymentName = "gpt-35-turbo"; + string endpoint = "https://test-endpoint.openai.azure.com/"; + TokenCredential credential = DelegatedTokenCredential.Create((_, _) => new AccessToken("test key", DateTimeOffset.Now)); + string serviceId = "test_service_id"; + string modelId = "gpt-35-turbo"; + + // Act + services.AddAzureOpenAIChatClient(deploymentName, endpoint, credential, serviceId, modelId); + + // Assert + var serviceProvider = services.BuildServiceProvider(); + var chatClient = serviceProvider.GetKeyedService(serviceId); + Assert.NotNull(chatClient); + } + + [Fact] + public void AddAzureOpenAIChatClientWithClientValidParametersRegistersService() + { + // Arrange + var services = new ServiceCollection(); + string deploymentName = "gpt-35-turbo"; + string endpoint = "https://test-endpoint.openai.azure.com/"; + string apiKey = "test_api_key"; + var azureOpenAIClient = new AzureOpenAIClient(new Uri(endpoint), new ApiKeyCredential(apiKey)); + string serviceId = "test_service_id"; + string modelId = "gpt-35-turbo"; + + // Act + services.AddAzureOpenAIChatClient(deploymentName, azureOpenAIClient, serviceId, modelId); + + // Assert + var serviceProvider = services.BuildServiceProvider(); + var chatClient = serviceProvider.GetKeyedService(serviceId); + Assert.NotNull(chatClient); + } + + [Fact] + public void AddAzureOpenAIChatClientWorksWithKernel() + { + // Arrange + var services = new ServiceCollection(); + string deploymentName = "gpt-35-turbo"; + string endpoint = "https://test-endpoint.openai.azure.com/"; + string apiKey = "test_api_key"; + string serviceId = "test_service_id"; + string modelId = "gpt-35-turbo"; + + // Act + services.AddAzureOpenAIChatClient(deploymentName, endpoint, apiKey, serviceId, modelId); + services.AddKernel(); + + // Assert + var serviceProvider = services.BuildServiceProvider(); + var kernel = serviceProvider.GetRequiredService(); + + var serviceFromCollection = serviceProvider.GetKeyedService(serviceId); + var serviceFromKernel = kernel.GetRequiredService(serviceId); + + Assert.NotNull(serviceFromKernel); + Assert.Same(serviceFromCollection, serviceFromKernel); + } +} diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/KernelCore/KernelTests.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/KernelCore/KernelTests.cs index ae409be39872..433a5c5897db 100644 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/KernelCore/KernelTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/KernelCore/KernelTests.cs @@ -36,7 +36,7 @@ public async Task FunctionUsageMetricsLoggingHasAllNeededData() this._multiMessageHandlerStub.ResponsesToReturn.Add( new HttpResponseMessage(System.Net.HttpStatusCode.OK) { Content = new StringContent(ChatCompletionResponse) } ); - using MeterListener listener = EnableTelemetryMeters(); + using MeterListener listener = new(); var builder = Kernel.CreateBuilder(); builder.Services.AddSingleton(this._mockLoggerFactory.Object); @@ -60,15 +60,26 @@ public async Task FunctionUsageMetricsLoggingHasAllNeededData() public async Task FunctionUsageMetricsAreCapturedByTelemetryAsExpected() { // Set up a MeterListener to capture the measurements - using MeterListener listener = EnableTelemetryMeters(); + using MeterListener listener = new(); + var isPublished = false; - var measurements = new Dictionary> + var measurements = new Dictionary> { ["semantic_kernel.function.invocation.token_usage.prompt"] = [], ["semantic_kernel.function.invocation.token_usage.completion"] = [], }; - listener.SetMeasurementEventCallback((instrument, measurement, tags, state) => + listener.InstrumentPublished = (instrument, listener) => + { + if (instrument.Name == "semantic_kernel.function.invocation.token_usage.prompt" || + instrument.Name == "semantic_kernel.function.invocation.token_usage.completion") + { + isPublished = true; + listener.EnableMeasurementEvents(instrument); + } + }; + + listener.SetMeasurementEventCallback((instrument, measurement, tags, state) => { if (instrument.Name == "semantic_kernel.function.invocation.token_usage.prompt" || instrument.Name == "semantic_kernel.function.invocation.token_usage.completion") @@ -105,6 +116,8 @@ public async Task FunctionUsageMetricsAreCapturedByTelemetryAsExpected() listener.Dispose(); + Assert.True(isPublished); + while (!completed) { // Wait for the measurements to be completed @@ -118,22 +131,6 @@ public void Dispose() this._multiMessageHandlerStub.Dispose(); } - private static MeterListener EnableTelemetryMeters() - { - var listener = new MeterListener(); - // Enable the listener to collect data for our specific histogram - listener.InstrumentPublished = (instrument, listener) => - { - if (instrument.Name == "semantic_kernel.function.invocation.token_usage.prompt" || - instrument.Name == "semantic_kernel.function.invocation.token_usage.completion") - { - listener.EnableMeasurementEvents(instrument); - } - }; - listener.Start(); - return listener; - } - private const string ChatCompletionResponse = """ { "id": "chatcmpl-8IlRBQU929ym1EqAY2J4T7GGkW5Om", diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI/CompatibilitySuppressions.xml b/dotnet/src/Connectors/Connectors.AzureOpenAI/CompatibilitySuppressions.xml new file mode 100644 index 000000000000..223e2f7df367 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI/CompatibilitySuppressions.xml @@ -0,0 +1,88 @@ + + + + + CP0002 + M:Microsoft.Extensions.DependencyInjection.AzureOpenAIServiceCollectionExtensions.AddAzureOpenAIEmbeddingGenerator(Microsoft.Extensions.DependencyInjection.IServiceCollection,System.String,Azure.AI.OpenAI.AzureOpenAIClient,System.String,System.String,System.Nullable{System.Int32}) + lib/net8.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + lib/net8.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + true + + + CP0002 + M:Microsoft.Extensions.DependencyInjection.AzureOpenAIServiceCollectionExtensions.AddAzureOpenAIEmbeddingGenerator(Microsoft.Extensions.DependencyInjection.IServiceCollection,System.String,System.String,Azure.Core.TokenCredential,System.String,System.String,System.Nullable{System.Int32},System.String,System.Net.Http.HttpClient) + lib/net8.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + lib/net8.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + true + + + CP0002 + M:Microsoft.Extensions.DependencyInjection.AzureOpenAIServiceCollectionExtensions.AddAzureOpenAIEmbeddingGenerator(Microsoft.Extensions.DependencyInjection.IServiceCollection,System.String,System.String,System.String,System.String,System.String,System.Nullable{System.Int32},System.String,System.Net.Http.HttpClient) + lib/net8.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + lib/net8.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + true + + + CP0002 + M:Microsoft.SemanticKernel.AzureOpenAIKernelBuilderExtensions.AddAzureOpenAIEmbeddingGenerator(Microsoft.SemanticKernel.IKernelBuilder,System.String,Azure.AI.OpenAI.AzureOpenAIClient,System.String,System.String,System.Nullable{System.Int32}) + lib/net8.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + lib/net8.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + true + + + CP0002 + M:Microsoft.SemanticKernel.AzureOpenAIKernelBuilderExtensions.AddAzureOpenAIEmbeddingGenerator(Microsoft.SemanticKernel.IKernelBuilder,System.String,System.String,Azure.Core.TokenCredential,System.String,System.String,System.Nullable{System.Int32},System.String,System.Net.Http.HttpClient) + lib/net8.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + lib/net8.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + true + + + CP0002 + M:Microsoft.SemanticKernel.AzureOpenAIKernelBuilderExtensions.AddAzureOpenAIEmbeddingGenerator(Microsoft.SemanticKernel.IKernelBuilder,System.String,System.String,System.String,System.String,System.String,System.Nullable{System.Int32},System.String,System.Net.Http.HttpClient) + lib/net8.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + lib/net8.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + true + + + CP0002 + M:Microsoft.Extensions.DependencyInjection.AzureOpenAIServiceCollectionExtensions.AddAzureOpenAIEmbeddingGenerator(Microsoft.Extensions.DependencyInjection.IServiceCollection,System.String,Azure.AI.OpenAI.AzureOpenAIClient,System.String,System.String,System.Nullable{System.Int32}) + lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + true + + + CP0002 + M:Microsoft.Extensions.DependencyInjection.AzureOpenAIServiceCollectionExtensions.AddAzureOpenAIEmbeddingGenerator(Microsoft.Extensions.DependencyInjection.IServiceCollection,System.String,System.String,Azure.Core.TokenCredential,System.String,System.String,System.Nullable{System.Int32},System.String,System.Net.Http.HttpClient) + lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + true + + + CP0002 + M:Microsoft.Extensions.DependencyInjection.AzureOpenAIServiceCollectionExtensions.AddAzureOpenAIEmbeddingGenerator(Microsoft.Extensions.DependencyInjection.IServiceCollection,System.String,System.String,System.String,System.String,System.String,System.Nullable{System.Int32},System.String,System.Net.Http.HttpClient) + lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + true + + + CP0002 + M:Microsoft.SemanticKernel.AzureOpenAIKernelBuilderExtensions.AddAzureOpenAIEmbeddingGenerator(Microsoft.SemanticKernel.IKernelBuilder,System.String,Azure.AI.OpenAI.AzureOpenAIClient,System.String,System.String,System.Nullable{System.Int32}) + lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + true + + + CP0002 + M:Microsoft.SemanticKernel.AzureOpenAIKernelBuilderExtensions.AddAzureOpenAIEmbeddingGenerator(Microsoft.SemanticKernel.IKernelBuilder,System.String,System.String,Azure.Core.TokenCredential,System.String,System.String,System.Nullable{System.Int32},System.String,System.Net.Http.HttpClient) + lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + true + + + CP0002 + M:Microsoft.SemanticKernel.AzureOpenAIKernelBuilderExtensions.AddAzureOpenAIEmbeddingGenerator(Microsoft.SemanticKernel.IKernelBuilder,System.String,System.String,System.String,System.String,System.String,System.Nullable{System.Int32},System.String,System.Net.Http.HttpClient) + lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.AzureOpenAI.dll + true + + \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI/Extensions/AzureOpenAIKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI/Extensions/AzureOpenAIKernelBuilderExtensions.cs index cd66ad829c59..a3b626edfe49 100644 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI/Extensions/AzureOpenAIKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI/Extensions/AzureOpenAIKernelBuilderExtensions.cs @@ -6,6 +6,7 @@ using System.Net.Http; using Azure.AI.OpenAI; using Azure.Core; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel.AudioToText; @@ -26,6 +27,127 @@ namespace Microsoft.SemanticKernel; /// public static partial class AzureOpenAIKernelBuilderExtensions { + #region Chat Client + + /// + /// Adds an Azure OpenAI to the . + /// + /// The instance to augment. + /// Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource + /// Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// Azure OpenAI API key, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// A local identifier for the given AI service + /// Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// Optional Azure OpenAI API version, see available here + /// The HttpClient to use with this service. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The same instance as . + public static IKernelBuilder AddAzureOpenAIChatClient( + this IKernelBuilder builder, + string deploymentName, + string endpoint, + string apiKey, + string? serviceId = null, + string? modelId = null, + string? apiVersion = null, + HttpClient? httpClient = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(builder); + + builder.Services.AddAzureOpenAIChatClient( + deploymentName, + endpoint, + apiKey, + serviceId, + modelId, + apiVersion, + httpClient, + openTelemetrySourceName, + openTelemetryConfig); + + return builder; + } + + /// + /// Adds an Azure OpenAI to the . + /// + /// The instance to augment. + /// Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource + /// Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// Token credentials, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc. + /// A local identifier for the given AI service + /// Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// Optional Azure OpenAI API version, see available here + /// The HttpClient to use with this service. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The same instance as . + public static IKernelBuilder AddAzureOpenAIChatClient( + this IKernelBuilder builder, + string deploymentName, + string endpoint, + TokenCredential credentials, + string? serviceId = null, + string? modelId = null, + string? apiVersion = null, + HttpClient? httpClient = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(builder); + + builder.Services.AddAzureOpenAIChatClient( + deploymentName, + endpoint, + credentials, + serviceId, + modelId, + apiVersion, + httpClient, + openTelemetrySourceName, + openTelemetryConfig); + + return builder; + } + + /// + /// Adds an Azure OpenAI to the . + /// + /// The instance to augment. + /// Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource + /// to use for the service. If null, one must be available in the service provider when this service is resolved. + /// A local identifier for the given AI service + /// Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The same instance as . + public static IKernelBuilder AddAzureOpenAIChatClient( + this IKernelBuilder builder, + string deploymentName, + AzureOpenAIClient? azureOpenAIClient = null, + string? serviceId = null, + string? modelId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(builder); + + builder.Services.AddAzureOpenAIChatClient( + deploymentName, + azureOpenAIClient, + serviceId, + modelId, + openTelemetrySourceName, + openTelemetryConfig); + + return builder; + } + + #endregion + #region Chat Completion /// @@ -314,6 +436,8 @@ public static IKernelBuilder AddAzureOpenAITextToAudio( /// The number of dimensions the resulting output embeddings should have. Only supported in "text-embedding-3" and later models. /// Optional Azure OpenAI API version, see available here /// The HttpClient to use with this service. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. /// The same instance as . [Experimental("SKEXP0010")] public static IKernelBuilder AddAzureOpenAIEmbeddingGenerator( @@ -325,7 +449,9 @@ public static IKernelBuilder AddAzureOpenAIEmbeddingGenerator( string? modelId = null, int? dimensions = null, string? apiVersion = null, - HttpClient? httpClient = null) + HttpClient? httpClient = null, + string? openTelemetrySourceName = null, + Action>>? openTelemetryConfig = null) { Verify.NotNull(builder); @@ -337,7 +463,9 @@ public static IKernelBuilder AddAzureOpenAIEmbeddingGenerator( modelId, dimensions, apiVersion, - httpClient); + httpClient, + openTelemetrySourceName, + openTelemetryConfig); return builder; } @@ -354,6 +482,8 @@ public static IKernelBuilder AddAzureOpenAIEmbeddingGenerator( /// The number of dimensions the resulting output embeddings should have. Only supported in "text-embedding-3" and later models. /// Optional Azure OpenAI API version, see available here /// The HttpClient to use with this service. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. /// The same instance as . [Experimental("SKEXP0010")] public static IKernelBuilder AddAzureOpenAIEmbeddingGenerator( @@ -365,7 +495,9 @@ public static IKernelBuilder AddAzureOpenAIEmbeddingGenerator( string? modelId = null, int? dimensions = null, string? apiVersion = null, - HttpClient? httpClient = null) + HttpClient? httpClient = null, + string? openTelemetrySourceName = null, + Action>>? openTelemetryConfig = null) { Verify.NotNull(builder); Verify.NotNull(credential); @@ -378,7 +510,9 @@ public static IKernelBuilder AddAzureOpenAIEmbeddingGenerator( modelId, dimensions, apiVersion, - httpClient); + httpClient, + openTelemetrySourceName, + openTelemetryConfig); return builder; } @@ -392,6 +526,8 @@ public static IKernelBuilder AddAzureOpenAIEmbeddingGenerator( /// A local identifier for the given AI service /// Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart /// The number of dimensions the resulting output embeddings should have. Only supported in "text-embedding-3" and later models. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. /// The same instance as . [Experimental("SKEXP0010")] public static IKernelBuilder AddAzureOpenAIEmbeddingGenerator( @@ -400,7 +536,9 @@ public static IKernelBuilder AddAzureOpenAIEmbeddingGenerator( AzureOpenAIClient? azureOpenAIClient = null, string? serviceId = null, string? modelId = null, - int? dimensions = null) + int? dimensions = null, + string? openTelemetrySourceName = null, + Action>>? openTelemetryConfig = null) { Verify.NotNull(builder); @@ -409,7 +547,9 @@ public static IKernelBuilder AddAzureOpenAIEmbeddingGenerator( azureOpenAIClient, serviceId, modelId, - dimensions); + dimensions, + openTelemetrySourceName, + openTelemetryConfig); return builder; } diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI/Extensions/AzureOpenAIServiceCollectionExtensions.DependencyInjection.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI/Extensions/AzureOpenAIServiceCollectionExtensions.DependencyInjection.cs index 4eff66bf6f1f..229bef7163a4 100644 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI/Extensions/AzureOpenAIServiceCollectionExtensions.DependencyInjection.cs +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI/Extensions/AzureOpenAIServiceCollectionExtensions.DependencyInjection.cs @@ -1,5 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.ClientModel; using System.Diagnostics.CodeAnalysis; using System.Net.Http; using Azure.AI.OpenAI; @@ -7,8 +9,6 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel; -using Microsoft.SemanticKernel.Connectors.AzureOpenAI; -using Microsoft.SemanticKernel.Embeddings; using Microsoft.SemanticKernel.Http; namespace Microsoft.Extensions.DependencyInjection; @@ -16,10 +16,184 @@ namespace Microsoft.Extensions.DependencyInjection; /// /// Provides extension methods for to configure Azure OpenAI connectors. /// -public static class AzureOpenAIServiceCollectionExtensions +public static partial class AzureOpenAIServiceCollectionExtensions { + #region Chat Client + + /// + /// Adds an Azure OpenAI to the . + /// + /// The instance to augment. + /// Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource + /// Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// Azure OpenAI API key, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// A local identifier for the given AI service + /// Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// Optional Azure OpenAI API version, see available here + /// The HttpClient to use with this service. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The same instance as . + [Experimental("SKEXP0010")] + public static IServiceCollection AddAzureOpenAIChatClient( + this IServiceCollection services, + string deploymentName, + string endpoint, + string apiKey, + string? serviceId = null, + string? modelId = null, + string? apiVersion = null, + HttpClient? httpClient = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(endpoint); + Verify.NotNullOrWhiteSpace(apiKey); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + var loggerFactory = serviceProvider.GetService(); + + AzureOpenAIClient client = Microsoft.SemanticKernel.AzureOpenAIServiceCollectionExtensions.CreateAzureOpenAIClient( + endpoint, + new ApiKeyCredential(apiKey), + HttpClientProvider.GetHttpClient(httpClient, serviceProvider), + apiVersion); + + var builder = client.GetChatClient(deploymentName) + .AsIChatClient() + .AsBuilder() + .UseKernelFunctionInvocation(loggerFactory) + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + + /// + /// Adds an Azure OpenAI to the . + /// + /// The instance to augment. + /// Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource + /// Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// Token credentials, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc. + /// A local identifier for the given AI service + /// Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// Optional Azure OpenAI API version, see available here + /// The HttpClient to use with this service. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The same instance as . + [Experimental("SKEXP0010")] + public static IServiceCollection AddAzureOpenAIChatClient( + this IServiceCollection services, + string deploymentName, + string endpoint, + TokenCredential credentials, + string? serviceId = null, + string? modelId = null, + string? apiVersion = null, + HttpClient? httpClient = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(endpoint); + Verify.NotNull(credentials); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + var loggerFactory = serviceProvider.GetService(); + + AzureOpenAIClient client = Microsoft.SemanticKernel.AzureOpenAIServiceCollectionExtensions.CreateAzureOpenAIClient( + endpoint, + credentials, + HttpClientProvider.GetHttpClient(httpClient, serviceProvider), + apiVersion); + + var builder = client.GetChatClient(deploymentName) + .AsIChatClient() + .AsBuilder() + .UseKernelFunctionInvocation(loggerFactory) + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + /// - /// Adds the to the . + /// Adds an Azure OpenAI to the . + /// + /// The instance to augment. + /// Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource + /// to use for the service. If null, one must be available in the service provider when this service is resolved. + /// A local identifier for the given AI service + /// Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The same instance as . + [Experimental("SKEXP0010")] + public static IServiceCollection AddAzureOpenAIChatClient( + this IServiceCollection services, + string deploymentName, + AzureOpenAIClient? azureOpenAIClient = null, + string? serviceId = null, + string? modelId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(deploymentName); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + var loggerFactory = serviceProvider.GetService(); + + var client = azureOpenAIClient ?? serviceProvider.GetRequiredService(); + + var builder = client.GetChatClient(deploymentName) + .AsIChatClient() + .AsBuilder() + .UseKernelFunctionInvocation(loggerFactory) + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + + #endregion + + #region Embedding Generator + /// + /// Adds the to the . /// /// The instance to augment. /// Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource @@ -30,6 +204,8 @@ public static class AzureOpenAIServiceCollectionExtensions /// The number of dimensions the resulting output embeddings should have. Only supported in "text-embedding-3" and later models. /// Optional Azure OpenAI API version, see available here /// The HttpClient to use with this service. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. /// The same instance as . [Experimental("SKEXP0010")] public static IServiceCollection AddAzureOpenAIEmbeddingGenerator( @@ -41,70 +217,95 @@ public static IServiceCollection AddAzureOpenAIEmbeddingGenerator( string? modelId = null, int? dimensions = null, string? apiVersion = null, - HttpClient? httpClient = null) + HttpClient? httpClient = null, + string? openTelemetrySourceName = null, + Action>>? openTelemetryConfig = null) { Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(deploymentName); -#pragma warning disable CS0618 // Type or member is obsolete return services.AddKeyedSingleton>>(serviceId, (serviceProvider, _) => - new AzureOpenAITextEmbeddingGenerationService( - deploymentName, + { + var loggerFactory = serviceProvider.GetService(); + + AzureOpenAIClient client = Microsoft.SemanticKernel.AzureOpenAIServiceCollectionExtensions.CreateAzureOpenAIClient( endpoint, - apiKey, - modelId, + new ApiKeyCredential(apiKey), HttpClientProvider.GetHttpClient(httpClient, serviceProvider), - serviceProvider.GetService(), dimensions, - apiVersion) - .AsEmbeddingGenerator()); -#pragma warning restore CS0618 // Type or member is obsolete + apiVersion); + + var builder = client.GetEmbeddingClient(deploymentName) + .AsIEmbeddingGenerator(dimensions) + .AsBuilder() + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + }); } /// - /// Adds the to the . + /// Adds the to the . /// /// The instance to augment. /// Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource /// Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart - /// Token credentials, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc. + /// Token credentials, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc. /// A local identifier for the given AI service /// Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart /// The number of dimensions the resulting output embeddings should have. Only supported in "text-embedding-3" and later models. /// Optional Azure OpenAI API version, see available here /// The HttpClient to use with this service. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. /// The same instance as . [Experimental("SKEXP0010")] public static IServiceCollection AddAzureOpenAIEmbeddingGenerator( this IServiceCollection services, string deploymentName, string endpoint, - TokenCredential credential, + TokenCredential credentials, string? serviceId = null, string? modelId = null, int? dimensions = null, string? apiVersion = null, - HttpClient? httpClient = null) + HttpClient? httpClient = null, + string? openTelemetrySourceName = null, + Action>>? openTelemetryConfig = null) { Verify.NotNull(services); - Verify.NotNull(credential); + Verify.NotNull(credentials); -#pragma warning disable CS0618 // Type or member is obsolete return services.AddKeyedSingleton>>(serviceId, (serviceProvider, _) => - new AzureOpenAITextEmbeddingGenerationService( - deploymentName, + { + var loggerFactory = serviceProvider.GetService(); + + AzureOpenAIClient client = Microsoft.SemanticKernel.AzureOpenAIServiceCollectionExtensions.CreateAzureOpenAIClient( endpoint, - credential, - modelId, + credentials, HttpClientProvider.GetHttpClient(httpClient, serviceProvider), - serviceProvider.GetService(), - dimensions, - apiVersion - ) - .AsEmbeddingGenerator()); -#pragma warning restore CS0618 // Type or member is obsolete + apiVersion); + + var builder = client.GetEmbeddingClient(deploymentName) + .AsIEmbeddingGenerator(dimensions) + .AsBuilder() + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + }); } /// - /// Adds the to the . + /// Adds the to the . /// /// The instance to augment. /// Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource @@ -112,6 +313,8 @@ public static IServiceCollection AddAzureOpenAIEmbeddingGenerator( /// A local identifier for the given AI service /// Model identifier, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart /// The number of dimensions the resulting output embeddings should have. Only supported in "text-embedding-3" and later models. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. /// The same instance as . [Experimental("SKEXP0010")] public static IServiceCollection AddAzureOpenAIEmbeddingGenerator( @@ -120,19 +323,30 @@ public static IServiceCollection AddAzureOpenAIEmbeddingGenerator( AzureOpenAIClient? azureOpenAIClient = null, string? serviceId = null, string? modelId = null, - int? dimensions = null) + int? dimensions = null, + string? openTelemetrySourceName = null, + Action>>? openTelemetryConfig = null) { Verify.NotNull(services); -#pragma warning disable CS0618 // Type or member is obsolete return services.AddKeyedSingleton>>(serviceId, (serviceProvider, _) => - new AzureOpenAITextEmbeddingGenerationService( - deploymentName, - azureOpenAIClient ?? serviceProvider.GetRequiredService(), - modelId, - serviceProvider.GetService(), - dimensions) - .AsEmbeddingGenerator()); -#pragma warning restore CS0618 // Type or member is obsolete + { + var loggerFactory = serviceProvider.GetService(); + var client = azureOpenAIClient ?? serviceProvider.GetRequiredService(); + + var builder = client.GetEmbeddingClient(deploymentName) + .AsIEmbeddingGenerator(dimensions) + .AsBuilder() + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + }); } + + #endregion } diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI/Extensions/AzureOpenAIServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI/Extensions/AzureOpenAIServiceCollectionExtensions.cs index e9d6e799f051..c14e03b02f27 100644 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI/Extensions/AzureOpenAIServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI/Extensions/AzureOpenAIServiceCollectionExtensions.cs @@ -523,9 +523,9 @@ public static IServiceCollection AddAzureOpenAIAudioToText( #endregion - private static AzureOpenAIClient CreateAzureOpenAIClient(string endpoint, ApiKeyCredential credentials, HttpClient? httpClient, string? apiVersion) => + internal static AzureOpenAIClient CreateAzureOpenAIClient(string endpoint, ApiKeyCredential credentials, HttpClient? httpClient, string? apiVersion) => new(new Uri(endpoint), credentials, AzureClientCore.GetAzureOpenAIClientOptions(httpClient, apiVersion)); - private static AzureOpenAIClient CreateAzureOpenAIClient(string endpoint, TokenCredential credentials, HttpClient? httpClient, string? apiVersion) => + internal static AzureOpenAIClient CreateAzureOpenAIClient(string endpoint, TokenCredential credentials, HttpClient? httpClient, string? apiVersion) => new(new Uri(endpoint), credentials, AzureClientCore.GetAzureOpenAIClientOptions(httpClient, apiVersion)); } diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI/Services/AzureOpenAITextEmbeddingGenerationService.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI/Services/AzureOpenAITextEmbeddingGenerationService.cs index adf09e506ba0..e5c9f286a915 100644 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI/Services/AzureOpenAITextEmbeddingGenerationService.cs +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI/Services/AzureOpenAITextEmbeddingGenerationService.cs @@ -18,7 +18,7 @@ namespace Microsoft.SemanticKernel.Connectors.AzureOpenAI; /// Azure OpenAI text embedding service. /// [Experimental("SKEXP0010")] -[Obsolete("Use AzureOpenAIEmbeddingGenerator instead.")] +[Obsolete("Use AddAzureOpenAIEmbeddingGenerator extension methods instead.")] public sealed class AzureOpenAITextEmbeddingGenerationService : ITextEmbeddingGenerationService { private readonly AzureClientCore _client; diff --git a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs index edde919e9f67..eaff0f64323f 100644 --- a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs @@ -155,16 +155,18 @@ public static IServiceCollection AddOllamaChatCompletion( { var loggerFactory = serviceProvider.GetService(); - var builder = ((IChatClient)new OllamaApiClient(endpoint, modelId)) - .AsBuilder() - .UseFunctionInvocation(loggerFactory, config => config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + var ollamaClient = (IChatClient)new OllamaApiClient(endpoint, modelId); + var builder = ollamaClient.AsBuilder(); if (loggerFactory is not null) { builder.UseLogging(loggerFactory); } - return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider); + return builder + .UseKernelFunctionInvocation(loggerFactory) + .Build(serviceProvider) + .AsChatCompletionService(); }); } @@ -190,16 +192,18 @@ public static IServiceCollection AddOllamaChatCompletion( var loggerFactory = serviceProvider.GetService(); - var builder = ((IChatClient)new OllamaApiClient(httpClient, modelId)) - .AsBuilder() - .UseFunctionInvocation(loggerFactory, config => config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); + var ollamaClient = (IChatClient)new OllamaApiClient(httpClient, modelId); + var builder = ollamaClient.AsBuilder(); if (loggerFactory is not null) { builder.UseLogging(loggerFactory); } - return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider); + return builder + .UseKernelFunctionInvocation(loggerFactory) + .Build(serviceProvider) + .AsChatCompletionService(); }); } @@ -230,16 +234,16 @@ public static IServiceCollection AddOllamaChatCompletion( throw new InvalidOperationException($"No {nameof(IOllamaApiClient)} implementations found in the service collection."); } - var builder = ((IChatClient)ollamaClient) - .AsBuilder() - .UseFunctionInvocation(loggerFactory, config => config.MaximumIterationsPerRequest = MaxInflightAutoInvokes); - + var builder = ((IChatClient)ollamaClient).AsBuilder(); if (loggerFactory is not null) { builder.UseLogging(loggerFactory); } - return builder.Build(serviceProvider).AsChatCompletionService(serviceProvider); + return builder + .UseKernelFunctionInvocation(loggerFactory) + .Build(serviceProvider) + .AsChatCompletionService(); }); } @@ -358,26 +362,4 @@ public static IServiceCollection AddOllamaTextEmbeddingGeneration( } #endregion - - #region Private - - /// - /// The maximum number of auto-invokes that can be in-flight at any given time as part of the current - /// asynchronous chain of execution. - /// - /// - /// This is a fail-safe mechanism. If someone accidentally manages to set up execution settings in such a way that - /// auto-invocation is invoked recursively, and in particular where a prompt function is able to auto-invoke itself, - /// we could end up in an infinite loop. This const is a backstop against that happening. We should never come close - /// to this limit, but if we do, auto-invoke will be disabled for the current flow in order to prevent runaway execution. - /// With the current setup, the way this could possibly happen is if a prompt function is configured with built-in - /// execution settings that opt-in to auto-invocation of everything in the kernel, in which case the invocation of that - /// prompt function could advertize itself as a candidate for auto-invocation. We don't want to outright block that, - /// if that's something a developer has asked to do (e.g. it might be invoked with different arguments than its parent - /// was invoked with), but we do want to limit it. This limit is arbitrary and can be tweaked in the future and/or made - /// configurable should need arise. - /// - private const int MaxInflightAutoInvokes = 128; - - #endregion } diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj index 57f983dd5c9b..04d35b9e6561 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj @@ -96,6 +96,15 @@ Always + + Always + + + Always + + + Always + diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs new file mode 100644 index 000000000000..ca568ecb229a --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs @@ -0,0 +1,793 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.OpenAI; +using Xunit; + +namespace SemanticKernel.Connectors.OpenAI.UnitTests.Core; + +public sealed class AutoFunctionInvocationFilterChatClientTests : IDisposable +{ + private readonly MultipleHttpMessageHandlerStub _messageHandlerStub; + private readonly HttpClient _httpClient; + + public AutoFunctionInvocationFilterChatClientTests() + { + this._messageHandlerStub = new MultipleHttpMessageHandlerStub(); + + this._httpClient = new HttpClient(this._messageHandlerStub, false); + } + + [Fact] + public async Task FiltersAreExecutedCorrectlyAsync() + { + // Arrange + int filterInvocations = 0; + int functionInvocations = 0; + int[] expectedRequestSequenceNumbers = [0, 0, 1, 1]; + int[] expectedFunctionSequenceNumbers = [0, 1, 0, 1]; + List requestSequenceNumbers = []; + List functionSequenceNumbers = []; + Kernel? contextKernel = null; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { functionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { functionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + contextKernel = context.Kernel; + + if (context.ChatHistory.Last() is OpenAIChatMessageContent content) + { + Assert.Equal(2, content.ToolCalls.Count); + } + + requestSequenceNumbers.Add(context.RequestSequenceIndex); + functionSequenceNumbers.Add(context.FunctionSequenceIndex); + + await next(context); + + filterInvocations++; + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + })); + + // Assert + Assert.Equal(4, filterInvocations); + Assert.Equal(4, functionInvocations); + Assert.Equal(expectedRequestSequenceNumbers, requestSequenceNumbers); + Assert.Equal(expectedFunctionSequenceNumbers, functionSequenceNumbers); + Assert.Same(kernel, contextKernel); + Assert.Equal("Test chat response", result.ToString()); + } + + [Fact] + public async Task FunctionSequenceIndexIsCorrectForConcurrentCallsAsync() + { + // Arrange + List functionSequenceNumbers = []; + List expectedFunctionSequenceNumbers = [0, 1, 0, 1]; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + functionSequenceNumbers.Add(context.FunctionSequenceIndex); + + await next(context); + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(options: new() + { + AllowParallelCalls = true, + AllowConcurrentInvocation = true + }) + })); + + // Assert + Assert.Equal(expectedFunctionSequenceNumbers, functionSequenceNumbers); + } + + [Fact] + public async Task FiltersAreExecutedCorrectlyOnStreamingAsync() + { + // Arrange + int filterInvocations = 0; + int functionInvocations = 0; + List requestSequenceNumbers = []; + List functionSequenceNumbers = []; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { functionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { functionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + if (context.ChatHistory.Last() is OpenAIChatMessageContent content) + { + Assert.Equal(2, content.ToolCalls.Count); + } + + requestSequenceNumbers.Add(context.RequestSequenceIndex); + functionSequenceNumbers.Add(context.FunctionSequenceIndex); + + await next(context); + + filterInvocations++; + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + var executionSettings = new OpenAIPromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + // Act + await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings))) + { } + + // Assert + Assert.Equal(4, filterInvocations); + Assert.Equal(4, functionInvocations); + Assert.Equal([0, 0, 1, 1], requestSequenceNumbers); + Assert.Equal([0, 1, 0, 1], functionSequenceNumbers); + } + + [Fact] + public async Task DifferentWaysOfAddingFiltersWorkCorrectlyAsync() + { + // Arrange + var executionOrder = new List(); + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var filter1 = new AutoFunctionInvocationFilter(async (context, next) => + { + executionOrder.Add("Filter1-Invoking"); + await next(context); + }); + + var filter2 = new AutoFunctionInvocationFilter(async (context, next) => + { + executionOrder.Add("Filter2-Invoking"); + await next(context); + }); + + var builder = Kernel.CreateBuilder(); + + builder.Plugins.Add(plugin); + + builder.Services.AddOpenAIChatClient("model-id", "test-api-key", "organization-id", httpClient: this._httpClient); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + + // Case #1 - Add filter to services + builder.Services.AddSingleton(filter1); + + var kernel = builder.Build(); + + // Case #2 - Add filter to kernel + kernel.AutoFunctionInvocationFilters.Add(filter2); + + var result = await kernel.InvokePromptAsync("Test prompt", new(new PromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + })); + + // Assert + Assert.Equal("Filter1-Invoking", executionOrder[0]); + Assert.Equal("Filter2-Invoking", executionOrder[1]); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task MultipleFiltersAreExecutedInOrderAsync(bool isStreaming) + { + // Arrange + var executionOrder = new List(); + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var filter1 = new AutoFunctionInvocationFilter(async (context, next) => + { + executionOrder.Add("Filter1-Invoking"); + await next(context); + executionOrder.Add("Filter1-Invoked"); + }); + + var filter2 = new AutoFunctionInvocationFilter(async (context, next) => + { + executionOrder.Add("Filter2-Invoking"); + await next(context); + executionOrder.Add("Filter2-Invoked"); + }); + + var filter3 = new AutoFunctionInvocationFilter(async (context, next) => + { + executionOrder.Add("Filter3-Invoking"); + await next(context); + executionOrder.Add("Filter3-Invoked"); + }); + + var builder = Kernel.CreateBuilder(); + + builder.Plugins.Add(plugin); + + builder.Services.AddOpenAIChatClient("model-id", "test-api-key", "organization-id", httpClient: this._httpClient); + + builder.Services.AddSingleton(filter1); + builder.Services.AddSingleton(filter2); + builder.Services.AddSingleton(filter3); + + var kernel = builder.Build(); + + var settings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + // Act + if (isStreaming) + { + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(settings))) + { } + } + else + { + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + await kernel.InvokePromptAsync("Test prompt", new(settings)); + } + + // Assert + Assert.Equal("Filter1-Invoking", executionOrder[0]); + Assert.Equal("Filter2-Invoking", executionOrder[1]); + Assert.Equal("Filter3-Invoking", executionOrder[2]); + Assert.Equal("Filter3-Invoked", executionOrder[3]); + Assert.Equal("Filter2-Invoked", executionOrder[4]); + Assert.Equal("Filter1-Invoked", executionOrder[5]); + } + + [Fact] + public async Task FilterCanOverrideArgumentsAsync() + { + // Arrange + const string NewValue = "NewValue"; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + context.Arguments!["parameter"] = NewValue; + await next(context); + context.Terminate = true; + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + })); + + // Assert + var chatResponse = Assert.IsType(result.GetValue()); + Assert.NotNull(chatResponse); + + var lastFunctionResult = GetLastFunctionResultFromChatResponse(chatResponse); + Assert.NotNull(lastFunctionResult); + Assert.Equal("NewValue", lastFunctionResult.ToString()); + } + + [Fact] + public async Task FilterCanHandleExceptionAsync() + { + // Arrange + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { throw new KernelException("Exception from Function1"); }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => "Result from Function2", "Function2"); + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + try + { + await next(context); + } + catch (KernelException exception) + { + Assert.Equal("Exception from Function1", exception.Message); + context.Result = new FunctionResult(context.Result, "Result from filter"); + } + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + var chatClient = kernel.GetRequiredService(); + + var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + var options = executionSettings.ToChatOptions(kernel); + List messageList = [new(ChatRole.System, "System message")]; + + // Act + var resultMessages = await chatClient.GetResponseAsync(messageList, options, CancellationToken.None); + + // Assert + var firstToolMessage = resultMessages.Messages.First(m => m.Role == ChatRole.Tool); + Assert.NotNull(firstToolMessage); + var firstFunctionResult = firstToolMessage.Contents[^2] as Microsoft.Extensions.AI.FunctionResultContent; + var secondFunctionResult = firstToolMessage.Contents[^1] as Microsoft.Extensions.AI.FunctionResultContent; + + Assert.NotNull(firstFunctionResult); + Assert.NotNull(secondFunctionResult); + Assert.Equal("Result from filter", firstFunctionResult.Result!.ToString()); + Assert.Equal("Result from Function2", secondFunctionResult.Result!.ToString()); + } + + [Fact] + public async Task FilterCanHandleExceptionOnStreamingAsync() + { + // Arrange + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { throw new KernelException("Exception from Function1"); }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => "Result from Function2", "Function2"); + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + try + { + await next(context); + } + catch (KernelException) + { + context.Result = new FunctionResult(context.Result, "Result from filter"); + } + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + var chatClient = kernel.GetRequiredService(); + + var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + var options = executionSettings.ToChatOptions(kernel); + List messageList = []; + + // Act + List streamingContent = []; + await foreach (var update in chatClient.GetStreamingResponseAsync(messageList, options, CancellationToken.None)) + { + streamingContent.Add(update); + } + var chatResponse = streamingContent.ToChatResponse(); + + // Assert + var firstToolMessage = chatResponse.Messages.First(m => m.Role == ChatRole.Tool); + Assert.NotNull(firstToolMessage); + var firstFunctionResult = firstToolMessage.Contents[^2] as Microsoft.Extensions.AI.FunctionResultContent; + var secondFunctionResult = firstToolMessage.Contents[^1] as Microsoft.Extensions.AI.FunctionResultContent; + + Assert.NotNull(firstFunctionResult); + Assert.NotNull(secondFunctionResult); + Assert.Equal("Result from filter", firstFunctionResult.Result!.ToString()); + Assert.Equal("Result from Function2", secondFunctionResult.Result!.ToString()); + } + + [Fact] + public async Task FiltersCanSkipFunctionExecutionAsync() + { + // Arrange + int filterInvocations = 0; + int firstFunctionInvocations = 0; + int secondFunctionInvocations = 0; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + // Filter delegate is invoked only for second function, the first one should be skipped. + if (context.Function.Name == "Function2" && context.Function.PluginName == "MyPlugin") + { + await next(context); + } + + filterInvocations++; + }); + + using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(File.ReadAllText("TestData/filters_chatclient_multiple_function_calls_test_response.json")) }; + using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(File.ReadAllText("TestData/chat_completion_test_response.json")) }; + + this._messageHandlerStub.ResponsesToReturn = [response1, response2]; + + // Act + var result = await kernel.InvokePromptAsync("Test prompt", new(new PromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + })); + + // Assert + Assert.Equal(2, filterInvocations); + Assert.Equal(0, firstFunctionInvocations); + Assert.Equal(1, secondFunctionInvocations); + } + + [Fact] + public async Task PreFilterCanTerminateOperationAsync() + { + // Arrange + int firstFunctionInvocations = 0; + int secondFunctionInvocations = 0; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + // Terminating before first function, so all functions won't be invoked. + context.Terminate = true; + + await next(context); + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + await kernel.InvokePromptAsync("Test prompt", new(new PromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + })); + + // Assert + Assert.Equal(0, firstFunctionInvocations); + Assert.Equal(0, secondFunctionInvocations); + } + + [Fact] + public async Task PreFilterCanTerminateOperationOnStreamingAsync() + { + // Arrange + int firstFunctionInvocations = 0; + int secondFunctionInvocations = 0; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + // Terminating before first function, so all functions won't be invoked. + context.Terminate = true; + + await next(context); + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + // Act + await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings))) + { } + + // Assert + Assert.Equal(0, firstFunctionInvocations); + Assert.Equal(0, secondFunctionInvocations); + } + + [Fact] + public async Task PostFilterCanTerminateOperationAsync() + { + // Arrange + int firstFunctionInvocations = 0; + int secondFunctionInvocations = 0; + List requestSequenceNumbers = []; + List functionSequenceNumbers = []; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + requestSequenceNumbers.Add(context.RequestSequenceIndex); + functionSequenceNumbers.Add(context.FunctionSequenceIndex); + + await next(context); + + // Terminating after first function, so second function won't be invoked. + context.Terminate = true; + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + var functionResult = await kernel.InvokePromptAsync("Test prompt", new(new PromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + })); + + // Assert + Assert.Equal(1, firstFunctionInvocations); + Assert.Equal(0, secondFunctionInvocations); + Assert.Equal([0], requestSequenceNumbers); + Assert.Equal([0], functionSequenceNumbers); + + // Results of function invoked before termination should be returned + var chatResponse = functionResult.GetValue(); + Assert.NotNull(chatResponse); + + var result = GetLastFunctionResultFromChatResponse(chatResponse); + Assert.NotNull(result); + Assert.Equal("function1-value", result.ToString()); + } + + [Fact] + public async Task PostFilterCanTerminateOperationOnStreamingAsync() + { + // Arrange + int firstFunctionInvocations = 0; + int secondFunctionInvocations = 0; + List requestSequenceNumbers = []; + List functionSequenceNumbers = []; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + requestSequenceNumbers.Add(context.RequestSequenceIndex); + functionSequenceNumbers.Add(context.FunctionSequenceIndex); + + await next(context); + + // Terminating after first function, so second function won't be invoked. + context.Terminate = true; + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + List streamingContent = []; + + // Act + await foreach (var update in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings))) + { + streamingContent.Add(update); + } + + // Assert + Assert.Equal(1, firstFunctionInvocations); + Assert.Equal(0, secondFunctionInvocations); + Assert.Equal([0], requestSequenceNumbers); + Assert.Equal([0], functionSequenceNumbers); + + // Results of function invoked before termination should be returned + Assert.Equal(4, streamingContent.Count); + + var chatResponse = streamingContent.ToChatResponse(); + Assert.NotNull(chatResponse); + + var result = GetLastFunctionResultFromChatResponse(chatResponse); + Assert.NotNull(result); + Assert.Equal("function1-value", result.ToString()); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming) + { + // Arrange + bool? actualStreamingFlag = null; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var filter = new AutoFunctionInvocationFilter(async (context, next) => + { + actualStreamingFlag = context.IsStreaming; + await next(context); + }); + + var builder = Kernel.CreateBuilder(); + + builder.Plugins.Add(plugin); + + builder.Services.AddOpenAIChatClient("model-id", "test-api-key", "organization-id", httpClient: this._httpClient); + + builder.Services.AddSingleton(filter); + + var kernel = builder.Build(); + + var settings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + // Act + if (isStreaming) + { + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + await kernel.InvokePromptStreamingAsync("Test prompt", new(settings)).ToListAsync(); + } + else + { + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + await kernel.InvokePromptAsync("Test prompt", new(settings)); + } + + // Assert + Assert.Equal(isStreaming, actualStreamingFlag); + } + + [Fact] + public async Task PromptExecutionSettingsArePropagatedFromInvokePromptToFilterContextAsync() + { + // Arrange + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]); + + AutoFunctionInvocationContext? actualContext = null; + + var kernel = this.GetKernelWithFilter(plugin, (context, next) => + { + actualContext = context; + return Task.CompletedTask; + }); + + var expectedExecutionSettings = new PromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + }; + + // Act + var result = await kernel.InvokePromptAsync("Test prompt", new(expectedExecutionSettings)); + + // Assert + Assert.NotNull(actualContext); + Assert.Same(expectedExecutionSettings, actualContext!.ExecutionSettings); + } + + [Fact] + public async Task PromptExecutionSettingsArePropagatedFromInvokePromptStreamingToFilterContextAsync() + { + // Arrange + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]); + + AutoFunctionInvocationContext? actualContext = null; + + var kernel = this.GetKernelWithFilter(plugin, (context, next) => + { + actualContext = context; + return Task.CompletedTask; + }); + + var expectedExecutionSettings = new PromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + }; + + // Act + await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(expectedExecutionSettings))) + { } + + // Assert + Assert.NotNull(actualContext); + Assert.Same(expectedExecutionSettings, actualContext!.ExecutionSettings); + } + + public void Dispose() + { + this._httpClient.Dispose(); + this._messageHandlerStub.Dispose(); + } + + #region private + + private static object? GetLastFunctionResultFromChatResponse(ChatResponse chatResponse) + { + Assert.NotEmpty(chatResponse.Messages); + var chatMessage = chatResponse.Messages.Where(m => m.Role == ChatRole.Tool).Last(); + + Assert.NotEmpty(chatMessage.Contents); + Assert.Contains(chatMessage.Contents, c => c is Microsoft.Extensions.AI.FunctionResultContent); + + var resultContent = (Microsoft.Extensions.AI.FunctionResultContent)chatMessage.Contents.Last(c => c is Microsoft.Extensions.AI.FunctionResultContent); + return resultContent.Result; + } + +#pragma warning disable CA2000 // Dispose objects before losing scope + private static List GetFunctionCallingResponses() + { + return [ + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_chatclient_multiple_function_calls_test_response.json")) }, + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_chatclient_multiple_function_calls_test_response.json")) }, + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_test_response.json")) } + ]; + } + + private static List GetFunctionCallingStreamingResponses() + { + return [ + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt")) }, + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt")) }, + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_streaming_test_response.txt")) } + ]; + } +#pragma warning restore CA2000 + + private Kernel GetKernelWithFilter( + KernelPlugin plugin, + Func, Task>? onAutoFunctionInvocation) + { + var builder = Kernel.CreateBuilder(); + var filter = new AutoFunctionInvocationFilter(onAutoFunctionInvocation); + + builder.Plugins.Add(plugin); + builder.Services.AddSingleton(filter); + + builder.AddOpenAIChatClient("model-id", "test-api-key", "organization-id", httpClient: this._httpClient); + + return builder.Build(); + } + + private sealed class AutoFunctionInvocationFilter( + Func, Task>? onAutoFunctionInvocation) : IAutoFunctionInvocationFilter + { + private readonly Func, Task>? _onAutoFunctionInvocation = onAutoFunctionInvocation; + + public Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) => + this._onAutoFunctionInvocation?.Invoke(context, next) ?? Task.CompletedTask; + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs index 19992be01667..b308206b12d5 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs @@ -312,7 +312,7 @@ public async Task FilterCanOverrideArgumentsAsync() // Act var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings { - ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() })); // Assert @@ -596,7 +596,7 @@ public async Task PostFilterCanTerminateOperationOnStreamingAsync() Assert.Equal([0], requestSequenceNumbers); Assert.Equal([0], functionSequenceNumbers); - // Results of function invoked before termination should be returned + // Results of function invoked before termination should be returned Assert.Equal(3, streamingContent.Count); var lastMessageContent = streamingContent[^1] as StreamingChatMessageContent; diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIKernelBuilderExtensionsChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIKernelBuilderExtensionsChatClientTests.cs new file mode 100644 index 000000000000..437d347aa194 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIKernelBuilderExtensionsChatClientTests.cs @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using OpenAI; +using Xunit; + +namespace SemanticKernel.Connectors.OpenAI.UnitTests.Extensions; + +public class OpenAIKernelBuilderExtensionsChatClientTests +{ + [Fact] + public void AddOpenAIChatClientNullArgsThrow() + { + // Arrange + IKernelBuilder builder = null!; + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + + // Act & Assert + var exception = Assert.Throws(() => builder.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId)); + Assert.Equal("builder", exception.ParamName); + + exception = Assert.Throws(() => builder.AddOpenAIChatClient(modelId, new OpenAIClient(apiKey), serviceId)); + Assert.Equal("builder", exception.ParamName); + + using var httpClient = new HttpClient(); + exception = Assert.Throws(() => builder.AddOpenAIChatClient(modelId, new Uri("http://localhost"), apiKey, orgId, serviceId, httpClient)); + Assert.Equal("builder", exception.ParamName); + } + + [Fact] + public void AddOpenAIChatClientDefaultValidParametersRegistersService() + { + // Arrange + var builder = Kernel.CreateBuilder(); + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + + // Act + builder.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId); + + // Assert + var kernel = builder.Build(); + Assert.NotNull(kernel.GetRequiredService()); + Assert.NotNull(kernel.GetRequiredService(serviceId)); + } + + [Fact] + public void AddOpenAIChatClientOpenAIClientValidParametersRegistersService() + { + // Arrange + var builder = Kernel.CreateBuilder(); + string modelId = "gpt-3.5-turbo"; + var openAIClient = new OpenAIClient("test_api_key"); + string serviceId = "test_service_id"; + + // Act + builder.AddOpenAIChatClient(modelId, openAIClient, serviceId); + + // Assert + var kernel = builder.Build(); + Assert.NotNull(kernel.GetRequiredService()); + Assert.NotNull(kernel.GetRequiredService(serviceId)); + } + + [Fact] + public void AddOpenAIChatClientCustomEndpointValidParametersRegistersService() + { + // Arrange + var builder = Kernel.CreateBuilder(); + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + using var httpClient = new HttpClient(); + + // Act + builder.AddOpenAIChatClient(modelId, new Uri("http://localhost"), apiKey, orgId, serviceId, httpClient); + + // Assert + var kernel = builder.Build(); + Assert.NotNull(kernel.GetRequiredService()); + Assert.NotNull(kernel.GetRequiredService(serviceId)); + } +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIServiceCollectionExtensionsChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIServiceCollectionExtensionsChatClientTests.cs new file mode 100644 index 000000000000..7a3888b95f30 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIServiceCollectionExtensionsChatClientTests.cs @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using OpenAI; +using Xunit; + +namespace SemanticKernel.Connectors.OpenAI.UnitTests.Extensions; + +public class OpenAIServiceCollectionExtensionsChatClientTests +{ + [Fact] + public void AddOpenAIChatClientNullArgsThrow() + { + // Arrange + ServiceCollection services = null!; + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + + // Act & Assert + var exception = Assert.Throws(() => services.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId)); + Assert.Equal("services", exception.ParamName); + + exception = Assert.Throws(() => services.AddOpenAIChatClient(modelId, new OpenAIClient(apiKey), serviceId)); + Assert.Equal("services", exception.ParamName); + + using var httpClient = new HttpClient(); + exception = Assert.Throws(() => services.AddOpenAIChatClient(modelId, new Uri("http://localhost"), apiKey, orgId, serviceId, httpClient)); + Assert.Equal("services", exception.ParamName); + } + + [Fact] + public void AddOpenAIChatClientDefaultValidParametersRegistersService() + { + // Arrange + var services = new ServiceCollection(); + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + + // Act + services.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId); + + // Assert + var serviceProvider = services.BuildServiceProvider(); + var chatClient = serviceProvider.GetKeyedService(serviceId); + Assert.NotNull(chatClient); + } + + [Fact] + public void AddOpenAIChatClientOpenAIClientValidParametersRegistersService() + { + // Arrange + var services = new ServiceCollection(); + string modelId = "gpt-3.5-turbo"; + var openAIClient = new OpenAIClient("test_api_key"); + string serviceId = "test_service_id"; + + // Act + services.AddOpenAIChatClient(modelId, openAIClient, serviceId); + + // Assert + var serviceProvider = services.BuildServiceProvider(); + var chatClient = serviceProvider.GetKeyedService(serviceId); + Assert.NotNull(chatClient); + } + + [Fact] + public void AddOpenAIChatClientCustomEndpointValidParametersRegistersService() + { + // Arrange + var services = new ServiceCollection(); + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + using var httpClient = new HttpClient(); + // Act + services.AddOpenAIChatClient(modelId, new Uri("http://localhost"), apiKey, orgId, serviceId, httpClient); + // Assert + var serviceProvider = services.BuildServiceProvider(); + var chatClient = serviceProvider.GetKeyedService(serviceId); + Assert.NotNull(chatClient); + } + + [Fact] + public void AddOpenAIChatClientWorksWithKernel() + { + var services = new ServiceCollection(); + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + + // Act + services.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId); + services.AddKernel(); + + var serviceProvider = services.BuildServiceProvider(); + var kernel = serviceProvider.GetRequiredService(); + + var serviceFromCollection = serviceProvider.GetKeyedService(serviceId); + var serviceFromKernel = kernel.GetRequiredService(serviceId); + + Assert.NotNull(serviceFromKernel); + Assert.Same(serviceFromCollection, serviceFromKernel); + } +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/KernelCore/KernelTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/KernelCore/KernelTests.cs index 4909504d743e..621f5124c8ab 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/KernelCore/KernelTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/KernelCore/KernelTests.cs @@ -61,15 +61,26 @@ public async Task FunctionUsageMetricsLoggingHasAllNeededData() public async Task FunctionUsageMetricsAreCapturedByTelemetryAsExpected() { // Set up a MeterListener to capture the measurements - using MeterListener listener = EnableTelemetryMeters(); + using MeterListener listener = new(); + var isPublished = false; - var measurements = new Dictionary> + var measurements = new Dictionary> { ["semantic_kernel.function.invocation.token_usage.prompt"] = [], ["semantic_kernel.function.invocation.token_usage.completion"] = [], }; - listener.SetMeasurementEventCallback((instrument, measurement, tags, state) => + listener.InstrumentPublished = (instrument, listener) => + { + if (instrument.Name == "semantic_kernel.function.invocation.token_usage.prompt" || + instrument.Name == "semantic_kernel.function.invocation.token_usage.completion") + { + isPublished = true; + listener.EnableMeasurementEvents(instrument); + } + }; + + listener.SetMeasurementEventCallback((instrument, measurement, tags, state) => { if (instrument.Name == "semantic_kernel.function.invocation.token_usage.prompt" || instrument.Name == "semantic_kernel.function.invocation.token_usage.completion") @@ -106,6 +117,8 @@ public async Task FunctionUsageMetricsAreCapturedByTelemetryAsExpected() listener.Dispose(); + Assert.True(isPublished); + while (!completed) { // Wait for the measurements to be completed diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/chat_completion_streaming_chatclient_multiple_function_calls_test_response.txt b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/chat_completion_streaming_chatclient_multiple_function_calls_test_response.txt new file mode 100644 index 000000000000..17ce94647fd5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/chat_completion_streaming_chatclient_multiple_function_calls_test_response.txt @@ -0,0 +1,9 @@ +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":0,"id":"1","type":"function","function":{"name":"MyPlugin_GetCurrentWeather","arguments":"{\n\"location\": \"Boston, MA\"\n}"}}]},"finish_reason":"tool_calls"}]} + +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":1,"id":"2","type":"function","function":{"name":"MyPlugin_FunctionWithException","arguments":"{\n\"argument\": \"value\"\n}"}}]},"finish_reason":"tool_calls"}]} + +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":2,"id":"3","type":"function","function":{"name":"MyPlugin_NonExistentFunction","arguments":"{\n\"argument\": \"value\"\n}"}}]},"finish_reason":"tool_calls"}]} + +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":3,"id":"4","type":"function","function":{"name":"MyPlugin_InvalidArguments","arguments":"invalid_arguments_format"}}]},"finish_reason":"tool_calls"}]} + +data: [DONE] diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_multiple_function_calls_test_response.json b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_multiple_function_calls_test_response.json new file mode 100644 index 000000000000..2c499b14089f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_multiple_function_calls_test_response.json @@ -0,0 +1,40 @@ +{ + "id": "response-id", + "object": "chat.completion", + "created": 1699896916, + "model": "gpt-3.5-turbo-0613", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "1", + "type": "function", + "function": { + "name": "MyPlugin_Function1", + "arguments": "{\n\"parameter\": \"function1-value\"\n}" + } + }, + { + "id": "2", + "type": "function", + "function": { + "name": "MyPlugin_Function2", + "arguments": "{\n\"parameter\": \"function2-value\"\n}" + } + } + ] + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 82, + "completion_tokens": 17, + "total_tokens": 99 + } +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt new file mode 100644 index 000000000000..c113e3fa97ca --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt @@ -0,0 +1,5 @@ +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":0,"id":"1","type":"function","function":{"name":"MyPlugin_Function1","arguments":"{\n\"parameter\": \"function1-value\"\n}"}}]},"finish_reason":"tool_calls"}]} + +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":1,"id":"2","type":"function","function":{"name":"MyPlugin_Function2","arguments":"{\n\"parameter\": \"function2-value\"\n}"}}]},"finish_reason":"tool_calls"}]} + +data: [DONE] diff --git a/dotnet/src/Connectors/Connectors.OpenAI/CompatibilitySuppressions.xml b/dotnet/src/Connectors/Connectors.OpenAI/CompatibilitySuppressions.xml new file mode 100644 index 000000000000..f37e1825f8d4 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI/CompatibilitySuppressions.xml @@ -0,0 +1,32 @@ + + + + + CP0002 + M:Microsoft.Extensions.DependencyInjection.OpenAIServiceCollectionExtensions.AddOpenAIEmbeddingGenerator(Microsoft.Extensions.DependencyInjection.IServiceCollection,System.String,OpenAI.OpenAIClient,System.Nullable{System.Int32},System.String) + lib/net8.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll + lib/net8.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll + true + + + CP0002 + M:Microsoft.Extensions.DependencyInjection.OpenAIServiceCollectionExtensions.AddOpenAIEmbeddingGenerator(Microsoft.Extensions.DependencyInjection.IServiceCollection,System.String,System.String,System.String,System.Nullable{System.Int32},System.String,System.Net.Http.HttpClient) + lib/net8.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll + lib/net8.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll + true + + + CP0002 + M:Microsoft.Extensions.DependencyInjection.OpenAIServiceCollectionExtensions.AddOpenAIEmbeddingGenerator(Microsoft.Extensions.DependencyInjection.IServiceCollection,System.String,OpenAI.OpenAIClient,System.Nullable{System.Int32},System.String) + lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll + lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll + true + + + CP0002 + M:Microsoft.Extensions.DependencyInjection.OpenAIServiceCollectionExtensions.AddOpenAIEmbeddingGenerator(Microsoft.Extensions.DependencyInjection.IServiceCollection,System.String,System.String,System.String,System.Nullable{System.Int32},System.String,System.Net.Http.HttpClient) + lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll + lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll + true + + \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj b/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj index ea6cd02fa983..2f280b843e10 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj +++ b/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj @@ -37,6 +37,7 @@ + diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs b/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs index 7e166e6679e8..08a1f486123f 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs @@ -846,7 +846,7 @@ private static List CreateRequestMessages(ChatMessageContent messag // HTTP 400 (invalid_request_error:) [] should be non-empty - 'messages.3.tool_calls' if (toolCalls.Count == 0) { - return [new AssistantChatMessage(message.Content) { ParticipantName = message.AuthorName }]; + return [new AssistantChatMessage(message.Content ?? string.Empty) { ParticipantName = message.AuthorName }]; } var assistantMessage = new AssistantChatMessage(SanitizeFunctionNames(toolCalls)) { ParticipantName = message.AuthorName }; diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.cs b/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.cs index 146701cdfaf0..d5533a217f5b 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.cs @@ -190,17 +190,27 @@ internal void AddAttribute(string key, string? value) /// Gets options to use for an OpenAIClient /// Custom for HTTP requests. /// Endpoint for the OpenAI API. + /// /// An instance of . - private static OpenAIClientOptions GetOpenAIClientOptions(HttpClient? httpClient, Uri? endpoint) + internal static OpenAIClientOptions GetOpenAIClientOptions(HttpClient? httpClient, Uri? endpoint = null, string? orgId = null) { OpenAIClientOptions options = new() { UserAgentApplicationId = HttpHeaderConstant.Values.UserAgent, - Endpoint = endpoint }; + if (endpoint is not null) + { + options.Endpoint = endpoint; + } + options.AddPolicy(CreateRequestHeaderPolicy(HttpHeaderConstant.Names.SemanticKernelVersion, HttpHeaderConstant.Values.GetAssemblyVersion(typeof(ClientCore))), PipelinePosition.PerCall); + if (orgId is not null) + { + options.OrganizationId = orgId; + } + if (httpClient is not null) { options.Transport = new HttpClientPipelineTransport(httpClient); diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs new file mode 100644 index 000000000000..9b038f10e6f8 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Net.Http; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using OpenAI; + +namespace Microsoft.SemanticKernel; + +/// Extension methods for . +[Experimental("SKEXP0010")] +public static class OpenAIChatClientKernelBuilderExtensions +{ + #region Chat Completion + + /// + /// Adds an OpenAI to the . + /// + /// The instance to augment. + /// OpenAI model name, see https://platform.openai.com/docs/models + /// OpenAI API key, see https://platform.openai.com/account/api-keys + /// OpenAI organization id. This is usually optional unless your account belongs to multiple organizations. + /// A local identifier for the given AI service + /// The HttpClient to use with this service. + /// The same instance as . + public static IKernelBuilder AddOpenAIChatClient( + this IKernelBuilder builder, + string modelId, + string apiKey, + string? orgId = null, + string? serviceId = null, + HttpClient? httpClient = null) + { + Verify.NotNull(builder); + + builder.Services.AddOpenAIChatClient( + modelId, + apiKey, + orgId, + serviceId, + httpClient); + + return builder; + } + + /// + /// Adds an OpenAI to the . + /// + /// The instance to augment. + /// OpenAI model id + /// to use for the service. If null, one must be available in the service provider when this service is resolved. + /// A local identifier for the given AI service + /// The same instance as . + public static IKernelBuilder AddOpenAIChatClient( + this IKernelBuilder builder, + string modelId, + OpenAIClient? openAIClient = null, + string? serviceId = null) + { + Verify.NotNull(builder); + + builder.Services.AddOpenAIChatClient( + modelId, + openAIClient, + serviceId); + + return builder; + } + + /// + /// Adds a custom endpoint OpenAI to the . + /// + /// The instance to augment. + /// OpenAI model name, see https://platform.openai.com/docs/models + /// Custom OpenAI Compatible Message API endpoint + /// OpenAI API key, see https://platform.openai.com/account/api-keys + /// OpenAI organization id. This is usually optional unless your account belongs to multiple organizations. + /// A local identifier for the given AI service + /// The HttpClient to use with this service. + /// The same instance as . + public static IKernelBuilder AddOpenAIChatClient( + this IKernelBuilder builder, + string modelId, + Uri endpoint, + string? apiKey, + string? orgId = null, + string? serviceId = null, + HttpClient? httpClient = null) + { + Verify.NotNull(builder); + + builder.Services.AddOpenAIChatClient( + modelId, + endpoint, + apiKey, + orgId, + serviceId, + httpClient); + + return builder; + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.cs index 40e08af474b7..1d6f79766974 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.cs @@ -21,7 +21,7 @@ namespace Microsoft.SemanticKernel; /// -/// Sponsor extensions class for . +/// Extension methods for . /// public static class OpenAIKernelBuilderExtensions { @@ -337,7 +337,7 @@ public static IKernelBuilder AddOpenAIFiles( #region Chat Completion /// - /// Adds the OpenAI chat completion service to the list. + /// Adds an to the . /// /// The instance to augment. /// OpenAI model name, see https://platform.openai.com/docs/models @@ -372,7 +372,7 @@ OpenAIChatCompletionService Factory(IServiceProvider serviceProvider, object? _) } /// - /// Adds the OpenAI chat completion service to the list. + /// Adds an to the . /// /// The instance to augment. /// OpenAI model id @@ -398,7 +398,7 @@ OpenAIChatCompletionService Factory(IServiceProvider serviceProvider, object? _) } /// - /// Adds the Custom Endpoint OpenAI chat completion service to the list. + /// Adds a custom endpoint to the . /// /// The instance to augment. /// OpenAI model name, see https://platform.openai.com/docs/models diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.DependencyInjection.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.DependencyInjection.cs index 78f404b650fc..2c1b865de37d 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.DependencyInjection.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.DependencyInjection.cs @@ -1,12 +1,14 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.ClientModel; +using System.ClientModel.Primitives; using System.Diagnostics.CodeAnalysis; using System.Net.Http; using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.OpenAI; -using Microsoft.SemanticKernel.Embeddings; using Microsoft.SemanticKernel.Http; using OpenAI; @@ -15,9 +17,170 @@ namespace Microsoft.Extensions.DependencyInjection; /// /// Sponsor extensions class for . /// -public static partial class OpenAIServiceCollectionExtensions +public static class OpenAIServiceCollectionExtensions { - #region Text Embedding + #region Chat Client + + /// + /// White space constant. + /// + private const string SingleSpace = " "; + + /// + /// Adds the OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// OpenAI model name, see https://platform.openai.com/docs/models + /// OpenAI API key, see https://platform.openai.com/account/api-keys + /// OpenAI organization id. This is usually optional unless your account belongs to multiple organizations. + /// A local identifier for the given AI service + /// The HttpClient to use with this service. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The same instance as . + public static IServiceCollection AddOpenAIChatClient( + this IServiceCollection services, + string modelId, + string apiKey, + string? orgId = null, + string? serviceId = null, + HttpClient? httpClient = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(services); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + var loggerFactory = serviceProvider.GetService(); + + ClientCore.GetOpenAIClientOptions( + endpoint: null, + orgId: orgId, + httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider)); + + var builder = new OpenAIClient( + credential: new ApiKeyCredential(apiKey ?? SingleSpace), + options: ClientCore.GetOpenAIClientOptions(HttpClientProvider.GetHttpClient(httpClient, serviceProvider))) + .GetChatClient(modelId) + .AsIChatClient() + .AsBuilder() + .UseKernelFunctionInvocation(loggerFactory) + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + + /// + /// Adds the OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// OpenAI model id + /// to use for the service. If null, one must be available in the service provider when this service is resolved. + /// A local identifier for the given AI service + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The same instance as . + public static IServiceCollection AddOpenAIChatClient(this IServiceCollection services, + string modelId, + OpenAIClient? openAIClient = null, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(services); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + var loggerFactory = serviceProvider.GetService(); + + var builder = (openAIClient ?? serviceProvider.GetRequiredService()) + .GetChatClient(modelId) + .AsIChatClient() + .AsBuilder() + .UseKernelFunctionInvocation(loggerFactory) + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + + /// + /// Adds the Custom OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// OpenAI model name, see https://platform.openai.com/docs/models + /// A Custom Message API compatible endpoint. + /// OpenAI API key, see https://platform.openai.com/account/api-keys + /// OpenAI organization id. This is usually optional unless your account belongs to multiple organizations. + /// A local identifier for the given AI service + /// The HttpClient to use with this service. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The same instance as . + public static IServiceCollection AddOpenAIChatClient( + this IServiceCollection services, + string modelId, + Uri endpoint, + string? apiKey = null, + string? orgId = null, + string? serviceId = null, + HttpClient? httpClient = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(services); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + var loggerFactory = serviceProvider.GetService(); + + var builder = new OpenAIClient( + credential: new ApiKeyCredential(apiKey ?? SingleSpace), + options: ClientCore.GetOpenAIClientOptions( + httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider), + endpoint: endpoint, + orgId: orgId)) + .GetChatClient(modelId) + .AsIChatClient() + .AsBuilder() + .UseKernelFunctionInvocation(loggerFactory) + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + #endregion + + #region Embedding Generator /// /// Adds the to the . /// @@ -28,6 +191,8 @@ public static partial class OpenAIServiceCollectionExtensions /// The number of dimensions the resulting output embeddings should have. Only supported in "text-embedding-3" and later models. /// A local identifier for the given AI service /// The HttpClient to use with this service. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. /// The same instance as . [Experimental("SKEXP0010")] public static IServiceCollection AddOpenAIEmbeddingGenerator( @@ -37,22 +202,35 @@ public static IServiceCollection AddOpenAIEmbeddingGenerator( string? orgId = null, int? dimensions = null, string? serviceId = null, - HttpClient? httpClient = null) + HttpClient? httpClient = null, + string? openTelemetrySourceName = null, + Action>>? openTelemetryConfig = null) { Verify.NotNull(services); Verify.NotNullOrWhiteSpace(modelId); Verify.NotNullOrWhiteSpace(apiKey); return services.AddKeyedSingleton>>(serviceId, (serviceProvider, _) => -#pragma warning disable CS0618 // Type or member is obsolete - new OpenAITextEmbeddingGenerationService( - modelId, - apiKey, - orgId, - HttpClientProvider.GetHttpClient(httpClient, serviceProvider), - serviceProvider.GetService(), dimensions) - .AsEmbeddingGenerator()); -#pragma warning restore CS0618 // Type or member is obsolete + { + var loggerFactory = serviceProvider.GetService(); + + var builder = new OpenAIClient( + credential: new ApiKeyCredential(apiKey ?? SingleSpace), + options: ClientCore.GetOpenAIClientOptions( + httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider), + orgId: orgId)) + .GetEmbeddingClient(modelId) + .AsIEmbeddingGenerator(dimensions) + .AsBuilder() + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + }); } /// @@ -63,26 +241,63 @@ public static IServiceCollection AddOpenAIEmbeddingGenerator( /// to use for the service. If null, one must be available in the service provider when this service is resolved. /// The number of dimensions the resulting output embeddings should have. Only supported in "text-embedding-3" and later models. /// A local identifier for the given AI service + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. /// The same instance as . [Experimental("SKEXP0010")] public static IServiceCollection AddOpenAIEmbeddingGenerator(this IServiceCollection services, string modelId, OpenAIClient? openAIClient = null, int? dimensions = null, - string? serviceId = null) + string? serviceId = null, + string? openTelemetrySourceName = null, + Action>>? openTelemetryConfig = null) { Verify.NotNull(services); Verify.NotNullOrWhiteSpace(modelId); -#pragma warning disable CS0618 // Type or member is obsolete return services.AddKeyedSingleton>>(serviceId, (serviceProvider, _) => - new OpenAITextEmbeddingGenerationService( - modelId, - openAIClient ?? serviceProvider.GetRequiredService(), - serviceProvider.GetService(), - dimensions) - .AsEmbeddingGenerator()); -#pragma warning restore CS0618 // Type or member is obsolete + { + var loggerFactory = serviceProvider.GetService(); + + var builder = (openAIClient ?? serviceProvider.GetRequiredService()) + .GetEmbeddingClient(modelId) + .AsIEmbeddingGenerator(dimensions) + .AsBuilder() + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + }); } #endregion + + private static OpenAIClientOptions GetClientOptions( + Uri? endpoint = null, + string? orgId = null, + HttpClient? httpClient = null) + { + OpenAIClientOptions options = new(); + + if (endpoint is not null) + { + options.Endpoint = endpoint; + } + + if (orgId is not null) + { + options.OrganizationId = orgId; + } + + if (httpClient is not null) + { + options.Transport = new HttpClientPipelineTransport(httpClient); + } + + return options; + } } diff --git a/dotnet/src/Experimental/Process.IntegrationTests.Shared/ProcessTests.cs b/dotnet/src/Experimental/Process.IntegrationTests.Shared/ProcessTests.cs index 8d3c810383e8..24932e1e72e6 100644 --- a/dotnet/src/Experimental/Process.IntegrationTests.Shared/ProcessTests.cs +++ b/dotnet/src/Experimental/Process.IntegrationTests.Shared/ProcessTests.cs @@ -1,9 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. #pragma warning disable IDE0005 // Using directive is unnecessary. -using System; using System.Linq; -using System.Runtime.Serialization; using System.Threading.Tasks; using Microsoft.Extensions.Configuration; using Microsoft.SemanticKernel; @@ -477,6 +475,7 @@ private void AssertStepStateLastMessage(KernelProcess processInfo, string stepNa } } +#if !NET private void AssertStepState(KernelProcess processInfo, string stepName, Predicate> predicate) where T : class, new() { KernelProcessStepInfo? stepInfo = processInfo.Steps.FirstOrDefault(s => s.State.Name == stepName); @@ -485,5 +484,6 @@ private void AssertStepStateLastMessage(KernelProcess processInfo, string stepNa Assert.NotNull(outputStepResult?.State); Assert.True(predicate(outputStepResult)); } +#endif #endregion } diff --git a/dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs b/dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs index 28305a5ec8a4..cf07d36be1f5 100644 --- a/dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs +++ b/dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs @@ -13,6 +13,8 @@ using System.Threading.Tasks; using Microsoft.SemanticKernel.Http; +#pragma warning disable CA1859 // Use concrete types when possible for improved performance + namespace Microsoft.SemanticKernel.Plugins.OpenApi; /// diff --git a/dotnet/src/IntegrationTests/Connectors/Amazon/Bedrock/BedrockChatClientTests.cs b/dotnet/src/IntegrationTests/Connectors/Amazon/Bedrock/BedrockChatClientTests.cs new file mode 100644 index 000000000000..8d8ebf1eff5f --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Amazon/Bedrock/BedrockChatClientTests.cs @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Amazon; + +public class BedrockChatClientTests +{ + [Theory(Skip = "For manual verification only")] + [InlineData("ai21.jamba-instruct-v1:0")] + [InlineData("amazon.titan-text-premier-v1:0")] + [InlineData("amazon.titan-text-lite-v1")] + [InlineData("amazon.titan-text-express-v1")] + [InlineData("anthropic.claude-v2")] + [InlineData("anthropic.claude-v2:1")] + [InlineData("anthropic.claude-instant-v1")] + [InlineData("anthropic.claude-3-sonnet-20240229-v1:0")] + [InlineData("anthropic.claude-3-haiku-20240307-v1:0")] + [InlineData("cohere.command-r-v1:0")] + [InlineData("cohere.command-r-plus-v1:0")] + [InlineData("meta.llama3-70b-instruct-v1:0")] + [InlineData("meta.llama3-8b-instruct-v1:0")] + [InlineData("mistral.mistral-7b-instruct-v0:2")] + [InlineData("mistral.mistral-large-2402-v1:0")] + [InlineData("mistral.mistral-small-2402-v1:0")] + [InlineData("mistral.mixtral-8x7b-instruct-v0:1")] + public async Task ChatCompletionReturnsValidResponseAsync(string modelId) + { + // Arrange + var kernel = Kernel.CreateBuilder().AddBedrockChatClient(modelId).Build(); + + // Act + var message = await kernel.InvokePromptAsync("Hello, I'm Alexa, how are you?").ConfigureAwait(true); + + // Assert + Assert.NotNull(message); + Assert.Equal(ChatRole.Assistant, message.Role); + Assert.NotNull(message.Text); + } + + [Theory(Skip = "For manual verification only")] + [InlineData("ai21.jamba-instruct-v1:0")] + [InlineData("amazon.titan-text-premier-v1:0")] + [InlineData("amazon.titan-text-lite-v1")] + [InlineData("amazon.titan-text-express-v1")] + [InlineData("anthropic.claude-v2")] + [InlineData("anthropic.claude-v2:1")] + [InlineData("anthropic.claude-instant-v1")] + [InlineData("anthropic.claude-3-sonnet-20240229-v1:0")] + [InlineData("anthropic.claude-3-haiku-20240307-v1:0")] + [InlineData("cohere.command-r-v1:0")] + [InlineData("cohere.command-r-plus-v1:0")] + [InlineData("meta.llama3-70b-instruct-v1:0")] + [InlineData("meta.llama3-8b-instruct-v1:0")] + [InlineData("mistral.mistral-7b-instruct-v0:2")] + [InlineData("mistral.mistral-large-2402-v1:0")] + [InlineData("mistral.mistral-small-2402-v1:0")] + [InlineData("mistral.mixtral-8x7b-instruct-v0:1")] + public async Task ChatStreamingReturnsValidResponseAsync(string modelId) + { + // Arrange + var kernel = Kernel.CreateBuilder().AddBedrockChatClient(modelId).Build(); + + // Act + var response = kernel.InvokePromptStreamingAsync("Hello, I'm Alexa, how are you?").ConfigureAwait(true); + string output = ""; + await foreach (var message in response) + { + output += message.Text; + Assert.NotNull(message.RawRepresentation); + } + + // Assert + Assert.NotNull(output); + Assert.False(string.IsNullOrEmpty(output)); + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/AzureAIInference/AzureAIInferenceChatClientTests.cs b/dotnet/src/IntegrationTests/Connectors/AzureAIInference/AzureAIInferenceChatClientTests.cs new file mode 100644 index 000000000000..d31737f5c697 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/AzureAIInference/AzureAIInferenceChatClientTests.cs @@ -0,0 +1,265 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading.Tasks; +using Azure; +using Azure.AI.Inference; +using Azure.Identity; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Http.Resilience; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.AzureAIInference; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.IntegrationTests.Connectors.AzureAIInference; + +#pragma warning disable xUnit1004 // Contains test methods used in manual verification. Disable warning for this file only. + +public sealed class AzureAIInferenceChatClientTests(ITestOutputHelper output) : BaseIntegrationTest, IDisposable +{ + private const string SkipReason = "For manual verification only"; + private const string InputParameterName = "input"; + private readonly XunitLogger _loggerFactory = new(output); + private readonly RedirectOutput _testOutputHelper = new(output); + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + [Theory(Skip = SkipReason)] + [InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")] + public async Task InvokeGetResponseAsync(string prompt, string expectedAnswerContains) + { + // Arrange + var config = this._configuration.GetSection("AzureAIInference").Get(); + Assert.NotNull(config); + + IChatClient sut = this.CreateChatClient(config); + + List chatHistory = [new(Microsoft.Extensions.AI.ChatRole.User, prompt)]; + + // Act + var result = await sut.GetResponseAsync(chatHistory); + + // Assert + Assert.NotNull(result); + Assert.Contains(expectedAnswerContains, result.Text, StringComparison.OrdinalIgnoreCase); + } + + [Theory(Skip = SkipReason)] + [InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")] + public async Task InvokeGetStreamingResponseAsync(string prompt, string expectedAnswerContains) + { + // Arrange + var config = this._configuration.GetSection("AzureAIInference").Get(); + Assert.NotNull(config); + + IChatClient sut = this.CreateChatClient(config); + + List chatHistory = [new(Microsoft.Extensions.AI.ChatRole.User, prompt)]; + + StringBuilder fullContent = new(); + + // Act + await foreach (var update in sut.GetStreamingResponseAsync(chatHistory)) + { + fullContent.Append(update.Text); + } + + // Assert + Assert.Contains(expectedAnswerContains, fullContent.ToString(), StringComparison.OrdinalIgnoreCase); + } + + [Fact(Skip = SkipReason)] + public async Task ItCanUseChatForTextGenerationAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var func = kernel.CreateFunctionFromPrompt( + "List the two planets after '{{$input}}', excluding moons, using bullet points.", + new AzureAIInferencePromptExecutionSettings()); + + // Act + var result = await func.InvokeAsync(kernel, new() { [InputParameterName] = "Jupiter" }); + + // Assert + Assert.NotNull(result); + Assert.Contains("Saturn", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + Assert.Contains("Uranus", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + } + + [Fact(Skip = SkipReason)] + public async Task ItStreamingFromKernelTestAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin"); + + StringBuilder fullResult = new(); + + var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; + + // Act + await foreach (var content in kernel.InvokeStreamingAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt })) + { + fullResult.Append(content); + } + + // Assert + Assert.Contains("Pike Place", fullResult.ToString(), StringComparison.OrdinalIgnoreCase); + } + + [Fact(Skip = SkipReason)] + public async Task ItHttpRetryPolicyTestAsync() + { + // Arrange + List statusCodes = []; + + var config = this._configuration.GetSection("AzureAIInference").Get(); + Assert.NotNull(config); + Assert.NotNull(config.Endpoint); + Assert.NotNull(config.ChatModelId); + + var kernelBuilder = Kernel.CreateBuilder(); + + kernelBuilder.AddAzureAIInferenceChatClient(modelId: config.ChatModelId, endpoint: config.Endpoint, apiKey: "wrong"); + + kernelBuilder.Services.ConfigureHttpClientDefaults(c => + { + // Use a standard resiliency policy, augmented to retry on 401 Unauthorized for this example + c.AddStandardResilienceHandler().Configure(o => + { + o.Retry.ShouldHandle = args => ValueTask.FromResult(args.Outcome.Result?.StatusCode is HttpStatusCode.Unauthorized); + o.Retry.OnRetry = args => + { + statusCodes.Add(args.Outcome.Result?.StatusCode); + return ValueTask.CompletedTask; + }; + }); + }); + + var target = kernelBuilder.Build(); + + var plugins = TestHelpers.ImportSamplePlugins(target, "SummarizePlugin"); + + var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; + + // Act + var exception = await Assert.ThrowsAsync(() => target.InvokeAsync(plugins["SummarizePlugin"]["Summarize"], new() { [InputParameterName] = prompt })); + + // Assert + Assert.All(statusCodes, s => Assert.Equal(HttpStatusCode.Unauthorized, s)); + Assert.Equal((int)HttpStatusCode.Unauthorized, ((RequestFailedException)exception).Status); + } + + [Fact(Skip = SkipReason)] + public async Task ItShouldReturnInnerContentAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "FunPlugin"); + + // Act + var result = await kernel.InvokeAsync(plugins["FunPlugin"]["Limerick"]); + var content = result.GetValue(); + // Assert + Assert.NotNull(content); + Assert.NotNull(content.InnerContent); + + Assert.IsType(content.InnerContent); + var completions = (ChatCompletions)content.InnerContent; + var usage = completions.Usage; + + // Usage + Assert.NotEqual(0, usage.PromptTokens); + Assert.NotEqual(0, usage.CompletionTokens); + } + + [Theory(Skip = SkipReason)] + [InlineData("\n")] + [InlineData("\r\n")] + public async Task CompletionWithDifferentLineEndingsAsync(string lineEnding) + { + // Arrange + var prompt = + "Given a json input and a request. Apply the request on the json input and return the result. " + + $"Put the result in between tags{lineEnding}" + + $$"""Input:{{lineEnding}}{"name": "John", "age": 30}{{lineEnding}}{{lineEnding}}Request:{{lineEnding}}name"""; + + var kernel = this.CreateAndInitializeKernel(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin"); + + // Act + FunctionResult actual = await kernel.InvokeAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt }); + + // Assert + Assert.Contains("John", actual.GetValue(), StringComparison.OrdinalIgnoreCase); + } + + private Kernel CreateAndInitializeKernel(HttpClient? httpClient = null) + { + var config = this._configuration.GetSection("AzureAIInference").Get(); + Assert.NotNull(config); + Assert.NotNull(config.ApiKey); + Assert.NotNull(config.Endpoint); + Assert.NotNull(config.ChatModelId); + + var kernelBuilder = base.CreateKernelBuilder(); + + kernelBuilder.AddAzureAIInferenceChatClient( + config.ChatModelId, + endpoint: config.Endpoint, + apiKey: config.ApiKey, + serviceId: config.ServiceId, + httpClient: httpClient); + + return kernelBuilder.Build(); + } + + private IChatClient CreateChatClient(AzureAIInferenceConfiguration config) + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(this._loggerFactory); + + Assert.NotNull(config.ChatModelId); + + if (config.ApiKey is not null) + { + serviceCollection.AddAzureAIInferenceChatClient( + modelId: config.ChatModelId, + endpoint: config.Endpoint, + apiKey: config.ApiKey); + } + else + { + serviceCollection.AddAzureAIInferenceChatClient( + modelId: config.ChatModelId, + endpoint: config.Endpoint, + credential: new AzureCliCredential()); + } + + var serviceProvider = serviceCollection.BuildServiceProvider(); + + return serviceProvider.GetRequiredService(); + } + + public void Dispose() + { + this._loggerFactory.Dispose(); + this._testOutputHelper.Dispose(); + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatClientTests.cs b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatClientTests.cs new file mode 100644 index 000000000000..527f6d10fa2b --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAIChatClientTests.cs @@ -0,0 +1,260 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Azure.Identity; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Http.Resilience; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.AzureOpenAI; +using Microsoft.SemanticKernel.Connectors.OpenAI; +using OpenAI.Chat; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.AzureOpenAI; + +public sealed class AzureOpenAIChatClientTests : BaseIntegrationTest +{ + [Fact] + public async Task ItCanUseAzureOpenAiChatForTextGenerationAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var func = kernel.CreateFunctionFromPrompt( + "List the two planets after '{{$input}}', excluding moons, using bullet points.", + new OpenAIPromptExecutionSettings()); + + // Act + var result = await func.InvokeAsync(kernel, new() { [InputParameterName] = "Jupiter" }); + + // Assert + Assert.NotNull(result); + Assert.Contains("Saturn", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + Assert.Contains("Uranus", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + } + + [Fact] + public async Task AzureOpenAIStreamingTestAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin"); + + StringBuilder fullResult = new(); + + var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; + + // Act + await foreach (var content in kernel.InvokeStreamingAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt })) + { + fullResult.Append(content); + } + + // Assert + Assert.Contains("Pike Place", fullResult.ToString(), StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task AzureOpenAIHttpRetryPolicyTestAsync() + { + // Arrange + List statusCodes = []; + + var config = this._configuration.GetSection("AzureOpenAI").Get(); + Assert.NotNull(config); + Assert.NotNull(config.DeploymentName); + Assert.NotNull(config.Endpoint); + + var kernelBuilder = Kernel.CreateBuilder(); + + kernelBuilder.AddAzureOpenAIChatCompletion( + deploymentName: config.DeploymentName, + endpoint: config.Endpoint, + apiKey: "INVALID_KEY"); + + kernelBuilder.Services.ConfigureHttpClientDefaults(c => + { + // Use a standard resiliency policy, augmented to retry on 401 Unauthorized for this example + c.AddStandardResilienceHandler().Configure(o => + { + o.Retry.ShouldHandle = args => ValueTask.FromResult(args.Outcome.Result?.StatusCode is HttpStatusCode.Unauthorized); + o.Retry.OnRetry = args => + { + statusCodes.Add(args.Outcome.Result?.StatusCode); + return ValueTask.CompletedTask; + }; + }); + }); + + var target = kernelBuilder.Build(); + + var plugins = TestHelpers.ImportSamplePlugins(target, "SummarizePlugin"); + + var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; + + // Act + var exception = await Assert.ThrowsAsync(() => target.InvokeAsync(plugins["SummarizePlugin"]["Summarize"], new() { [InputParameterName] = prompt })); + + // Assert + Assert.All(statusCodes, s => Assert.Equal(HttpStatusCode.Unauthorized, s)); + Assert.Equal(HttpStatusCode.Unauthorized, exception.StatusCode); + } + + [Fact] + public async Task AzureOpenAIShouldReturnUsageAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "FunPlugin"); + + // Act + var result = await kernel.InvokeAsync(plugins["FunPlugin"]["Limerick"]); + + // Assert + var chatResponse = result.GetValue(); + + Assert.NotNull(chatResponse); + Assert.NotNull(chatResponse.Usage); + Assert.NotEqual(0, chatResponse.Usage.InputTokenCount); + Assert.NotEqual(0, chatResponse.Usage.OutputTokenCount); + } + + [Theory(Skip = "This test is for manual verification.")] + [InlineData("\n")] + [InlineData("\r\n")] + public async Task CompletionWithDifferentLineEndingsAsync(string lineEnding) + { + // Arrange + var prompt = + "Given a json input and a request. Apply the request on the json input and return the result. " + + $"Put the result in between tags{lineEnding}" + + $$"""Input:{{lineEnding}}{"name": "John", "age": 30}{{lineEnding}}{{lineEnding}}Request:{{lineEnding}}name"""; + + var kernel = this.CreateAndInitializeKernel(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin"); + + // Act + FunctionResult actual = await kernel.InvokeAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt }); + + // Assert + Assert.Contains("John", actual.GetValue(), StringComparison.OrdinalIgnoreCase); + } + + [Fact(Skip = "Currently not supported - Chat System Prompt is not surfacing as a system message level")] + public async Task ChatSystemPromptIsNotIgnoredAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var settings = new OpenAIPromptExecutionSettings { ChatSystemPrompt = "Reply \"I don't know\" to every question." }; + + // Act + var result = await kernel.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?", new(settings)); + + // Assert + Assert.Contains("I don't know", result.ToString(), StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task SemanticKernelVersionHeaderIsSentAsync() + { + // Arrange + using var defaultHandler = new HttpClientHandler(); + using var httpHeaderHandler = new HttpHeaderHandler(defaultHandler); + using var httpClient = new HttpClient(httpHeaderHandler); + + var kernel = this.CreateAndInitializeKernel(httpClient); + + // Act + await kernel.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?"); + + // Assert + Assert.NotNull(httpHeaderHandler.RequestHeaders); + Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var _)); + } + + //[Theory(Skip = "This test is for manual verification.")] + [Theory(Skip = "Currently not supported - Log Probabilities is not surfacing to the API level")] + [InlineData(null, null)] + [InlineData(false, null)] + [InlineData(true, 2)] + [InlineData(true, 5)] + public async Task LogProbsDataIsReturnedWhenRequestedAsync(bool? logprobs, int? topLogprobs) + { + // Arrange + var settings = new AzureOpenAIPromptExecutionSettings { Logprobs = logprobs, TopLogprobs = topLogprobs }; + + var kernel = this.CreateAndInitializeKernel(); + + // Act + var result = await kernel.InvokePromptAsync("Hi, can you help me today?", new(settings)); + + var chatResponse = result.GetValue(); + var logProbabilityInfo = result.Metadata!["ContentTokenLogProbabilities"] as IReadOnlyList; + + // Assert + Assert.NotNull(logProbabilityInfo); + + if (logprobs is true) + { + Assert.NotNull(logProbabilityInfo); + Assert.Equal(topLogprobs, logProbabilityInfo[0].TopLogProbabilities.Count); + } + else + { + Assert.Empty(logProbabilityInfo); + } + } + private Kernel CreateAndInitializeKernel(HttpClient? httpClient = null) + { + var config = this._configuration.GetSection("AzureOpenAI").Get(); + Assert.NotNull(config); + Assert.NotNull(config.ChatDeploymentName); + Assert.NotNull(config.Endpoint); + Assert.NotNull(config.ServiceId); + + var kernelBuilder = this.CreateKernelBuilder(); + + kernelBuilder.AddAzureOpenAIChatClient( + deploymentName: config.ChatDeploymentName, + modelId: config.ChatModelId, + endpoint: config.Endpoint, + credentials: new AzureCliCredential(), + serviceId: config.ServiceId, + httpClient: httpClient); + + return kernelBuilder.Build(); + } + + private const string InputParameterName = "input"; + + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + private sealed class HttpHeaderHandler(HttpMessageHandler innerHandler) : DelegatingHandler(innerHandler) + { + public System.Net.Http.Headers.HttpRequestHeaders? RequestHeaders { get; private set; } + + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + this.RequestHeaders = request.Headers; + return await base.SendAsync(request, cancellationToken); + } + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAITextEmbeddingTests.cs b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAITextEmbeddingTests.cs index 401b29a2cba0..8b7bf3bf4b61 100644 --- a/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAITextEmbeddingTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/AzureOpenAI/AzureOpenAITextEmbeddingTests.cs @@ -3,7 +3,9 @@ using System; using System.Threading.Tasks; using Azure.Identity; +using Microsoft.Extensions.AI; using Microsoft.Extensions.Configuration; +using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.AzureOpenAI; using Microsoft.SemanticKernel.Embeddings; using SemanticKernel.IntegrationTests.TestSettings; @@ -72,3 +74,70 @@ public async Task AzureOpenAITextEmbeddingGenerationWithDimensionsAsync(int? dim .AddUserSecrets() .Build(); } + +public sealed class AzureOpenAIEmbeddingGeneratorTests +{ + public AzureOpenAIEmbeddingGeneratorTests() + { + var config = this._configuration.GetSection("AzureOpenAIEmbeddings").Get(); + Assert.NotNull(config); + this._azureOpenAIConfiguration = config; + } + + [Theory] + [InlineData("test sentence")] + public async Task AzureOpenAITestAsync(string testInputString) + { + // Arrange + var embeddingGenerator = Kernel.CreateBuilder() + .AddAzureOpenAIEmbeddingGenerator( + deploymentName: this._azureOpenAIConfiguration.DeploymentName, + endpoint: this._azureOpenAIConfiguration.Endpoint, + credential: new AzureCliCredential()) + .Build() + .GetRequiredService>>(); + + // Act + var singleResult = await embeddingGenerator.GenerateAsync(testInputString); + var batchResult = await embeddingGenerator.GenerateAsync([testInputString]); + + // Assert + Assert.Equal(AdaVectorLength, singleResult.Vector.Length); + Assert.Single(batchResult); + } + + [Theory] + [InlineData(null, 3072)] + [InlineData(1024, 1024)] + public async Task AzureOpenAITextEmbeddingGenerationWithDimensionsAsync(int? dimensions, int expectedVectorLength) + { + // Arrange + const string TestInputString = "test sentence"; + + var embeddingGenerator = Kernel.CreateBuilder() + .AddAzureOpenAIEmbeddingGenerator( + deploymentName: "text-embedding-3-large", + endpoint: this._azureOpenAIConfiguration.Endpoint, + credential: new AzureCliCredential(), + dimensions: dimensions) + .Build() + .GetRequiredService>>(); + + // Act + var result = await embeddingGenerator.GenerateAsync(TestInputString); + + // Assert + Assert.Equal(expectedVectorLength, result.Vector.Length); + } + + private readonly AzureOpenAIConfiguration _azureOpenAIConfiguration; + + private const int AdaVectorLength = 1536; + + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); +} diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIAudioToTextTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIAudioToTextTests.cs index 90375307c533..9e1127fa8b55 100644 --- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIAudioToTextTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIAudioToTextTests.cs @@ -22,7 +22,7 @@ public sealed class OpenAIAudioToTextTests() .AddUserSecrets() .Build(); - [RetryFact]//(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] + [RetryFact] //(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] public async Task OpenAIAudioToTextTestAsync() { // Arrange diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatClientTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatClientTests.cs new file mode 100644 index 000000000000..3351f2a78996 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatClientTests.cs @@ -0,0 +1,254 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Http.Resilience; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.OpenAI; +using OpenAI.Chat; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.OpenAI; + +public sealed class OpenAIChatClientTests : BaseIntegrationTest +{ + [Fact] + public async Task ItCanUseOpenAiChatForTextGenerationAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var func = kernel.CreateFunctionFromPrompt( + "List the two planets after '{{$input}}', excluding moons, using bullet points.", + new OpenAIPromptExecutionSettings()); + + // Act + var result = await func.InvokeAsync(kernel, new() { [InputParameterName] = "Jupiter" }); + + // Assert + Assert.NotNull(result); + Assert.Contains("Saturn", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + Assert.Contains("Uranus", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + } + + [Fact] + public async Task OpenAIStreamingTestAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin"); + + StringBuilder fullResult = new(); + + var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; + + // Act + await foreach (var content in kernel.InvokeStreamingAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt })) + { + fullResult.Append(content); + } + + // Assert + Assert.Contains("Pike Place", fullResult.ToString(), StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task OpenAIHttpRetryPolicyTestAsync() + { + // Arrange + List statusCodes = []; + + var openAIConfiguration = this._configuration.GetSection("OpenAI").Get(); + Assert.NotNull(openAIConfiguration); + Assert.NotNull(openAIConfiguration.ChatModelId); + + var kernelBuilder = Kernel.CreateBuilder(); + + kernelBuilder.AddOpenAIChatCompletion( + modelId: openAIConfiguration.ChatModelId, + apiKey: "INVALID_KEY"); + + kernelBuilder.Services.ConfigureHttpClientDefaults(c => + { + // Use a standard resiliency policy, augmented to retry on 401 Unauthorized for this example + c.AddStandardResilienceHandler().Configure(o => + { + o.Retry.ShouldHandle = args => ValueTask.FromResult(args.Outcome.Result?.StatusCode is HttpStatusCode.Unauthorized); + o.Retry.OnRetry = args => + { + statusCodes.Add(args.Outcome.Result?.StatusCode); + return ValueTask.CompletedTask; + }; + }); + }); + + var target = kernelBuilder.Build(); + + var plugins = TestHelpers.ImportSamplePlugins(target, "SummarizePlugin"); + + var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; + + // Act + var exception = await Assert.ThrowsAsync(() => target.InvokeAsync(plugins["SummarizePlugin"]["Summarize"], new() { [InputParameterName] = prompt })); + + // Assert + Assert.All(statusCodes, s => Assert.Equal(HttpStatusCode.Unauthorized, s)); + Assert.Equal(HttpStatusCode.Unauthorized, exception.StatusCode); + } + + [Fact] + public async Task OpenAIShouldReturnUsageAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "FunPlugin"); + + // Act + var result = await kernel.InvokeAsync(plugins["FunPlugin"]["Limerick"]); + + // Assert + var chatResponse = result.GetValue(); + + Assert.NotNull(chatResponse); + Assert.NotNull(chatResponse.Usage); + Assert.NotEqual(0, chatResponse.Usage.InputTokenCount); + Assert.NotEqual(0, chatResponse.Usage.OutputTokenCount); + } + + [Theory(Skip = "This test is for manual verification.")] + [InlineData("\n")] + [InlineData("\r\n")] + public async Task CompletionWithDifferentLineEndingsAsync(string lineEnding) + { + // Arrange + var prompt = + "Given a json input and a request. Apply the request on the json input and return the result. " + + $"Put the result in between tags{lineEnding}" + + $$"""Input:{{lineEnding}}{"name": "John", "age": 30}{{lineEnding}}{{lineEnding}}Request:{{lineEnding}}name"""; + + var kernel = this.CreateAndInitializeKernel(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin"); + + // Act + FunctionResult actual = await kernel.InvokeAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt }); + + // Assert + Assert.Contains("John", actual.GetValue(), StringComparison.OrdinalIgnoreCase); + } + + [Fact(Skip = "Currently not supported - Chat System Prompt is not surfacing as a system message level")] + public async Task ChatSystemPromptIsNotIgnoredAsync() + { + // Arrange + var kernel = this.CreateAndInitializeKernel(); + + var settings = new OpenAIPromptExecutionSettings { ChatSystemPrompt = "Reply \"I don't know\" to every question." }; + + // Act + var result = await kernel.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?", new(settings)); + + // Assert + Assert.Contains("I don't know", result.ToString(), StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task SemanticKernelVersionHeaderIsSentAsync() + { + // Arrange + using var defaultHandler = new HttpClientHandler(); + using var httpHeaderHandler = new HttpHeaderHandler(defaultHandler); + using var httpClient = new HttpClient(httpHeaderHandler); + + var kernel = this.CreateAndInitializeKernel(httpClient); + + // Act + await kernel.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?"); + + // Assert + Assert.NotNull(httpHeaderHandler.RequestHeaders); + Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var _)); + } + + //[Theory(Skip = "This test is for manual verification.")] + [Theory(Skip = "Currently not supported - Log Probabilities is not surfacing to the API level")] + [InlineData(null, null)] + [InlineData(false, null)] + [InlineData(true, 2)] + [InlineData(true, 5)] + public async Task LogProbsDataIsReturnedWhenRequestedAsync(bool? logprobs, int? topLogprobs) + { + // Arrange + var settings = new OpenAIPromptExecutionSettings { Logprobs = logprobs, TopLogprobs = topLogprobs }; + + var kernel = this.CreateAndInitializeKernel(); + + // Act + var result = await kernel.InvokePromptAsync("Hi, can you help me today?", new(settings)); + + var chatResponse = result.GetValue(); + var logProbabilityInfo = result.Metadata!["ContentTokenLogProbabilities"] as IReadOnlyList; + + // Assert + Assert.NotNull(logProbabilityInfo); + + if (logprobs is true) + { + Assert.NotNull(logProbabilityInfo); + Assert.Equal(topLogprobs, logProbabilityInfo[0].TopLogProbabilities.Count); + } + else + { + Assert.Empty(logProbabilityInfo); + } + } + private Kernel CreateAndInitializeKernel(HttpClient? httpClient = null) + { + var openAIConfiguration = this._configuration.GetSection("OpenAI").Get(); + Assert.NotNull(openAIConfiguration); + Assert.NotNull(openAIConfiguration.ChatModelId); + Assert.NotNull(openAIConfiguration.ApiKey); + Assert.NotNull(openAIConfiguration.ServiceId); + + var kernelBuilder = this.CreateKernelBuilder(); + + kernelBuilder.AddOpenAIChatClient( + modelId: openAIConfiguration.ChatModelId, + apiKey: openAIConfiguration.ApiKey, + serviceId: openAIConfiguration.ServiceId, + httpClient: httpClient); + + return kernelBuilder.Build(); + } + + private const string InputParameterName = "input"; + + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + private sealed class HttpHeaderHandler(HttpMessageHandler innerHandler) : DelegatingHandler(innerHandler) + { + public System.Net.Http.Headers.HttpRequestHeaders? RequestHeaders { get; private set; } + + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + this.RequestHeaders = request.Headers; + return await base.SendAsync(request, cancellationToken); + } + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs index 1359b701e29c..ddfe6b997a25 100644 --- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs @@ -8,12 +8,13 @@ using System.Text.Json; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Http.Resilience; -using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.OpenAI; +using OpenAI; using OpenAI.Chat; using SemanticKernel.IntegrationTests.TestSettings; using Xunit; @@ -43,6 +44,40 @@ public async Task ItCanUseOpenAiChatForTextGenerationAsync() Assert.Contains("Uranus", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); } + [Fact] + public async Task ItCanUseOpenAiChatClientAndContentsAsync() + { + var openAIConfiguration = this._configuration.GetSection("OpenAI").Get(); + Assert.NotNull(openAIConfiguration); + Assert.NotNull(openAIConfiguration.ChatModelId); + Assert.NotNull(openAIConfiguration.ApiKey); + Assert.NotNull(openAIConfiguration.ServiceId); + + // Arrange + var openAIClient = new OpenAIClient(openAIConfiguration.ApiKey); + var builder = Kernel.CreateBuilder(); + builder.Services.AddChatClient(openAIClient.GetChatClient(openAIConfiguration.ChatModelId).AsIChatClient()); + var kernel = builder.Build(); + + var func = kernel.CreateFunctionFromPrompt( + "List the two planets after '{{$input}}', excluding moons, using bullet points.", + new OpenAIPromptExecutionSettings()); + + // Act + var result = await func.InvokeAsync(kernel, new() { [InputParameterName] = "Jupiter" }); + + // Assert + Assert.NotNull(result); + Assert.Contains("Saturn", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + Assert.Contains("Uranus", result.GetValue(), StringComparison.InvariantCultureIgnoreCase); + var chatResponse = Assert.IsType(result.GetValue()); + Assert.Contains("Saturn", chatResponse.Text, StringComparison.InvariantCultureIgnoreCase); + var chatMessage = Assert.IsType(result.GetValue()); + Assert.Contains("Uranus", chatMessage.Text, StringComparison.InvariantCultureIgnoreCase); + var chatMessageContent = Assert.IsType(result.GetValue()); + Assert.Contains("Uranus", chatMessageContent.Content, StringComparison.InvariantCultureIgnoreCase); + } + [Fact] public async Task OpenAIStreamingTestAsync() { @@ -65,6 +100,43 @@ public async Task OpenAIStreamingTestAsync() Assert.Contains("Pike Place", fullResult.ToString(), StringComparison.OrdinalIgnoreCase); } + [Fact] + public async Task ItCanUseOpenAiStreamingChatClientAndContentsAsync() + { + var openAIConfiguration = this._configuration.GetSection("OpenAI").Get(); + Assert.NotNull(openAIConfiguration); + Assert.NotNull(openAIConfiguration.ChatModelId); + Assert.NotNull(openAIConfiguration.ApiKey); + Assert.NotNull(openAIConfiguration.ServiceId); + + // Arrange + var openAIClient = new OpenAIClient(openAIConfiguration.ApiKey); + var builder = Kernel.CreateBuilder(); + builder.Services.AddChatClient(openAIClient.GetChatClient(openAIConfiguration.ChatModelId).AsIChatClient()); + var kernel = builder.Build(); + + var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin"); + + StringBuilder fullResultSK = new(); + StringBuilder fullResultMEAI = new(); + + var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; + + // Act + await foreach (var content in kernel.InvokeStreamingAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt })) + { + fullResultSK.Append(content); + } + await foreach (var content in kernel.InvokeStreamingAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt })) + { + fullResultMEAI.Append(content); + } + + // Assert + Assert.Contains("Pike Place", fullResultSK.ToString(), StringComparison.OrdinalIgnoreCase); + Assert.Contains("Pike Place", fullResultMEAI.ToString(), StringComparison.OrdinalIgnoreCase); + } + [Fact] public async Task OpenAIHttpRetryPolicyTestAsync() { @@ -106,7 +178,7 @@ public async Task OpenAIHttpRetryPolicyTestAsync() // Assert Assert.All(statusCodes, s => Assert.Equal(HttpStatusCode.Unauthorized, s)); - Assert.Equal(HttpStatusCode.Unauthorized, ((HttpOperationException)exception).StatusCode); + Assert.Equal(HttpStatusCode.Unauthorized, exception.StatusCode); } [Fact] @@ -185,11 +257,11 @@ public async Task SemanticKernelVersionHeaderIsSentAsync() var kernel = this.CreateAndInitializeKernel(httpClient); // Act - var result = await kernel.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?"); + await kernel.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?"); // Assert Assert.NotNull(httpHeaderHandler.RequestHeaders); - Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var values)); + Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var _)); } //[Theory(Skip = "This test is for manual verification.")] @@ -228,18 +300,18 @@ public async Task LogProbsDataIsReturnedWhenRequestedAsync(bool? logprobs, int? private Kernel CreateAndInitializeKernel(HttpClient? httpClient = null) { - var OpenAIConfiguration = this._configuration.GetSection("OpenAI").Get(); - Assert.NotNull(OpenAIConfiguration); - Assert.NotNull(OpenAIConfiguration.ChatModelId); - Assert.NotNull(OpenAIConfiguration.ApiKey); - Assert.NotNull(OpenAIConfiguration.ServiceId); + var openAIConfiguration = this._configuration.GetSection("OpenAI").Get(); + Assert.NotNull(openAIConfiguration); + Assert.NotNull(openAIConfiguration.ChatModelId); + Assert.NotNull(openAIConfiguration.ApiKey); + Assert.NotNull(openAIConfiguration.ServiceId); - var kernelBuilder = base.CreateKernelBuilder(); + var kernelBuilder = this.CreateKernelBuilder(); kernelBuilder.AddOpenAIChatCompletion( - modelId: OpenAIConfiguration.ChatModelId, - apiKey: OpenAIConfiguration.ApiKey, - serviceId: OpenAIConfiguration.ServiceId, + modelId: openAIConfiguration.ChatModelId, + apiKey: openAIConfiguration.ApiKey, + serviceId: openAIConfiguration.ServiceId, httpClient: httpClient); return kernelBuilder.Build(); diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIEmbeddingGeneratorTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIEmbeddingGeneratorTests.cs new file mode 100644 index 000000000000..599ee15b5d17 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIEmbeddingGeneratorTests.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Configuration; +using Microsoft.SemanticKernel; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.OpenAI; + +public sealed class OpenAIEmbeddingGeneratorTests +{ + private const int AdaVectorLength = 1536; + private const string AdaModelId = "text-embedding-ada-002"; + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + [Theory(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] + [InlineData("test sentence")] + public async Task OpenAITestAsync(string testInputString) + { + // Arrange + OpenAIConfiguration? openAIConfiguration = this._configuration.GetSection("OpenAIEmbeddings").Get(); + Assert.NotNull(openAIConfiguration); + + var embeddingGenerator = Kernel.CreateBuilder() + .AddOpenAIEmbeddingGenerator(AdaModelId, openAIConfiguration.ApiKey) + .Build() + .GetRequiredService>>(); + + // Act + var singleResult = await embeddingGenerator.GenerateAsync(testInputString); + var batchResult = await embeddingGenerator.GenerateAsync([testInputString, testInputString, testInputString]); + + // Assert + Assert.Equal(AdaVectorLength, singleResult.Vector.Length); + Assert.Equal(3, batchResult.Count); + } + + [Theory(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] + [InlineData(null, 3072)] + [InlineData(1024, 1024)] + public async Task OpenAIWithDimensionsAsync(int? dimensions, int expectedVectorLength) + { + // Arrange + const string TestInputString = "test sentence"; + + OpenAIConfiguration? openAIConfiguration = this._configuration.GetSection("OpenAIEmbeddings").Get(); + Assert.NotNull(openAIConfiguration); + + var embeddingGenerator = Kernel.CreateBuilder() + .AddOpenAIEmbeddingGenerator("text-embedding-3-large", openAIConfiguration.ApiKey, dimensions: dimensions) + .Build() + .GetRequiredService>>(); + + // Act + var result = await embeddingGenerator.GenerateAsync(TestInputString); + + // Assert + Assert.Equal(expectedVectorLength, result.Vector.Length); + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAITextToAudioTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAITextToAudioTests.cs index c2818abe2502..420295fe4349 100644 --- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAITextToAudioTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAITextToAudioTests.cs @@ -18,7 +18,7 @@ public sealed class OpenAITextToAudioTests .AddUserSecrets() .Build(); - [Fact]//(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] + [Fact] //(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] public async Task OpenAITextToAudioTestAsync() { // Arrange diff --git a/dotnet/src/IntegrationTests/Planners/Handlebars/HandlebarsPlannerTests.cs b/dotnet/src/IntegrationTests/Planners/Handlebars/HandlebarsPlannerTests.cs index 1f344160f026..91496a8311fd 100644 --- a/dotnet/src/IntegrationTests/Planners/Handlebars/HandlebarsPlannerTests.cs +++ b/dotnet/src/IntegrationTests/Planners/Handlebars/HandlebarsPlannerTests.cs @@ -127,7 +127,7 @@ private Kernel InitializeKernel(bool useEmbeddings = false) deploymentName: azureOpenAIEmbeddingsConfiguration.DeploymentName, modelId: azureOpenAIEmbeddingsConfiguration.EmbeddingModelId, endpoint: azureOpenAIEmbeddingsConfiguration.Endpoint, - credential: new AzureCliCredential()); + credentials: new AzureCliCredential()); } return builder.Build(); diff --git a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs index 97cb426c307d..55a300769812 100644 --- a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs +++ b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs @@ -51,7 +51,7 @@ internal sealed class FunctionCallsProcessor /// will be disabled. This is a safeguard against possible runaway execution if the model routinely re-requests /// the same function over and over. /// - private const int MaximumAutoInvokeAttempts = 128; + internal const int MaximumAutoInvokeAttempts = 128; /// Tracking for . /// @@ -211,7 +211,7 @@ public FunctionCallsProcessor(ILogger? logger = null) { bool terminationRequested = false; - // Wait for all of the function invocations to complete, then add the results to the chat, but stop when we hit a + // Wait for all the function invocations to complete, then add the results to the chat, but stop when we hit a // function for which termination was requested. FunctionResultContext[] resultContexts = await Task.WhenAll(functionTasks).ConfigureAwait(false); foreach (FunctionResultContext resultContext in resultContexts) @@ -487,8 +487,8 @@ public static string ProcessFunctionResult(object functionResult) return stringResult; } - // This is an optimization to use ChatMessageContent content directly - // without unnecessary serialization of the whole message content class. + // This is an optimization to use ChatMessageContent content directly + // without unnecessary serialization of the whole message content class. if (functionResult is ChatMessageContent chatMessageContent) { return chatMessageContent.ToString(); diff --git a/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs b/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs index 2219e6151b9d..9218bdc08bce 100644 --- a/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs +++ b/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs @@ -1,9 +1,13 @@ // Copyright (c) Microsoft. All rights reserved. +using System.ClientModel; using System.Reflection; using System.Text; using System.Text.Json; +using Azure.AI.OpenAI; using Azure.Identity; +using Microsoft.Extensions.AI; using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; @@ -46,10 +50,22 @@ public abstract class BaseTest : TextWriter protected bool UseBingSearch => TestConfiguration.Bing.ApiKey is not null; protected Kernel CreateKernelWithChatCompletion(string? modelName = null) + => this.CreateKernelWithChatCompletion(useChatClient: false, out _); + + protected Kernel CreateKernelWithChatCompletion(bool useChatClient, out IChatClient? chatClient, string? modelName = null) { var builder = Kernel.CreateBuilder(); AddChatCompletionToKernel(builder, modelName); + if (useChatClient) + { + chatClient = AddChatClientToKernel(builder); + } + else + { + chatClient = null; + AddChatCompletionToKernel(builder); + } return builder.Build(); } @@ -78,6 +94,39 @@ protected void AddChatCompletionToKernel(IKernelBuilder builder, string? modelNa } } + protected IChatClient AddChatClientToKernel(IKernelBuilder builder) + { +#pragma warning disable CA2000 // Dispose objects before losing scope + IChatClient chatClient; + if (this.UseOpenAIConfig) + { + chatClient = new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey) + .GetChatClient(TestConfiguration.OpenAI.ChatModelId) + .AsIChatClient(); + } + else if (!string.IsNullOrEmpty(this.ApiKey)) + { + chatClient = new AzureOpenAIClient( + endpoint: new Uri(TestConfiguration.AzureOpenAI.Endpoint), + credential: new ApiKeyCredential(TestConfiguration.AzureOpenAI.ApiKey)) + .GetChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName) + .AsIChatClient(); + } + else + { + chatClient = new AzureOpenAIClient( + endpoint: new Uri(TestConfiguration.AzureOpenAI.Endpoint), + credential: new AzureCliCredential()) + .GetChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName) + .AsIChatClient(); + } + + var functionCallingChatClient = chatClient!.AsBuilder().UseKernelFunctionInvocation().Build(); + builder.Services.AddTransient((sp) => functionCallingChatClient); + return functionCallingChatClient; +#pragma warning restore CA2000 // Dispose objects before losing scope + } + protected BaseTest(ITestOutputHelper output, bool redirectSystemConsoleOutput = false) { this.Output = output; diff --git a/dotnet/src/InternalUtilities/src/Diagnostics/KernelVerify.cs b/dotnet/src/InternalUtilities/src/Diagnostics/KernelVerify.cs index 1cb1c96ae181..a359389756a7 100644 --- a/dotnet/src/InternalUtilities/src/Diagnostics/KernelVerify.cs +++ b/dotnet/src/InternalUtilities/src/Diagnostics/KernelVerify.cs @@ -46,7 +46,7 @@ internal static void ValidFunctionName([NotNull] string? functionName, [CallerAr /// Make sure every function parameter name is unique /// /// List of parameters - internal static void ParametersUniqueness(IReadOnlyList parameters) + internal static IReadOnlyList ParametersUniqueness(IReadOnlyList parameters) { int count = parameters.Count; if (count > 0) @@ -74,5 +74,7 @@ internal static void ParametersUniqueness(IReadOnlyList } } } + + return parameters; } } diff --git a/dotnet/src/InternalUtilities/src/Diagnostics/Throw.cs b/dotnet/src/InternalUtilities/src/Diagnostics/Throw.cs new file mode 100644 index 000000000000..77a1115561ef --- /dev/null +++ b/dotnet/src/InternalUtilities/src/Diagnostics/Throw.cs @@ -0,0 +1,984 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +// Source Originally from: https://github.com/dotnet/extensions/blob/ef3f0a/src/Shared/Throw/Throw.cs + +namespace Microsoft.SemanticKernel; + +/// +/// Defines static methods used to throw exceptions. +/// +/// +/// The main purpose is to reduce code size, improve performance, and standardize exception +/// messages. +/// +[SuppressMessage("Minor Code Smell", "S4136:Method overloads should be grouped together", Justification = "Doesn't work with the region layout")] +[SuppressMessage("Minor Code Smell", "S2333:Partial is gratuitous in this context", Justification = "Some projects add additional partial parts.")] +[SuppressMessage("Design", "CA1716", Justification = "Not part of an API")] + +[ExcludeFromCodeCoverage] +internal static partial class Throw +{ + #region For Object + + /// + /// Throws an if the specified argument is . + /// + /// Argument type to be checked for . + /// Object to be checked for . + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [return: NotNull] + public static T IfNull([NotNull] T argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument is null) + { + ArgumentNullException(paramName); + } + + return argument; + } + + /// + /// Throws an if the specified argument is , + /// or if the specified member is . + /// + /// Argument type to be checked for . + /// Member type to be checked for . + /// Argument to be checked for . + /// Object member to be checked for . + /// The name of the parameter being checked. + /// The name of the member. + /// The original value of . + /// + /// + /// Throws.IfNullOrMemberNull(myObject, myObject?.MyProperty) + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [return: NotNull] + public static TMember IfNullOrMemberNull( + [NotNull] TParameter argument, + [NotNull] TMember member, + [CallerArgumentExpression(nameof(argument))] string paramName = "", + [CallerArgumentExpression(nameof(member))] string memberName = "") + { + if (argument is null) + { + ArgumentNullException(paramName); + } + + if (member is null) + { + ArgumentException(paramName, $"Member {memberName} of {paramName} is null"); + } + + return member; + } + + /// + /// Throws an if the specified member is . + /// + /// Argument type. + /// Member type to be checked for . + /// Argument to which member belongs. + /// Object member to be checked for . + /// The name of the parameter being checked. + /// The name of the member. + /// The original value of . + /// + /// + /// Throws.IfMemberNull(myObject, myObject.MyProperty) + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [return: NotNull] + [SuppressMessage("Style", "IDE0060:Remove unused parameter", Justification = "Analyzer isn't seeing the reference to 'argument' in the attribute")] + public static TMember IfMemberNull( + TParameter argument, + [NotNull] TMember member, + [CallerArgumentExpression(nameof(argument))] string paramName = "", + [CallerArgumentExpression(nameof(member))] string memberName = "") + where TParameter : notnull + { + if (member is null) + { + ArgumentException(paramName, $"Member {memberName} of {paramName} is null"); + } + + return member; + } + + #endregion + + #region For String + + /// + /// Throws either an or an + /// if the specified string is or whitespace respectively. + /// + /// String to be checked for or whitespace. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [return: NotNull] + public static string IfNullOrWhitespace([NotNull] string? argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { +#if !NETCOREAPP3_1_OR_GREATER + if (argument == null) + { + ArgumentNullException(paramName); + } +#endif + + if (string.IsNullOrWhiteSpace(argument)) + { + if (argument == null) + { + ArgumentNullException(paramName); + } + else + { + ArgumentException(paramName, "Argument is whitespace"); + } + } + + return argument; + } + + /// + /// Throws an if the string is , + /// or if it is empty. + /// + /// String to be checked for or empty. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [return: NotNull] + public static string IfNullOrEmpty([NotNull] string? argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { +#if !NETCOREAPP3_1_OR_GREATER + if (argument == null) + { + ArgumentNullException(paramName); + } +#endif + + if (string.IsNullOrEmpty(argument)) + { + if (argument == null) + { + ArgumentNullException(paramName); + } + else + { + ArgumentException(paramName, "Argument is an empty string"); + } + } + + return argument; + } + + #endregion + + #region For Buffer + + /// + /// Throws an if the argument's buffer size is less than the required buffer size. + /// + /// The actual buffer size. + /// The required buffer size. + /// The name of the parameter to be checked. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void IfBufferTooSmall(int bufferSize, int requiredSize, string paramName = "") + { + if (bufferSize < requiredSize) + { + ArgumentException(paramName, $"Buffer too small, needed a size of {requiredSize} but got {bufferSize}"); + } + } + + #endregion + + #region For Enums + + /// + /// Throws an if the enum value is not valid. + /// + /// The argument to evaluate. + /// The name of the parameter being checked. + /// The type of the enumeration. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static T IfOutOfRange(T argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + where T : struct, Enum + { +#if NET5_0_OR_GREATER + if (!Enum.IsDefined(argument)) +#else + if (!Enum.IsDefined(typeof(T), argument)) +#endif + { + ArgumentOutOfRangeException(paramName, $"{argument} is an invalid value for enum type {typeof(T)}"); + } + + return argument; + } + + #endregion + + #region For Collections + + /// + /// Throws an if the collection is , + /// or if it is empty. + /// + /// The collection to evaluate. + /// The name of the parameter being checked. + /// The type of objects in the collection. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [return: NotNull] + + // The method has actually 100% coverage, but due to a bug in the code coverage tool, + // a lower number is reported. Therefore, we temporarily exclude this method + // from the coverage measurements. Once the bug in the code coverage tool is fixed, + // the exclusion attribute can be removed. + [ExcludeFromCodeCoverage] + public static IEnumerable IfNullOrEmpty([NotNull] IEnumerable? argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument == null) + { + ArgumentNullException(paramName); + } + else + { + switch (argument) + { + case ICollection collection: + if (collection.Count == 0) + { + ArgumentException(paramName, "Collection is empty"); + } + + break; + case IReadOnlyCollection readOnlyCollection: + if (readOnlyCollection.Count == 0) + { + ArgumentException(paramName, "Collection is empty"); + } + + break; + default: + using (IEnumerator enumerator = argument.GetEnumerator()) + { + if (!enumerator.MoveNext()) + { + ArgumentException(paramName, "Collection is empty"); + } + } + + break; + } + } + + return argument; + } + + #endregion + + #region Exceptions + + /// + /// Throws an . + /// + /// The name of the parameter that caused the exception. +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void ArgumentNullException(string paramName) + => throw new ArgumentNullException(paramName); + + /// + /// Throws an . + /// + /// The name of the parameter that caused the exception. + /// A message that describes the error. +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void ArgumentNullException(string paramName, string? message) + => throw new ArgumentNullException(paramName, message); + + /// + /// Throws an . + /// + /// The name of the parameter that caused the exception. +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void ArgumentOutOfRangeException(string paramName) + => throw new ArgumentOutOfRangeException(paramName); + + /// + /// Throws an . + /// + /// The name of the parameter that caused the exception. + /// A message that describes the error. +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void ArgumentOutOfRangeException(string paramName, string? message) + => throw new ArgumentOutOfRangeException(paramName, message); + + /// + /// Throws an . + /// + /// The name of the parameter that caused the exception. + /// The value of the argument that caused this exception. + /// A message that describes the error. +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void ArgumentOutOfRangeException(string paramName, object? actualValue, string? message) + => throw new ArgumentOutOfRangeException(paramName, actualValue, message); + + /// + /// Throws an . + /// + /// The name of the parameter that caused the exception. + /// A message that describes the error. +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void ArgumentException(string paramName, string? message) + => throw new ArgumentException(message, paramName); + + /// + /// Throws an . + /// + /// The name of the parameter that caused the exception. + /// A message that describes the error. + /// The exception that is the cause of the current exception. + /// + /// If the is not a , the current exception is raised in a catch + /// block that handles the inner exception. + /// +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void ArgumentException(string paramName, string? message, Exception? innerException) + => throw new ArgumentException(message, paramName, innerException); + + /// + /// Throws an . + /// + /// A message that describes the error. +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void InvalidOperationException(string message) + => throw new InvalidOperationException(message); + + /// + /// Throws an . + /// + /// A message that describes the error. + /// The exception that is the cause of the current exception. +#if !NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.NoInlining)] +#endif + [DoesNotReturn] + public static void InvalidOperationException(string message, Exception? innerException) + => throw new InvalidOperationException(message, innerException); + + #endregion + + #region For Integer + + /// + /// Throws an if the specified number is less than min. + /// + /// Number to be expected being less than min. + /// The number that must be less than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IfLessThan(int argument, int min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument < min) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater than max. + /// + /// Number to be expected being greater than max. + /// The number that must be greater than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IfGreaterThan(int argument, int max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument > max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is less or equal than min. + /// + /// Number to be expected being less or equal than min. + /// The number that must be less or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IfLessThanOrEqual(int argument, int min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument <= min) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less or equal than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater or equal than max. + /// + /// Number to be expected being greater or equal than max. + /// The number that must be greater or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IfGreaterThanOrEqual(int argument, int max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument >= max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater or equal than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is not in the specified range. + /// + /// Number to be expected being greater or equal than max. + /// The lower bound of the allowed range of argument values. + /// The upper bound of the allowed range of argument values. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IfOutOfRange(int argument, int min, int max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument < min || argument > max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument not in the range [{min}..{max}]"); + } + + return argument; + } + + /// + /// Throws an if the specified number is equal to 0. + /// + /// Number to be expected being not equal to zero. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IfZero(int argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument == 0) + { + ArgumentOutOfRangeException(paramName, "Argument is zero"); + } + + return argument; + } + + #endregion + + #region For Unsigned Integer + + /// + /// Throws an if the specified number is less than min. + /// + /// Number to be expected being less than min. + /// The number that must be less than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint IfLessThan(uint argument, uint min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument < min) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater than max. + /// + /// Number to be expected being greater than max. + /// The number that must be greater than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint IfGreaterThan(uint argument, uint max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument > max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is less or equal than min. + /// + /// Number to be expected being less or equal than min. + /// The number that must be less or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint IfLessThanOrEqual(uint argument, uint min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument <= min) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less or equal than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater or equal than max. + /// + /// Number to be expected being greater or equal than max. + /// The number that must be greater or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint IfGreaterThanOrEqual(uint argument, uint max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument >= max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater or equal than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is not in the specified range. + /// + /// Number to be expected being greater or equal than max. + /// The lower bound of the allowed range of argument values. + /// The upper bound of the allowed range of argument values. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint IfOutOfRange(uint argument, uint min, uint max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument < min || argument > max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument not in the range [{min}..{max}]"); + } + + return argument; + } + + /// + /// Throws an if the specified number is equal to 0. + /// + /// Number to be expected being not equal to zero. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static uint IfZero(uint argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument == 0U) + { + ArgumentOutOfRangeException(paramName, "Argument is zero"); + } + + return argument; + } + + #endregion + + #region For Long + + /// + /// Throws an if the specified number is less than min. + /// + /// Number to be expected being less than min. + /// The number that must be less than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long IfLessThan(long argument, long min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument < min) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater than max. + /// + /// Number to be expected being greater than max. + /// The number that must be greater than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long IfGreaterThan(long argument, long max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument > max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is less or equal than min. + /// + /// Number to be expected being less or equal than min. + /// The number that must be less or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long IfLessThanOrEqual(long argument, long min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument <= min) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less or equal than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater or equal than max. + /// + /// Number to be expected being greater or equal than max. + /// The number that must be greater or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long IfGreaterThanOrEqual(long argument, long max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument >= max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater or equal than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is not in the specified range. + /// + /// Number to be expected being greater or equal than max. + /// The lower bound of the allowed range of argument values. + /// The upper bound of the allowed range of argument values. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long IfOutOfRange(long argument, long min, long max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument < min || argument > max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument not in the range [{min}..{max}]"); + } + + return argument; + } + + /// + /// Throws an if the specified number is equal to 0. + /// + /// Number to be expected being not equal to zero. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static long IfZero(long argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument == 0L) + { + ArgumentOutOfRangeException(paramName, "Argument is zero"); + } + + return argument; + } + + #endregion + + #region For Unsigned Long + + /// + /// Throws an if the specified number is less than min. + /// + /// Number to be expected being less than min. + /// The number that must be less than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong IfLessThan(ulong argument, ulong min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument < min) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater than max. + /// + /// Number to be expected being greater than max. + /// The number that must be greater than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong IfGreaterThan(ulong argument, ulong max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument > max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is less or equal than min. + /// + /// Number to be expected being less or equal than min. + /// The number that must be less or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong IfLessThanOrEqual(ulong argument, ulong min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument <= min) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less or equal than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater or equal than max. + /// + /// Number to be expected being greater or equal than max. + /// The number that must be greater or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong IfGreaterThanOrEqual(ulong argument, ulong max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument >= max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater or equal than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is not in the specified range. + /// + /// Number to be expected being greater or equal than max. + /// The lower bound of the allowed range of argument values. + /// The upper bound of the allowed range of argument values. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong IfOutOfRange(ulong argument, ulong min, ulong max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument < min || argument > max) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument not in the range [{min}..{max}]"); + } + + return argument; + } + + /// + /// Throws an if the specified number is equal to 0. + /// + /// Number to be expected being not equal to zero. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ulong IfZero(ulong argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument == 0UL) + { + ArgumentOutOfRangeException(paramName, "Argument is zero"); + } + + return argument; + } + + #endregion + + #region For Double + + /// + /// Throws an if the specified number is less than min. + /// + /// Number to be expected being less than min. + /// The number that must be less than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static double IfLessThan(double argument, double min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + // strange conditional needed in order to handle NaN values correctly +#pragma warning disable S1940 // Boolean checks should not be inverted + if (!(argument >= min)) +#pragma warning restore S1940 // Boolean checks should not be inverted + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater than max. + /// + /// Number to be expected being greater than max. + /// The number that must be greater than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static double IfGreaterThan(double argument, double max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + // strange conditional needed in order to handle NaN values correctly +#pragma warning disable S1940 // Boolean checks should not be inverted + if (!(argument <= max)) +#pragma warning restore S1940 // Boolean checks should not be inverted + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is less or equal than min. + /// + /// Number to be expected being less or equal than min. + /// The number that must be less or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static double IfLessThanOrEqual(double argument, double min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + // strange conditional needed in order to handle NaN values correctly +#pragma warning disable S1940 // Boolean checks should not be inverted + if (!(argument > min)) +#pragma warning restore S1940 // Boolean checks should not be inverted + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less or equal than minimum value {min}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is greater or equal than max. + /// + /// Number to be expected being greater or equal than max. + /// The number that must be greater or equal than the argument. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static double IfGreaterThanOrEqual(double argument, double max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + // strange conditional needed in order to handle NaN values correctly +#pragma warning disable S1940 // Boolean checks should not be inverted + if (!(argument < max)) +#pragma warning restore S1940 // Boolean checks should not be inverted + { + ArgumentOutOfRangeException(paramName, argument, $"Argument greater or equal than maximum value {max}"); + } + + return argument; + } + + /// + /// Throws an if the specified number is not in the specified range. + /// + /// Number to be expected being greater or equal than max. + /// The lower bound of the allowed range of argument values. + /// The upper bound of the allowed range of argument values. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static double IfOutOfRange(double argument, double min, double max, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + // strange conditional needed in order to handle NaN values correctly + if (!(min <= argument && argument <= max)) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument not in the range [{min}..{max}]"); + } + + return argument; + } + + /// + /// Throws an if the specified number is equal to 0. + /// + /// Number to be expected being not equal to zero. + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static double IfZero(double argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { +#pragma warning disable S1244 // Floating point numbers should not be tested for equality + if (argument == 0.0) +#pragma warning restore S1244 // Floating point numbers should not be tested for equality + { + ArgumentOutOfRangeException(paramName, "Argument is zero"); + } + + return argument; + } + + #endregion +} diff --git a/dotnet/src/InternalUtilities/src/EmptyCollections/EmptyReadonlyDictionary.cs b/dotnet/src/InternalUtilities/src/EmptyCollections/EmptyReadonlyDictionary.cs new file mode 100644 index 000000000000..a013d3556df5 --- /dev/null +++ b/dotnet/src/InternalUtilities/src/EmptyCollections/EmptyReadonlyDictionary.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; + +#pragma warning disable IDE0009 // use this directive +#pragma warning disable CA1716 + +// Original source from +// https://raw.githubusercontent.com/dotnet/extensions/main/src/Shared/EmptyCollections/EmptyReadOnlyList.cs + +[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] +internal sealed class EmptyReadOnlyDictionary : IReadOnlyDictionary, IDictionary + where TKey : notnull +{ + public static readonly EmptyReadOnlyDictionary Instance = new(); + + public int Count => 0; + public TValue this[TKey key] => throw new KeyNotFoundException(); + public bool ContainsKey(TKey key) => false; + public IEnumerable Keys => EmptyReadOnlyList.Instance; + public IEnumerable Values => EmptyReadOnlyList.Instance; + + public IEnumerator> GetEnumerator() => EmptyReadOnlyList>.Instance.GetEnumerator(); + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + ICollection IDictionary.Keys => Array.Empty(); + ICollection IDictionary.Values => Array.Empty(); + bool ICollection>.IsReadOnly => true; + TValue IDictionary.this[TKey key] + { + get => throw new KeyNotFoundException(); + set => throw new NotSupportedException(); + } + + public bool TryGetValue(TKey key, out TValue value) + { +#pragma warning disable CS8601 // The recommended implementation: https://docs.microsoft.com/en-us/dotnet/api/system.collections.generic.dictionary-2.trygetvalue + value = default; +#pragma warning restore + + return false; + } + + void ICollection>.Clear() + { + // nop + } + + void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) + { + // nop + } + + void IDictionary.Add(TKey key, TValue value) => throw new NotSupportedException(); + bool IDictionary.Remove(TKey key) => false; + void ICollection>.Add(KeyValuePair item) => throw new NotSupportedException(); + bool ICollection>.Contains(KeyValuePair item) => false; + bool ICollection>.Remove(KeyValuePair item) => false; +} diff --git a/dotnet/src/InternalUtilities/src/EmptyCollections/EmptyReadonlyList.cs b/dotnet/src/InternalUtilities/src/EmptyCollections/EmptyReadonlyList.cs new file mode 100644 index 000000000000..b2c730958691 --- /dev/null +++ b/dotnet/src/InternalUtilities/src/EmptyCollections/EmptyReadonlyList.cs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; + +#pragma warning disable IDE0009 // use this directive +#pragma warning disable CA1716 + +// Original source from +// https://raw.githubusercontent.com/dotnet/extensions/main/src/Shared/EmptyCollections/EmptyReadOnlyList.cs + +[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] +[System.Diagnostics.CodeAnalysis.SuppressMessage("Design", "CA1001:Types that own disposable fields should be disposable", Justification = "Static field, lifetime matches the process")] +internal sealed class EmptyReadOnlyList : IReadOnlyList, ICollection +{ + public static readonly EmptyReadOnlyList Instance = new(); + private readonly Enumerator _enumerator = new(); + + public IEnumerator GetEnumerator() => _enumerator; + IEnumerator IEnumerable.GetEnumerator() => _enumerator; + public int Count => 0; + public T this[int index] => throw new ArgumentOutOfRangeException(nameof(index)); + + void ICollection.CopyTo(T[] array, int arrayIndex) + { + // nop + } + + bool ICollection.Contains(T item) => false; + bool ICollection.IsReadOnly => true; + void ICollection.Add(T item) => throw new NotSupportedException(); + bool ICollection.Remove(T item) => false; + + void ICollection.Clear() + { + // nop + } + + internal sealed class Enumerator : IEnumerator + { + public void Dispose() + { + // nop + } + + public void Reset() + { + // nop + } + + public bool MoveNext() => false; + public T Current => throw new InvalidOperationException(); + object IEnumerator.Current => throw new InvalidOperationException(); + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs new file mode 100644 index 000000000000..58ea317804f9 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel.Services; + +namespace Microsoft.SemanticKernel.ChatCompletion; + +/// +/// Allow to be used as an in a +/// +internal sealed class ChatClientAIService : IAIService, IChatClient +{ + private readonly IChatClient _chatClient; + + /// + /// Storage for AI service attributes. + /// + internal Dictionary _internalAttributes { get; } = []; + + /// + /// Initializes a new instance of the class. + /// + /// Target . + internal ChatClientAIService(IChatClient chatClient) + { + Verify.NotNull(chatClient); + this._chatClient = chatClient; + + var metadata = this._chatClient.GetService(); + Verify.NotNull(metadata); + + this._internalAttributes[AIServiceExtensions.ModelIdKey] = metadata.DefaultModelId; + this._internalAttributes[nameof(metadata.ProviderName)] = metadata.ProviderName; + this._internalAttributes[nameof(metadata.ProviderUri)] = metadata.ProviderUri; + } + + /// + public IReadOnlyDictionary Attributes => this._internalAttributes; + + /// + public void Dispose() + { + } + + /// + public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + => this._chatClient.GetResponseAsync(messages, options, cancellationToken); + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + => this._chatClient.GetService(serviceType, serviceKey); + + /// + public IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + => this._chatClient.GetStreamingResponseAsync(messages, options, cancellationToken); +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs new file mode 100644 index 000000000000..cac719649cab --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.SemanticKernel.ChatCompletion; + +/// Provides extension methods for . +public static class ChatClientExtensions +{ + /// + /// Get chat response which may contain multiple choices for the prompt and settings. + /// + /// Target chat client service. + /// The standardized prompt input. + /// The AI execution settings (optional). + /// The containing services, plugins, and other state for use throughout the operation. + /// The to monitor for cancellation requests. The default is . + /// Get chat response with choices generated by the remote model + internal static Task GetResponseAsync( + this IChatClient chatClient, + string prompt, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + { + var chatOptions = executionSettings.ToChatOptions(kernel); + + // Try to parse the text as a chat history + if (ChatPromptParser.TryParse(prompt, out var chatHistoryFromPrompt)) + { + var messageList = chatHistoryFromPrompt.ToChatMessageList(); + return chatClient.GetResponseAsync(messageList, chatOptions, cancellationToken); + } + + return chatClient.GetResponseAsync(prompt, chatOptions, cancellationToken); + } + + /// Get ChatClient streaming response for the prompt, settings and kernel. + /// Target chat client service. + /// The standardized prompt input. + /// The AI execution settings (optional). + /// The containing services, plugins, and other state for use throughout the operation. + /// The to monitor for cancellation requests. The default is . + /// Streaming list of different completion streaming string updates generated by the remote model + internal static IAsyncEnumerable GetStreamingResponseAsync( + this IChatClient chatClient, + string prompt, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + { + var chatOptions = executionSettings.ToChatOptions(kernel); + + return chatClient.GetStreamingResponseAsync(prompt, chatOptions, cancellationToken); + } + + /// Creates an for the specified . + /// The chat client to be represented as a chat completion service. + /// An optional that can be used to resolve services to use in the instance. + /// + /// The . If is an , will + /// be returned. Otherwise, a new will be created that wraps . + /// + [Experimental("SKEXP0001")] + public static IChatCompletionService AsChatCompletionService(this IChatClient client, IServiceProvider? serviceProvider = null) + { + Verify.NotNull(client); + + return client is IChatCompletionService chatCompletionService ? + chatCompletionService : + new ChatClientChatCompletionService(client, serviceProvider); + } + + /// + /// Get the model identifier for the specified . + /// + [Experimental("SKEXP0001")] + public static string? GetModelId(this IChatClient client) + { + Verify.NotNull(client); + + return client.GetService()?.DefaultModelId; + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs new file mode 100644 index 000000000000..1501cb71d988 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel.ChatCompletion; + +internal static class ChatMessageExtensions +{ + /// Converts a to a . + internal static ChatMessageContent ToChatMessageContent(this ChatMessage message, Microsoft.Extensions.AI.ChatResponse? response = null) + { + ChatMessageContent result = new() + { + ModelId = response?.ModelId, + AuthorName = message.AuthorName, + InnerContent = response?.RawRepresentation ?? message.RawRepresentation, + Metadata = message.AdditionalProperties, + Role = new AuthorRole(message.Role.Value), + }; + + foreach (AIContent content in message.Contents) + { + KernelContent? resultContent = content switch + { + Microsoft.Extensions.AI.TextContent tc => new Microsoft.SemanticKernel.TextContent(tc.Text), + Microsoft.Extensions.AI.DataContent dc when dc.HasTopLevelMediaType("image") => new Microsoft.SemanticKernel.ImageContent(dc.Uri), + Microsoft.Extensions.AI.UriContent uc when uc.HasTopLevelMediaType("image") => new Microsoft.SemanticKernel.ImageContent(uc.Uri), + Microsoft.Extensions.AI.DataContent dc when dc.HasTopLevelMediaType("audio") => new Microsoft.SemanticKernel.AudioContent(dc.Uri), + Microsoft.Extensions.AI.UriContent uc when uc.HasTopLevelMediaType("audio") => new Microsoft.SemanticKernel.AudioContent(uc.Uri), + Microsoft.Extensions.AI.DataContent dc => new Microsoft.SemanticKernel.BinaryContent(dc.Uri), + Microsoft.Extensions.AI.UriContent uc => new Microsoft.SemanticKernel.BinaryContent(uc.Uri), + Microsoft.Extensions.AI.FunctionCallContent fcc => new Microsoft.SemanticKernel.FunctionCallContent(fcc.Name, null, fcc.CallId, fcc.Arguments is not null ? new(fcc.Arguments) : null), + Microsoft.Extensions.AI.FunctionResultContent frc => new Microsoft.SemanticKernel.FunctionResultContent(callId: frc.CallId, result: frc.Result), + _ => null + }; + + if (resultContent is not null) + { + resultContent.Metadata = content.AdditionalProperties; + resultContent.InnerContent = content.RawRepresentation; + resultContent.ModelId = response?.ModelId; + result.Items.Add(resultContent); + } + } + + return result; + } + + /// Converts a list of to a . + internal static ChatHistory ToChatHistory(this IEnumerable chatMessages) + { + ChatHistory chatHistory = []; + foreach (var message in chatMessages) + { + chatHistory.Add(message.ToChatMessageContent()); + } + return chatHistory; + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs new file mode 100644 index 000000000000..d8fab37e57bd --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel.ChatCompletion; + +/// +/// Extensions methods for . +/// +internal static class ChatOptionsExtensions +{ + /// Converts a to a . + internal static PromptExecutionSettings? ToPromptExecutionSettings(this ChatOptions? options) + { + if (options is null) + { + return null; + } + + PromptExecutionSettings settings = new() + { + ExtensionData = new Dictionary(StringComparer.OrdinalIgnoreCase), + ModelId = options.ModelId, + }; + + // Transfer over all strongly-typed members of ChatOptions. We do not know the exact name the derived PromptExecutionSettings + // will pick for these options, so we just use the most common choice for each. (We could make this more exact by having an + // IPromptExecutionSettingsFactory interface with a method like `PromptExecutionSettings Create(ChatOptions options)`; that + // interface could then optionally be implemented by an IChatCompletionService, and this implementation could just ask the + // chat completion service to produce the PromptExecutionSettings it wants. But, this is already a problem + // with PromptExecutionSettings, regardless of ChatOptions... someone creating a PES without knowing what backend is being + // used has to guess at the names to use.) + + if (options.Temperature is not null) + { + settings.ExtensionData["temperature"] = options.Temperature.Value; + } + + if (options.MaxOutputTokens is not null) + { + settings.ExtensionData["max_tokens"] = options.MaxOutputTokens.Value; + } + + if (options.FrequencyPenalty is not null) + { + settings.ExtensionData["frequency_penalty"] = options.FrequencyPenalty.Value; + } + + if (options.PresencePenalty is not null) + { + settings.ExtensionData["presence_penalty"] = options.PresencePenalty.Value; + } + + if (options.StopSequences is not null) + { + settings.ExtensionData["stop_sequences"] = options.StopSequences; + } + + if (options.TopP is not null) + { + settings.ExtensionData["top_p"] = options.TopP.Value; + } + + if (options.TopK is not null) + { + settings.ExtensionData["top_k"] = options.TopK.Value; + } + + if (options.Seed is not null) + { + settings.ExtensionData["seed"] = options.Seed.Value; + } + + if (options.ResponseFormat is not null) + { + if (options.ResponseFormat is ChatResponseFormatText) + { + settings.ExtensionData["response_format"] = "text"; + } + else if (options.ResponseFormat is ChatResponseFormatJson json) + { + settings.ExtensionData["response_format"] = json.Schema is JsonElement schema ? + JsonSerializer.Deserialize(schema, AbstractionsJsonContext.Default.JsonElement) : + "json_object"; + } + } + + // Transfer over loosely-typed members of ChatOptions. + + if (options.AdditionalProperties is not null) + { + foreach (var kvp in options.AdditionalProperties) + { + if (kvp.Value is not null) + { + settings.ExtensionData[kvp.Key] = kvp.Value; + } + } + } + + // Transfer over tools. For IChatClient, we do not want automatic invocation, as that's a concern left up to + // components like FunctionInvocationChatClient. As such, based on the tool mode, we map to the appropriate + // FunctionChoiceBehavior, but always with autoInvoke: false. + + if (options.Tools is { Count: > 0 }) + { + var functions = options.Tools.OfType().Select(f => new AIFunctionKernelFunction(f)); + settings.FunctionChoiceBehavior = + options.ToolMode is null or AutoChatToolMode ? FunctionChoiceBehavior.Auto(functions, autoInvoke: false) : + options.ToolMode is RequiredChatToolMode { RequiredFunctionName: null } ? FunctionChoiceBehavior.Required(functions, autoInvoke: false) : + options.ToolMode is RequiredChatToolMode { RequiredFunctionName: string functionName } ? FunctionChoiceBehavior.Required(functions.Where(f => f.Name == functionName), autoInvoke: false) : + null; + } + + return settings; + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatResponseUpdateExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatResponseUpdateExtensions.cs new file mode 100644 index 000000000000..8ec9698484b0 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatResponseUpdateExtensions.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Microsoft.Extensions.AI; + +/// Provides extension methods for . +internal static class ChatResponseUpdateExtensions +{ + /// Converts a to a . + /// This conversion should not be necessary once SK eventually adopts the shared content types. + internal static StreamingChatMessageContent ToStreamingChatMessageContent(this ChatResponseUpdate update) + { + StreamingChatMessageContent content = new( + update.Role is not null ? new AuthorRole(update.Role.Value.Value) : null, + null) + { + InnerContent = update.RawRepresentation, + Metadata = update.AdditionalProperties, + ModelId = update.ModelId + }; + + foreach (AIContent item in update.Contents) + { + StreamingKernelContent? resultContent = + item is Microsoft.Extensions.AI.TextContent tc ? new Microsoft.SemanticKernel.StreamingTextContent(tc.Text) : + item is Microsoft.Extensions.AI.FunctionCallContent fcc ? + new Microsoft.SemanticKernel.StreamingFunctionCallUpdateContent(fcc.CallId, fcc.Name, fcc.Arguments is not null ? + JsonSerializer.Serialize(fcc.Arguments!, AbstractionsJsonContext.Default.IDictionaryStringObject!) : + null) : + null; + + if (resultContent is not null) + { + resultContent.ModelId = update.ModelId; + content.Items.Add(resultContent); + } + } + + return content; + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelChatOptions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelChatOptions.cs new file mode 100644 index 000000000000..70656742be12 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelChatOptions.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; +using Microsoft.SemanticKernel; + +namespace Microsoft.Extensions.AI; + +/// +/// This class allows a , and a to be used +/// for internal creation. This avoids any leaking information in the lower-level ChatOptions +/// during serialization of in calls to AI Models. +/// +internal class KernelChatOptions : ChatOptions +{ + /// + /// Initializes a new instance of the class. + /// + /// Target kernel. + /// Original chat options. + /// Prompt execution settings. + internal KernelChatOptions(Kernel kernel, ChatOptions? options = null, PromptExecutionSettings? settings = null) + { + Verify.NotNull(kernel); + + if (options is not null) + { + this.AdditionalProperties = options.AdditionalProperties; + this.AllowMultipleToolCalls = options.AllowMultipleToolCalls; + this.Tools = options.Tools; + this.Temperature = options.Temperature; + this.TopP = options.TopP; + this.TopK = options.TopK; + this.Seed = options.Seed; + this.ResponseFormat = options.ResponseFormat; + this.MaxOutputTokens = options.MaxOutputTokens; + this.FrequencyPenalty = options.FrequencyPenalty; + this.PresencePenalty = options.PresencePenalty; + this.StopSequences = options.StopSequences; + this.RawRepresentationFactory = options.RawRepresentationFactory; + this.ConversationId = options.ConversationId; + this.Seed = options.Seed; + this.ToolMode = options.ToolMode; + this.ModelId = options.ModelId; + } + + this.ExecutionSettings = settings; + this.Kernel = kernel; + } + + [JsonIgnore] + public ChatMessageContent? ChatMessageContent { get; internal set; } + + [JsonIgnore] + public Kernel Kernel { get; } + + [JsonIgnore] + public PromptExecutionSettings? ExecutionSettings { get; internal set; } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs new file mode 100644 index 000000000000..968a51b411b9 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Microsoft.Extensions.AI; + +/// +/// Specialization of that uses and supports . +/// +internal sealed class KernelFunctionInvokingChatClient : FunctionInvokingChatClient +{ + /// + public KernelFunctionInvokingChatClient(IChatClient innerClient, ILoggerFactory? loggerFactory = null, IServiceProvider? functionInvocationServices = null) + : base(innerClient, loggerFactory, functionInvocationServices) + { + this.MaximumIterationsPerRequest = 128; + } + + /// + /// Invokes the auto function invocation filters. + /// + /// The auto function invocation context. + /// The function to call after the filters. + /// The auto function invocation context. + private async Task OnAutoFunctionInvocationAsync( + AutoFunctionInvocationContext context, + Func functionCallCallback) + { + await this.InvokeFilterOrFunctionAsync(functionCallCallback, context).ConfigureAwait(false); + + return context; + } + + /// + /// This method will execute auto function invocation filters and function recursively. + /// If there are no registered filters, just function will be executed. + /// If there are registered filters, filter on position will be executed. + /// Second parameter of filter is callback. It can be either filter on + 1 position or function if there are no remaining filters to execute. + /// Function will always be executed as last step after all filters. + /// + private async Task InvokeFilterOrFunctionAsync( + Func functionCallCallback, + AutoFunctionInvocationContext context, + int index = 0) + { + IList autoFunctionInvocationFilters = context.Kernel.AutoFunctionInvocationFilters; + + if (autoFunctionInvocationFilters is { Count: > 0 } && index < autoFunctionInvocationFilters.Count) + { + await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( + context, + (ctx) => this.InvokeFilterOrFunctionAsync(functionCallCallback, ctx, index + 1) + ).ConfigureAwait(false); + } + else + { + await functionCallCallback(context).ConfigureAwait(false); + } + } + + /// + protected override async ValueTask InvokeFunctionAsync(Microsoft.Extensions.AI.FunctionInvocationContext context, CancellationToken cancellationToken) + { + if (context.Options is null || context.Options is not KernelChatOptions kernelChatOptions) + { + return await context.Function.InvokeAsync(context.Arguments, cancellationToken).ConfigureAwait(false); + } + + object? result = null; + + kernelChatOptions.ChatMessageContent = context.Messages.Last().ToChatMessageContent(); + + var autoContext = new AutoFunctionInvocationContext(kernelChatOptions, context.Function) + { + AIFunction = context.Function, + Arguments = new KernelArguments(context.Arguments) { Services = this.FunctionInvocationServices }, + Messages = context.Messages, + CallContent = context.CallContent, + Iteration = context.Iteration, + FunctionCallIndex = context.FunctionCallIndex, + FunctionCount = context.FunctionCount, + IsStreaming = context.IsStreaming + }; + + autoContext = await this.OnAutoFunctionInvocationAsync( + autoContext, + async (ctx) => + { + // Check if filter requested termination + if (ctx.Terminate) + { + return; + } + + // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any + // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here, + // as the called function could in turn telling the model about itself as a possible candidate for invocation. + result = await autoContext.AIFunction.InvokeAsync(autoContext.Arguments, cancellationToken).ConfigureAwait(false); + ctx.Result = new FunctionResult(ctx.Function, result); + }).ConfigureAwait(false); + result = autoContext.Result.GetValue(); + + context.Terminate = autoContext.Terminate; + + return result; + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientBuilderExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientBuilderExtensions.cs new file mode 100644 index 000000000000..7d30c16ab231 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientBuilderExtensions.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +namespace Microsoft.SemanticKernel; + +/// Provides extensions for configuring instances. +[Experimental("SKEXP0001")] +public static class KernelFunctionInvokingChatClientBuilderExtensions +{ + /// + /// Enables automatic function call invocation on the chat pipeline. + /// + /// This works by adding an instance of with default options. + /// The being used to build the chat pipeline. + /// An optional to use to create a logger for logging function invocations. + /// The supplied . + /// is . + public static ChatClientBuilder UseKernelFunctionInvocation( + this ChatClientBuilder builder, + ILoggerFactory? loggerFactory = null) + { + _ = Throw.IfNull(builder); + + return builder.Use((innerClient, services) => + { + loggerFactory ??= services.GetService(); + + return new KernelFunctionInvokingChatClient(innerClient, loggerFactory, services); + }); + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/AIFunctionArgumentsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/AIFunctionArgumentsExtensions.cs new file mode 100644 index 000000000000..c3af5fe46684 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/AIFunctionArgumentsExtensions.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel; + +internal static class AIFunctionArgumentsExtensions +{ + public const string KernelAIFunctionArgumentKey = $"{nameof(AIFunctionArguments)}_{nameof(Kernel)}"; + + internal static AIFunctionArguments AddKernel(this AIFunctionArguments arguments, Kernel kernel) + { + Verify.NotNull(arguments); + arguments[KernelAIFunctionArgumentKey] = kernel; + + return arguments; + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/AIFunctionExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/AIFunctionExtensions.cs index abf08ba2ca29..df9f24e7d9aa 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/AIFunctionExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/AIFunctionExtensions.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; using Microsoft.Extensions.AI; using Microsoft.SemanticKernel.ChatCompletion; @@ -12,7 +14,7 @@ namespace Microsoft.SemanticKernel; public static class AIFunctionExtensions { /// - /// Converts an to a . + /// Converts an to a if it is not already one. /// /// The AI function to convert. /// The converted . @@ -20,6 +22,32 @@ public static class AIFunctionExtensions public static KernelFunction AsKernelFunction(this AIFunction aiFunction) { Verify.NotNull(aiFunction); - return new AIFunctionKernelFunction(aiFunction); + return aiFunction is KernelFunction kf + ? kf + : new AIFunctionKernelFunction(aiFunction); + } + + /// + /// Invokes the providing a and returns its result. + /// + /// Represents the AI function to be executed. + /// instance to be used when the is a . + /// Contains the arguments required for the AI function execution. + /// Allows for the operation to be canceled if needed. + /// The result of the function execution. + [Experimental("SKEXP0001")] + public static ValueTask InvokeAsync(this AIFunction aiFunction, Kernel kernel, AIFunctionArguments? functionArguments = null, CancellationToken cancellationToken = default) + { + Verify.NotNull(aiFunction); + + AIFunctionArguments? functionArgumentsClone = null; + if (aiFunction is KernelFunction) + { + // If the AIFunction is a KernelFunction inject the provided kernel in the cloned arguments + functionArgumentsClone = new AIFunctionArguments(functionArguments) + .AddKernel(kernel); + } + + return aiFunction.InvokeAsync(functionArgumentsClone ?? functionArguments, cancellationToken); } } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/AIFunctionKernelFunction.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/AIFunctionKernelFunction.cs index 693fedbad951..ee697b48bd1d 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/AIFunctionKernelFunction.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/AIFunctionKernelFunction.cs @@ -16,13 +16,19 @@ namespace Microsoft.SemanticKernel.ChatCompletion; /// Provides a that wraps an . internal sealed class AIFunctionKernelFunction : KernelFunction { + private readonly string? _pluginName; private readonly AIFunction _aiFunction; + public override string Name => ((AIFunction)this).Name; + + public override string? PluginName => (this._pluginName is null && this._aiFunction is KernelFunction kf) ? kf.PluginName : this._pluginName; + public AIFunctionKernelFunction(AIFunction aiFunction) : - base(aiFunction.Name, - aiFunction.Description, - MapParameterMetadata(aiFunction), - aiFunction.JsonSerializerOptions, + base( + name: aiFunction.Name, + description: aiFunction.Description, + parameters: MapParameterMetadata(aiFunction), + jsonSerializerOptions: aiFunction.JsonSerializerOptions, new KernelReturnParameterMetadata(AbstractionsJsonContext.Default.Options) { Description = aiFunction.UnderlyingMethod?.ReturnParameter.GetCustomAttribute()?.Description, @@ -30,24 +36,34 @@ public AIFunctionKernelFunction(AIFunction aiFunction) : Schema = new KernelJsonSchema(AIJsonUtilities.CreateJsonSchema(aiFunction.UnderlyingMethod?.ReturnParameter.ParameterType)), }) { + // Kernel functions created from AI functions are always fully qualified this._aiFunction = aiFunction; } private AIFunctionKernelFunction(AIFunctionKernelFunction other, string pluginName) : base(other.Name, pluginName, other.Description, other.Metadata.Parameters, AbstractionsJsonContext.Default.Options, other.Metadata.ReturnParameter) { + this._pluginName = pluginName; this._aiFunction = other._aiFunction; } - public override KernelFunction Clone(string pluginName) + public override KernelFunction Clone(string? pluginName = null) { Verify.NotNullOrWhiteSpace(pluginName); + return new AIFunctionKernelFunction(this, pluginName); } + public override string ToString() => this.Name; + protected override async ValueTask InvokeCoreAsync( Kernel kernel, KernelArguments arguments, CancellationToken cancellationToken) { + if (this._aiFunction is KernelFunction kernelFunction) + { + return await kernelFunction.InvokeAsync(kernel, arguments, cancellationToken).ConfigureAwait(false); + } + object? result = await this._aiFunction.InvokeAsync(new(arguments), cancellationToken).ConfigureAwait(false); return new FunctionResult(this, result); } @@ -61,6 +77,11 @@ protected override async IAsyncEnumerable InvokeStreamingCoreAsync MapParameterMetadata(AIFunction aiFunction) { + if (aiFunction is KernelFunction kernelFunction) + { + return kernelFunction.Metadata.Parameters; + } + if (!aiFunction.JsonSchema.TryGetProperty("properties", out JsonElement properties)) { return Array.Empty(); diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatClientChatCompletionService.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatClientChatCompletionService.cs index bff679677703..e8251d450a77 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatClientChatCompletionService.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatClientChatCompletionService.cs @@ -3,12 +3,8 @@ using System; using System.Collections.Generic; using System.Collections.ObjectModel; -using System.Diagnostics.CodeAnalysis; -using System.Globalization; using System.Linq; using System.Runtime.CompilerServices; -using System.Text.Json; -using System.Text.Json.Serialization.Metadata; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -23,7 +19,7 @@ internal sealed class ChatClientChatCompletionService : IChatCompletionService private readonly IChatClient _chatClient; /// Initializes the for . - public ChatClientChatCompletionService(IChatClient chatClient, IServiceProvider? serviceProvider) + internal ChatClientChatCompletionService(IChatClient chatClient, IServiceProvider? serviceProvider) { Verify.NotNull(chatClient); @@ -54,9 +50,12 @@ public async Task> GetChatMessageContentsAsync { Verify.NotNull(chatHistory); + var messageList = chatHistory.ToChatMessageList(); + var currentSize = messageList.Count; + var completion = await this._chatClient.GetResponseAsync( - ChatCompletionServiceExtensions.ToChatMessageList(chatHistory), - ToChatOptions(executionSettings, kernel), + messageList, + executionSettings.ToChatOptions(kernel), cancellationToken).ConfigureAwait(false); if (completion.Messages.Count > 0) @@ -64,11 +63,11 @@ public async Task> GetChatMessageContentsAsync // Add all but the last message into the chat history. for (int i = 0; i < completion.Messages.Count - 1; i++) { - chatHistory.Add(ChatCompletionServiceExtensions.ToChatMessageContent(completion.Messages[i], completion)); + chatHistory.Add(completion.Messages[i].ToChatMessageContent(completion)); } // Return the last message as the result. - return [ChatCompletionServiceExtensions.ToChatMessageContent(completion.Messages[completion.Messages.Count - 1], completion)]; + return [completion.Messages[completion.Messages.Count - 1].ToChatMessageContent(completion)]; } return []; @@ -84,240 +83,30 @@ public async IAsyncEnumerable GetStreamingChatMessa ChatRole? role = null; await foreach (var update in this._chatClient.GetStreamingResponseAsync( - ChatCompletionServiceExtensions.ToChatMessageList(chatHistory), - ToChatOptions(executionSettings, kernel), + chatHistory.ToChatMessageList(), + executionSettings.ToChatOptions(kernel), cancellationToken).ConfigureAwait(false)) { role ??= update.Role; fcContents.AddRange(update.Contents.Where(c => c is Microsoft.Extensions.AI.FunctionCallContent or Microsoft.Extensions.AI.FunctionResultContent)); - yield return ToStreamingChatMessageContent(update); - } - - // Add function call content/results to chat history, as other IChatCompletionService streaming implementations do. - chatHistory.Add(ChatCompletionServiceExtensions.ToChatMessageContent(new ChatMessage(role ?? ChatRole.Assistant, fcContents))); - } - - /// Converts a pair of and to a . - private static ChatOptions? ToChatOptions(PromptExecutionSettings? settings, Kernel? kernel) - { - if (settings is null) - { - return null; - } - - if (settings.GetType() != typeof(PromptExecutionSettings)) - { - // If the settings are of a derived type, roundtrip through JSON to the base type in order to try - // to get the derived strongly-typed properties to show up in the loosely-typed ExtensionData dictionary. - // This has the unfortunate effect of making all the ExtensionData values into JsonElements, so we lose - // some type fidelity. (As an alternative, we could introduce new interfaces that could be queried for - // in this method and implemented by the derived settings types to control how they're converted to - // ChatOptions.) - settings = JsonSerializer.Deserialize( - JsonSerializer.Serialize(settings, AbstractionsJsonContext.GetTypeInfo(settings.GetType(), null)), - AbstractionsJsonContext.Default.PromptExecutionSettings); + yield return update.ToStreamingChatMessageContent(); } - ChatOptions options = new() - { - ModelId = settings!.ModelId - }; - - if (settings!.ExtensionData is IDictionary extensionData) + // Message tools and function calls should be individual messages in the history. + foreach (var fcc in fcContents) { - foreach (var entry in extensionData) + if (fcc is Microsoft.Extensions.AI.FunctionCallContent functionCallContent) { - if (entry.Key.Equals("temperature", StringComparison.OrdinalIgnoreCase) && - TryConvert(entry.Value, out float temperature)) - { - options.Temperature = temperature; - } - else if (entry.Key.Equals("top_p", StringComparison.OrdinalIgnoreCase) && - TryConvert(entry.Value, out float topP)) - { - options.TopP = topP; - } - else if (entry.Key.Equals("top_k", StringComparison.OrdinalIgnoreCase) && - TryConvert(entry.Value, out int topK)) - { - options.TopK = topK; - } - else if (entry.Key.Equals("seed", StringComparison.OrdinalIgnoreCase) && - TryConvert(entry.Value, out long seed)) - { - options.Seed = seed; - } - else if (entry.Key.Equals("max_tokens", StringComparison.OrdinalIgnoreCase) && - TryConvert(entry.Value, out int maxTokens)) - { - options.MaxOutputTokens = maxTokens; - } - else if (entry.Key.Equals("frequency_penalty", StringComparison.OrdinalIgnoreCase) && - TryConvert(entry.Value, out float frequencyPenalty)) - { - options.FrequencyPenalty = frequencyPenalty; - } - else if (entry.Key.Equals("presence_penalty", StringComparison.OrdinalIgnoreCase) && - TryConvert(entry.Value, out float presencePenalty)) - { - options.PresencePenalty = presencePenalty; - } - else if (entry.Key.Equals("stop_sequences", StringComparison.OrdinalIgnoreCase) && - TryConvert(entry.Value, out IList? stopSequences)) - { - options.StopSequences = stopSequences; - } - else if (entry.Key.Equals("response_format", StringComparison.OrdinalIgnoreCase) && - entry.Value is { } responseFormat) - { - if (TryConvert(responseFormat, out string? responseFormatString)) - { - options.ResponseFormat = responseFormatString switch - { - "text" => ChatResponseFormat.Text, - "json_object" => ChatResponseFormat.Json, - _ => null, - }; - } - else - { - options.ResponseFormat = responseFormat is JsonElement e ? ChatResponseFormat.ForJsonSchema(e) : null; - } - } - else - { - // Roundtripping a derived PromptExecutionSettings through the base type will have put all the - // object values in AdditionalProperties into JsonElements. Convert them back where possible. - object? value = entry.Value; - if (value is JsonElement jsonElement) - { - value = jsonElement.ValueKind switch - { - JsonValueKind.String => jsonElement.GetString(), - JsonValueKind.Number => jsonElement.GetDouble(), // not perfect, but a reasonable heuristic - JsonValueKind.True => true, - JsonValueKind.False => false, - JsonValueKind.Null => null, - _ => value, - }; - - if (jsonElement.ValueKind == JsonValueKind.Array) - { - var enumerator = jsonElement.EnumerateArray(); - - var enumeratorType = enumerator.MoveNext() ? enumerator.Current.ValueKind : JsonValueKind.Null; - - switch (enumeratorType) - { - case JsonValueKind.String: - value = enumerator.Select(e => e.GetString()); - break; - case JsonValueKind.Number: - value = enumerator.Select(e => e.GetDouble()); - break; - case JsonValueKind.True or JsonValueKind.False: - value = enumerator.Select(e => e.ValueKind == JsonValueKind.True); - break; - } - } - } - - (options.AdditionalProperties ??= [])[entry.Key] = value; - } + chatHistory.Add(new ChatMessage(ChatRole.Assistant, [functionCallContent]).ToChatMessageContent()); + continue; } - } - if (settings.FunctionChoiceBehavior?.GetConfiguration(new([]) { Kernel = kernel }).Functions is { Count: > 0 } functions) - { - options.ToolMode = settings.FunctionChoiceBehavior is RequiredFunctionChoiceBehavior ? ChatToolMode.RequireAny : ChatToolMode.Auto; - options.Tools = functions.Select(f => f.AsAIFunction(kernel)).Cast().ToList(); - } - - return options; - - // Be a little lenient on the types of the values used in the extension data, - // e.g. allow doubles even when requesting floats. - static bool TryConvert(object? value, [NotNullWhen(true)] out T? result) - { - if (value is not null) + if (fcc is Microsoft.Extensions.AI.FunctionResultContent functionResultContent) { - // If the value is a T, use it. - if (value is T typedValue) - { - result = typedValue; - return true; - } - - if (value is JsonElement json) - { - // If the value is JsonElement, it likely resulted from JSON serializing as object. - // Try to deserialize it as a T. This currently will only be successful either when - // reflection-based serialization is enabled or T is one of the types special-cased - // in the AbstractionsJsonContext. For other cases with NativeAOT, we would need to - // have a JsonSerializationOptions with the relevant type information. - if (AbstractionsJsonContext.TryGetTypeInfo(typeof(T), firstOptions: null, out JsonTypeInfo? jti)) - { - try - { - result = (T)json.Deserialize(jti)!; - return true; - } - catch (Exception e) when (e is ArgumentException or JsonException or NotSupportedException or InvalidOperationException) - { - } - } - } - else - { - // Otherwise, try to convert it to a T using Convert, in particular to handle conversions between numeric primitive types. - try - { - result = (T)Convert.ChangeType(value, typeof(T), CultureInfo.InvariantCulture); - return true; - } - catch (Exception e) when (e is ArgumentException or FormatException or InvalidCastException or OverflowException) - { - } - } + chatHistory.Add(new ChatMessage(ChatRole.Tool, [functionResultContent]).ToChatMessageContent()); } - - result = default; - return false; } } - - /// Converts a to a . - /// This conversion should not be necessary once SK eventually adopts the shared content types. - private static StreamingChatMessageContent ToStreamingChatMessageContent(ChatResponseUpdate update) - { - StreamingChatMessageContent content = new( - update.Role is not null ? new AuthorRole(update.Role.Value.Value) : null, - null) - { - InnerContent = update.RawRepresentation, - Metadata = update.AdditionalProperties, - ModelId = update.ModelId - }; - - foreach (AIContent item in update.Contents) - { - StreamingKernelContent? resultContent = - item is Microsoft.Extensions.AI.TextContent tc ? new Microsoft.SemanticKernel.StreamingTextContent(tc.Text) : - item is Microsoft.Extensions.AI.FunctionCallContent fcc ? - new Microsoft.SemanticKernel.StreamingFunctionCallUpdateContent(fcc.CallId, fcc.Name, fcc.Arguments is not null ? - JsonSerializer.Serialize(fcc.Arguments!, AbstractionsJsonContext.Default.IDictionaryStringObject!) : - null) : - null; - - if (resultContent is not null) - { - resultContent.ModelId = update.ModelId; - content.Items.Add(resultContent); - } - } - - return content; - } } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceChatClient.cs index da55fa60cee7..45b912a7d884 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceChatClient.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -19,7 +18,7 @@ internal sealed class ChatCompletionServiceChatClient : IChatClient private readonly IChatCompletionService _chatCompletionService; /// Initializes the for . - public ChatCompletionServiceChatClient(IChatCompletionService chatCompletionService) + internal ChatCompletionServiceChatClient(IChatCompletionService chatCompletionService) { Verify.NotNull(chatCompletionService); @@ -39,12 +38,12 @@ public ChatCompletionServiceChatClient(IChatCompletionService chatCompletionServ { Verify.NotNull(messages); - ChatHistory chatHistory = new(messages.Select(m => ChatCompletionServiceExtensions.ToChatMessageContent(m))); + ChatHistory chatHistory = new(messages.Select(m => m.ToChatMessageContent())); int preCount = chatHistory.Count; var response = await this._chatCompletionService.GetChatMessageContentAsync( chatHistory, - ToPromptExecutionSettings(options), + options.ToPromptExecutionSettings(), kernel: null, cancellationToken).ConfigureAwait(false); @@ -58,10 +57,10 @@ public ChatCompletionServiceChatClient(IChatCompletionService chatCompletionServ // Then add the result message. for (int i = preCount; i < chatHistory.Count; i++) { - chatResponse.Messages.Add(ChatCompletionServiceExtensions.ToChatMessage(chatHistory[i])); + chatResponse.Messages.Add(chatHistory[i].ToChatMessage()); } - chatResponse.Messages.Add(ChatCompletionServiceExtensions.ToChatMessage(response)); + chatResponse.Messages.Add(response.ToChatMessage()); return chatResponse; } @@ -72,12 +71,12 @@ public async IAsyncEnumerable GetStreamingResponseAsync(IEnu Verify.NotNull(messages); await foreach (var update in this._chatCompletionService.GetStreamingChatMessageContentsAsync( - new ChatHistory(messages.Select(m => ChatCompletionServiceExtensions.ToChatMessageContent(m))), - ToPromptExecutionSettings(options), + new ChatHistory(messages.Select(m => m.ToChatMessageContent())), + options.ToPromptExecutionSettings(), kernel: null, cancellationToken).ConfigureAwait(false)) { - yield return ToStreamingChatCompletionUpdate(update); + yield return update.ToChatResponseUpdate(); } } @@ -99,151 +98,4 @@ public void Dispose() serviceType.IsInstanceOfType(this.Metadata) ? this.Metadata : null; } - - /// Converts a to a . - private static PromptExecutionSettings? ToPromptExecutionSettings(ChatOptions? options) - { - if (options is null) - { - return null; - } - - PromptExecutionSettings settings = new() - { - ExtensionData = new Dictionary(StringComparer.OrdinalIgnoreCase), - ModelId = options.ModelId, - }; - - // Transfer over all strongly-typed members of ChatOptions. We do not know the exact name the derived PromptExecutionSettings - // will pick for these options, so we just use the most common choice for each. (We could make this more exact by having an - // IPromptExecutionSettingsFactory interface with a method like `PromptExecutionSettings Create(ChatOptions options)`; that - // interface could then optionally be implemented by an IChatCompletionService, and this implementation could just ask the - // chat completion service to produce the PromptExecutionSettings it wants. But, this is already a problem - // with PromptExecutionSettings, regardless of ChatOptions... someone creating a PES without knowing what backend is being - // used has to guess at the names to use.) - - if (options.Temperature is not null) - { - settings.ExtensionData["temperature"] = options.Temperature.Value; - } - - if (options.MaxOutputTokens is not null) - { - settings.ExtensionData["max_tokens"] = options.MaxOutputTokens.Value; - } - - if (options.FrequencyPenalty is not null) - { - settings.ExtensionData["frequency_penalty"] = options.FrequencyPenalty.Value; - } - - if (options.PresencePenalty is not null) - { - settings.ExtensionData["presence_penalty"] = options.PresencePenalty.Value; - } - - if (options.StopSequences is not null) - { - settings.ExtensionData["stop_sequences"] = options.StopSequences; - } - - if (options.TopP is not null) - { - settings.ExtensionData["top_p"] = options.TopP.Value; - } - - if (options.TopK is not null) - { - settings.ExtensionData["top_k"] = options.TopK.Value; - } - - if (options.Seed is not null) - { - settings.ExtensionData["seed"] = options.Seed.Value; - } - - if (options.ResponseFormat is not null) - { - if (options.ResponseFormat is ChatResponseFormatText) - { - settings.ExtensionData["response_format"] = "text"; - } - else if (options.ResponseFormat is ChatResponseFormatJson json) - { - settings.ExtensionData["response_format"] = json.Schema is JsonElement schema ? - JsonSerializer.Deserialize(schema, AbstractionsJsonContext.Default.JsonElement) : - "json_object"; - } - } - - // Transfer over loosely-typed members of ChatOptions. - - if (options.AdditionalProperties is not null) - { - foreach (var kvp in options.AdditionalProperties) - { - if (kvp.Value is not null) - { - settings.ExtensionData[kvp.Key] = kvp.Value; - } - } - } - - // Transfer over tools. For IChatClient, we do not want automatic invocation, as that's a concern left up to - // components like FunctionInvocationChatClient. As such, based on the tool mode, we map to the appropriate - // FunctionChoiceBehavior, but always with autoInvoke: false. - - if (options.Tools is { Count: > 0 }) - { - var functions = options.Tools.OfType().Select(aiFunction => aiFunction.AsKernelFunction()); - settings.FunctionChoiceBehavior = - options.ToolMode is null or AutoChatToolMode ? FunctionChoiceBehavior.Auto(functions, autoInvoke: false) : - options.ToolMode is RequiredChatToolMode { RequiredFunctionName: null } ? FunctionChoiceBehavior.Required(functions, autoInvoke: false) : - options.ToolMode is RequiredChatToolMode { RequiredFunctionName: string functionName } ? FunctionChoiceBehavior.Required(functions.Where(f => f.Name == functionName), autoInvoke: false) : - null; - } - - return settings; - } - - /// Converts a to a . - /// This conversion should not be necessary once SK eventually adopts the shared content types. - private static ChatResponseUpdate ToStreamingChatCompletionUpdate(StreamingChatMessageContent content) - { - ChatResponseUpdate update = new() - { - AdditionalProperties = content.Metadata is not null ? new AdditionalPropertiesDictionary(content.Metadata) : null, - AuthorName = content.AuthorName, - ModelId = content.ModelId, - RawRepresentation = content, - Role = content.Role is not null ? new ChatRole(content.Role.Value.Label) : null, - }; - - foreach (var item in content.Items) - { - AIContent? aiContent = null; - switch (item) - { - case Microsoft.SemanticKernel.StreamingTextContent tc: - aiContent = new Microsoft.Extensions.AI.TextContent(tc.Text); - break; - - case Microsoft.SemanticKernel.StreamingFunctionCallUpdateContent fcc: - aiContent = new Microsoft.Extensions.AI.FunctionCallContent( - fcc.CallId ?? string.Empty, - fcc.Name ?? string.Empty, - fcc.Arguments is not null ? JsonSerializer.Deserialize>(fcc.Arguments, AbstractionsJsonContext.Default.IDictionaryStringObject!) : null); - break; - } - - if (aiContent is not null) - { - aiContent.RawRepresentation = content; - - update.Contents.Add(aiContent); - } - } - - return update; - } } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceExtensions.cs index f61a61a6687e..844d940e5e54 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatCompletionServiceExtensions.cs @@ -128,128 +128,4 @@ public static IChatClient AsChatClient(this IChatCompletionService service) chatClient : new ChatCompletionServiceChatClient(service); } - - /// Creates an for the specified . - /// The chat client to be represented as a chat completion service. - /// An optional that can be used to resolve services to use in the instance. - /// - /// The . If is an , will - /// be returned. Otherwise, a new will be created that wraps . - /// - [Experimental("SKEXP0001")] - public static IChatCompletionService AsChatCompletionService(this IChatClient client, IServiceProvider? serviceProvider = null) - { - Verify.NotNull(client); - - return client is IChatCompletionService chatCompletionService ? - chatCompletionService : - new ChatClientChatCompletionService(client, serviceProvider); - } - - /// Converts a to a . - /// This conversion should not be necessary once SK eventually adopts the shared content types. - internal static ChatMessage ToChatMessage(ChatMessageContent content) - { - ChatMessage message = new() - { - AdditionalProperties = content.Metadata is not null ? new(content.Metadata) : null, - AuthorName = content.AuthorName, - RawRepresentation = content.InnerContent, - Role = content.Role.Label is string label ? new ChatRole(label) : ChatRole.User, - }; - - foreach (var item in content.Items) - { - AIContent? aiContent = null; - switch (item) - { - case Microsoft.SemanticKernel.TextContent tc: - aiContent = new Microsoft.Extensions.AI.TextContent(tc.Text); - break; - - case Microsoft.SemanticKernel.ImageContent ic: - aiContent = - ic.DataUri is not null ? new Microsoft.Extensions.AI.DataContent(ic.DataUri, ic.MimeType) : - ic.Uri is not null ? new Microsoft.Extensions.AI.UriContent(ic.Uri, ic.MimeType ?? "image/*") : - null; - break; - - case Microsoft.SemanticKernel.AudioContent ac: - aiContent = - ac.DataUri is not null ? new Microsoft.Extensions.AI.DataContent(ac.DataUri, ac.MimeType) : - ac.Uri is not null ? new Microsoft.Extensions.AI.UriContent(ac.Uri, ac.MimeType ?? "audio/*") : - null; - break; - - case Microsoft.SemanticKernel.BinaryContent bc: - aiContent = - bc.DataUri is not null ? new Microsoft.Extensions.AI.DataContent(bc.DataUri, bc.MimeType) : - bc.Uri is not null ? new Microsoft.Extensions.AI.UriContent(bc.Uri, bc.MimeType ?? "application/octet-stream") : - null; - break; - - case Microsoft.SemanticKernel.FunctionCallContent fcc: - aiContent = new Microsoft.Extensions.AI.FunctionCallContent(fcc.Id ?? string.Empty, fcc.FunctionName, fcc.Arguments); - break; - - case Microsoft.SemanticKernel.FunctionResultContent frc: - aiContent = new Microsoft.Extensions.AI.FunctionResultContent(frc.CallId ?? string.Empty, frc.Result); - break; - } - - if (aiContent is not null) - { - aiContent.RawRepresentation = item.InnerContent; - aiContent.AdditionalProperties = item.Metadata is not null ? new(item.Metadata) : null; - - message.Contents.Add(aiContent); - } - } - - return message; - } - - /// Converts a to a . - /// This conversion should not be necessary once SK eventually adopts the shared content types. - internal static ChatMessageContent ToChatMessageContent(ChatMessage message, Microsoft.Extensions.AI.ChatResponse? response = null) - { - ChatMessageContent result = new() - { - ModelId = response?.ModelId, - AuthorName = message.AuthorName, - InnerContent = response?.RawRepresentation ?? message.RawRepresentation, - Metadata = message.AdditionalProperties, - Role = new AuthorRole(message.Role.Value), - }; - - foreach (AIContent content in message.Contents) - { - KernelContent? resultContent = content switch - { - Microsoft.Extensions.AI.TextContent tc => new Microsoft.SemanticKernel.TextContent(tc.Text), - Microsoft.Extensions.AI.DataContent dc when dc.HasTopLevelMediaType("image") => new Microsoft.SemanticKernel.ImageContent(dc.Uri), - Microsoft.Extensions.AI.UriContent uc when uc.HasTopLevelMediaType("image") => new Microsoft.SemanticKernel.ImageContent(uc.Uri), - Microsoft.Extensions.AI.DataContent dc when dc.HasTopLevelMediaType("audio") => new Microsoft.SemanticKernel.AudioContent(dc.Uri), - Microsoft.Extensions.AI.UriContent uc when uc.HasTopLevelMediaType("audio") => new Microsoft.SemanticKernel.AudioContent(uc.Uri), - Microsoft.Extensions.AI.DataContent dc => new Microsoft.SemanticKernel.BinaryContent(dc.Uri), - Microsoft.Extensions.AI.UriContent uc => new Microsoft.SemanticKernel.BinaryContent(uc.Uri), - Microsoft.Extensions.AI.FunctionCallContent fcc => new Microsoft.SemanticKernel.FunctionCallContent(fcc.Name, null, fcc.CallId, fcc.Arguments is not null ? new(fcc.Arguments) : null), - Microsoft.Extensions.AI.FunctionResultContent frc => new Microsoft.SemanticKernel.FunctionResultContent(callId: frc.CallId, result: frc.Result), - _ => null - }; - - if (resultContent is not null) - { - resultContent.Metadata = content.AdditionalProperties; - resultContent.InnerContent = content.RawRepresentation; - resultContent.ModelId = response?.ModelId; - result.Items.Add(resultContent); - } - } - - return result; - } - - internal static List ToChatMessageList(ChatHistory chatHistory) - => chatHistory.Select(ToChatMessage).ToList(); } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs index 22968c47ea38..147cdd5ba332 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs @@ -18,6 +18,14 @@ public class ChatHistory : IList, IReadOnlyListThe messages. private readonly List _messages; + private Action? _overrideAdd; + private Func? _overrideRemove; + private Action? _overrideClear; + private Action? _overrideInsert; + private Action? _overrideRemoveAt; + private Action? _overrideRemoveRange; + private Action>? _overrideAddRange; + /// Initializes an empty history. /// /// Creates a new instance of the class @@ -27,6 +35,38 @@ public ChatHistory() this._messages = []; } + // This allows observation of the chat history changes by-reference reflecting in an + // internal IEnumerable when used from IChatClients + // with AutoFunctionInvocationFilters + internal void SetOverrides( + Action overrideAdd, + Func overrideRemove, + Action onClear, + Action overrideInsert, + Action overrideRemoveAt, + Action overrideRemoveRange, + Action> overrideAddRange) + { + this._overrideAdd = overrideAdd; + this._overrideRemove = overrideRemove; + this._overrideClear = onClear; + this._overrideInsert = overrideInsert; + this._overrideRemoveAt = overrideRemoveAt; + this._overrideRemoveRange = overrideRemoveRange; + this._overrideAddRange = overrideAddRange; + } + + internal void ClearOverrides() + { + this._overrideAdd = null; + this._overrideRemove = null; + this._overrideClear = null; + this._overrideInsert = null; + this._overrideRemoveAt = null; + this._overrideRemoveRange = null; + this._overrideAddRange = null; + } + /// /// Creates a new instance of the with a first message in the provided . /// If not role is provided then the first message will default to role. @@ -37,8 +77,7 @@ public ChatHistory(string message, AuthorRole role) { Verify.NotNullOrWhiteSpace(message); - this._messages = []; - this.Add(new ChatMessageContent(role, message)); + this._messages = [new ChatMessageContent(role, message)]; } /// @@ -60,7 +99,7 @@ public ChatHistory(IEnumerable messages) } /// Gets the number of messages in the history. - public int Count => this._messages.Count; + public virtual int Count => this._messages.Count; /// /// Role of the message author @@ -118,29 +157,32 @@ public void AddDeveloperMessage(string content) => /// Adds a message to the history. /// The message to add. /// is null. - public void Add(ChatMessageContent item) + public virtual void Add(ChatMessageContent item) { Verify.NotNull(item); this._messages.Add(item); + this._overrideAdd?.Invoke(item); } /// Adds the messages to the history. /// The collection whose messages should be added to the history. /// is null. - public void AddRange(IEnumerable items) + public virtual void AddRange(IEnumerable items) { Verify.NotNull(items); this._messages.AddRange(items); + this._overrideAddRange?.Invoke(items); } /// Inserts a message into the history at the specified index. /// The index at which the item should be inserted. /// The message to insert. /// is null. - public void Insert(int index, ChatMessageContent item) + public virtual void Insert(int index, ChatMessageContent item) { Verify.NotNull(item); this._messages.Insert(index, item); + this._overrideInsert?.Invoke(index, item); } /// @@ -151,17 +193,22 @@ public void Insert(int index, ChatMessageContent item) /// is null. /// The number of messages in the history is greater than the available space from to the end of . /// is less than 0. - public void CopyTo(ChatMessageContent[] array, int arrayIndex) => this._messages.CopyTo(array, arrayIndex); + public virtual void CopyTo(ChatMessageContent[] array, int arrayIndex) + => this._messages.CopyTo(array, arrayIndex); /// Removes all messages from the history. - public void Clear() => this._messages.Clear(); + public virtual void Clear() + { + this._messages.Clear(); + this._overrideClear?.Invoke(); + } /// Gets or sets the message at the specified index in the history. /// The index of the message to get or set. /// The message at the specified index. /// is null. /// The was not valid for this history. - public ChatMessageContent this[int index] + public virtual ChatMessageContent this[int index] { get => this._messages[index]; set @@ -175,7 +222,7 @@ public ChatMessageContent this[int index] /// The message to locate. /// true if the message is found in the history; otherwise, false. /// is null. - public bool Contains(ChatMessageContent item) + public virtual bool Contains(ChatMessageContent item) { Verify.NotNull(item); return this._messages.Contains(item); @@ -185,7 +232,7 @@ public bool Contains(ChatMessageContent item) /// The message to locate. /// The index of the first found occurrence of the specified message; -1 if the message could not be found. /// is null. - public int IndexOf(ChatMessageContent item) + public virtual int IndexOf(ChatMessageContent item) { Verify.NotNull(item); return this._messages.IndexOf(item); @@ -194,16 +241,22 @@ public int IndexOf(ChatMessageContent item) /// Removes the message at the specified index from the history. /// The index of the message to remove. /// The was not valid for this history. - public void RemoveAt(int index) => this._messages.RemoveAt(index); + public virtual void RemoveAt(int index) + { + this._messages.RemoveAt(index); + this._overrideRemoveAt?.Invoke(index); + } /// Removes the first occurrence of the specified message from the history. /// The message to remove from the history. /// true if the item was successfully removed; false if it wasn't located in the history. /// is null. - public bool Remove(ChatMessageContent item) + public virtual bool Remove(ChatMessageContent item) { Verify.NotNull(item); - return this._messages.Remove(item); + var result = this._messages.Remove(item); + this._overrideRemove?.Invoke(item); + return result; } /// @@ -214,9 +267,10 @@ public bool Remove(ChatMessageContent item) /// is less than 0. /// is less than 0. /// and do not denote a valid range of messages. - public void RemoveRange(int index, int count) + public virtual void RemoveRange(int index, int count) { this._messages.RemoveRange(index, count); + this._overrideRemoveRange?.Invoke(index, count); } /// diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs index 7095c2dfacc4..381e073a1446 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs @@ -4,6 +4,7 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel.ChatCompletion; @@ -78,4 +79,68 @@ public static async Task ReduceAsync(this ChatHistory chatHistory, return chatHistory; } + + /// Converts a to a list. + /// The chat history to convert. + /// A list of objects. + internal static List ToChatMessageList(this ChatHistory chatHistory) + => chatHistory.Select(m => m.ToChatMessage()).ToList(); + + internal static void SetChatMessageHandlers(this ChatHistory chatHistory, IList messages) + { + chatHistory.SetOverrides(Add, Remove, Clear, Insert, RemoveAt, RemoveRange, AddRange); + + void Add(ChatMessageContent item) + { + messages.Add(item.ToChatMessage()); + } + + void Clear() + { + messages.Clear(); + } + + bool Remove(ChatMessageContent item) + { + var index = chatHistory.IndexOf(item); + + if (index < 0) + { + return false; + } + + messages.RemoveAt(index); + + return true; + } + + void Insert(int index, ChatMessageContent item) + { + messages.Insert(index, item.ToChatMessage()); + } + + void RemoveAt(int index) + { + messages.RemoveAt(index); + } + + void RemoveRange(int index, int count) + { + if (messages is List messageList) + { + messageList.RemoveRange(index, count); + return; + } + + foreach (var chatMessage in messages.Skip(index).Take(count)) + { + messages.Remove(chatMessage); + } + } + + void AddRange(IEnumerable items) + { + messages.AddRange(items.Select(i => i.ToChatMessage())); + } + } } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs new file mode 100644 index 000000000000..d92fb15d80be --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs @@ -0,0 +1,214 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel; + +/// Extensions methods for . +public static class PromptExecutionSettingsExtensions +{ + /// Converts a pair of and to a . + public static ChatOptions? ToChatOptions(this PromptExecutionSettings? settings, Kernel? kernel) + { + if (settings is null) + { + return null; + } + + if (settings.GetType() != typeof(PromptExecutionSettings)) + { + // If the settings are of a derived type, roundtrip through JSON to the base type in order to try + // to get the derived strongly-typed properties to show up in the loosely-typed ExtensionData dictionary. + // This has the unfortunate effect of making all the ExtensionData values into JsonElements, so we lose + // some type fidelity. (As an alternative, we could introduce new interfaces that could be queried for + // in this method and implemented by the derived settings types to control how they're converted to + // ChatOptions.) + settings = JsonSerializer.Deserialize( + JsonSerializer.Serialize(settings, AbstractionsJsonContext.GetTypeInfo(settings.GetType(), null)), + AbstractionsJsonContext.Default.PromptExecutionSettings); + } + + ChatOptions options = new() + { + ModelId = settings!.ModelId + }; + + if (settings!.ExtensionData is IDictionary extensionData) + { + foreach (var entry in extensionData) + { + if (entry.Key.Equals("temperature", StringComparison.OrdinalIgnoreCase) && + TryConvert(entry.Value, out float temperature)) + { + options.Temperature = temperature; + } + else if (entry.Key.Equals("top_p", StringComparison.OrdinalIgnoreCase) && + TryConvert(entry.Value, out float topP)) + { + options.TopP = topP; + } + else if (entry.Key.Equals("top_k", StringComparison.OrdinalIgnoreCase) && + TryConvert(entry.Value, out int topK)) + { + options.TopK = topK; + } + else if (entry.Key.Equals("seed", StringComparison.OrdinalIgnoreCase) && + TryConvert(entry.Value, out long seed)) + { + options.Seed = seed; + } + else if (entry.Key.Equals("max_tokens", StringComparison.OrdinalIgnoreCase) && + TryConvert(entry.Value, out int maxTokens)) + { + options.MaxOutputTokens = maxTokens; + } + else if (entry.Key.Equals("frequency_penalty", StringComparison.OrdinalIgnoreCase) && + TryConvert(entry.Value, out float frequencyPenalty)) + { + options.FrequencyPenalty = frequencyPenalty; + } + else if (entry.Key.Equals("presence_penalty", StringComparison.OrdinalIgnoreCase) && + TryConvert(entry.Value, out float presencePenalty)) + { + options.PresencePenalty = presencePenalty; + } + else if (entry.Key.Equals("stop_sequences", StringComparison.OrdinalIgnoreCase) && + TryConvert(entry.Value, out IList? stopSequences)) + { + options.StopSequences = stopSequences; + } + else if (entry.Key.Equals("response_format", StringComparison.OrdinalIgnoreCase) && + entry.Value is { } responseFormat) + { + if (TryConvert(responseFormat, out string? responseFormatString)) + { + options.ResponseFormat = responseFormatString switch + { + "text" => ChatResponseFormat.Text, + "json_object" => ChatResponseFormat.Json, + _ => null, + }; + } + else + { + options.ResponseFormat = responseFormat is JsonElement e ? ChatResponseFormat.ForJsonSchema(e) : null; + } + } + else + { + // Roundtripping a derived PromptExecutionSettings through the base type will have put all the + // object values in AdditionalProperties into JsonElements. Convert them back where possible. + object? value = entry.Value; + if (value is JsonElement jsonElement) + { + value = jsonElement.ValueKind switch + { + JsonValueKind.String => jsonElement.GetString(), + JsonValueKind.Number => jsonElement.GetDouble(), // not perfect, but a reasonable heuristic + JsonValueKind.True => true, + JsonValueKind.False => false, + JsonValueKind.Null => null, + _ => value, + }; + + if (jsonElement.ValueKind == JsonValueKind.Array) + { + var enumerator = jsonElement.EnumerateArray(); + + var enumeratorType = enumerator.MoveNext() ? enumerator.Current.ValueKind : JsonValueKind.Null; + + switch (enumeratorType) + { + case JsonValueKind.String: + value = enumerator.Select(e => e.GetString()); + break; + case JsonValueKind.Number: + value = enumerator.Select(e => e.GetDouble()); + break; + case JsonValueKind.True or JsonValueKind.False: + value = enumerator.Select(e => e.ValueKind == JsonValueKind.True); + break; + } + } + } + + (options.AdditionalProperties ??= [])[entry.Key] = value; + } + } + } + + if (settings.FunctionChoiceBehavior?.GetConfiguration(new([]) { Kernel = kernel }).Functions is { Count: > 0 } functions) + { + options.ToolMode = settings.FunctionChoiceBehavior is RequiredFunctionChoiceBehavior ? ChatToolMode.RequireAny : ChatToolMode.Auto; + options.Tools = []; + foreach (var function in functions) + { + // Clone the function to ensure it works running as a AITool lower-level abstraction for the specified kernel. + var functionClone = function.WithKernel(kernel); + options.Tools.Add(functionClone); + } + } + + // Enables usage of AutoFunctionInvocationFilters + return kernel is null + ? options + : new KernelChatOptions(kernel, options, settings: settings); + + // Be a little lenient on the types of the values used in the extension data, + // e.g. allow doubles even when requesting floats. + static bool TryConvert(object? value, [NotNullWhen(true)] out T? result) + { + if (value is not null) + { + // If the value is a T, use it. + if (value is T typedValue) + { + result = typedValue; + return true; + } + + if (value is JsonElement json) + { + // If the value is JsonElement, it likely resulted from JSON serializing as object. + // Try to deserialize it as a T. This currently will only be successful either when + // reflection-based serialization is enabled or T is one of the types special-cased + // in the AbstractionsJsonContext. For other cases with NativeAOT, we would need to + // have a JsonSerializationOptions with the relevant type information. + if (AbstractionsJsonContext.TryGetTypeInfo(typeof(T), firstOptions: null, out JsonTypeInfo? jti)) + { + try + { + result = (T)json.Deserialize(jti)!; + return true; + } + catch (Exception e) when (e is ArgumentException or JsonException or NotSupportedException or InvalidOperationException) + { + } + } + } + else + { + // Otherwise, try to convert it to a T using Convert, in particular to handle conversions between numeric primitive types. + try + { + result = (T)Convert.ChangeType(value, typeof(T), CultureInfo.InvariantCulture); + return true; + } + catch (Exception e) when (e is ArgumentException or FormatException or InvalidCastException or OverflowException) + { + } + } + } + + result = default; + return false; + } + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AbstractionsJsonContext.cs b/dotnet/src/SemanticKernel.Abstractions/AbstractionsJsonContext.cs index 7f7b37c5d754..ae98b5b55051 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AbstractionsJsonContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AbstractionsJsonContext.cs @@ -68,7 +68,9 @@ private static JsonSerializerOptions CreateDefaultToolJsonOptions() // and we want to be flexible in terms of what can be put into the various collections in the object model. // Otherwise, use the source-generated options to enable trimming and Native AOT. - if (JsonSerializer.IsReflectionEnabledByDefault) + if (JsonSerializer.IsReflectionEnabledByDefault + // This is a workaround for the fact that the default options are not available when running in Native AOT. + || Default is null) { // Keep in sync with the JsonSourceGenerationOptions attribute on JsonContext above. JsonSerializerOptions options = new(JsonSerializerDefaults.Web) diff --git a/dotnet/src/SemanticKernel.Abstractions/CompatibilitySuppressions.xml b/dotnet/src/SemanticKernel.Abstractions/CompatibilitySuppressions.xml new file mode 100644 index 000000000000..da61649a30bd --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/CompatibilitySuppressions.xml @@ -0,0 +1,18 @@ + + + + + CP0002 + M:Microsoft.SemanticKernel.ChatCompletion.ChatCompletionServiceExtensions.AsChatCompletionService(Microsoft.Extensions.AI.IChatClient,System.IServiceProvider) + lib/net8.0/Microsoft.SemanticKernel.Abstractions.dll + lib/net8.0/Microsoft.SemanticKernel.Abstractions.dll + true + + + CP0002 + M:Microsoft.SemanticKernel.ChatCompletion.ChatCompletionServiceExtensions.AsChatCompletionService(Microsoft.Extensions.AI.IChatClient,System.IServiceProvider) + lib/netstandard2.0/Microsoft.SemanticKernel.Abstractions.dll + lib/netstandard2.0/Microsoft.SemanticKernel.Abstractions.dll + true + + \ No newline at end of file diff --git a/dotnet/src/SemanticKernel.Abstractions/Contents/ChatMessageContentExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/Contents/ChatMessageContentExtensions.cs new file mode 100644 index 000000000000..276d500ce787 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Contents/ChatMessageContentExtensions.cs @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel; + +internal static class ChatMessageContentExtensions +{ + /// Converts a to a . + /// This conversion should not be necessary once SK eventually adopts the shared content types. + internal static ChatMessage ToChatMessage(this ChatMessageContent content) + { + ChatMessage message = new() + { + AdditionalProperties = content.Metadata is not null ? new(content.Metadata) : null, + AuthorName = content.AuthorName, + RawRepresentation = content.InnerContent, + Role = content.Role.Label is string label ? new ChatRole(label) : ChatRole.User, + }; + + foreach (var item in content.Items) + { + AIContent? aiContent = null; + switch (item) + { + case Microsoft.SemanticKernel.TextContent tc: + aiContent = new Microsoft.Extensions.AI.TextContent(tc.Text); + break; + + case Microsoft.SemanticKernel.ImageContent ic: + aiContent = + ic.DataUri is not null ? new Microsoft.Extensions.AI.DataContent(ic.DataUri, ic.MimeType) : + ic.Uri is not null ? new Microsoft.Extensions.AI.UriContent(ic.Uri, ic.MimeType ?? "image/*") : + null; + break; + + case Microsoft.SemanticKernel.AudioContent ac: + aiContent = + ac.DataUri is not null ? new Microsoft.Extensions.AI.DataContent(ac.DataUri, ac.MimeType) : + ac.Uri is not null ? new Microsoft.Extensions.AI.UriContent(ac.Uri, ac.MimeType ?? "audio/*") : + null; + break; + + case Microsoft.SemanticKernel.BinaryContent bc: + aiContent = + bc.DataUri is not null ? new Microsoft.Extensions.AI.DataContent(bc.DataUri, bc.MimeType) : + bc.Uri is not null ? new Microsoft.Extensions.AI.UriContent(bc.Uri, bc.MimeType ?? "application/octet-stream") : + null; + break; + + case Microsoft.SemanticKernel.FunctionCallContent fcc: + aiContent = new Microsoft.Extensions.AI.FunctionCallContent(fcc.Id ?? string.Empty, fcc.FunctionName, fcc.Arguments); + break; + + case Microsoft.SemanticKernel.FunctionResultContent frc: + aiContent = new Microsoft.Extensions.AI.FunctionResultContent(frc.CallId ?? string.Empty, frc.Result); + break; + } + + if (aiContent is not null) + { + aiContent.RawRepresentation = item.InnerContent; + aiContent.AdditionalProperties = item.Metadata is not null ? new(item.Metadata) : null; + + message.Contents.Add(aiContent); + } + } + + return message; + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Contents/StreamingChatMessageContentExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/Contents/StreamingChatMessageContentExtensions.cs new file mode 100644 index 000000000000..78e2f8445a78 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Contents/StreamingChatMessageContentExtensions.cs @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Text.Json; +using Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel; + +/// Provides extension methods for . +internal static class StreamingChatMessageContentExtensions +{ + /// Converts a to a . + /// This conversion should not be necessary once SK eventually adopts the shared content types. + internal static ChatResponseUpdate ToChatResponseUpdate(this StreamingChatMessageContent content) + { + ChatResponseUpdate update = new() + { + AdditionalProperties = content.Metadata is not null ? new AdditionalPropertiesDictionary(content.Metadata) : null, + AuthorName = content.AuthorName, + ModelId = content.ModelId, + RawRepresentation = content.InnerContent, + Role = content.Role is not null ? new ChatRole(content.Role.Value.Label) : null, + }; + + foreach (var item in content.Items) + { + AIContent? aiContent = null; + switch (item) + { + case Microsoft.SemanticKernel.StreamingTextContent tc: + aiContent = new Microsoft.Extensions.AI.TextContent(tc.Text); + break; + + case Microsoft.SemanticKernel.StreamingFunctionCallUpdateContent fcc: + aiContent = new Microsoft.Extensions.AI.FunctionCallContent( + fcc.CallId ?? string.Empty, + fcc.Name ?? string.Empty, + !string.IsNullOrWhiteSpace(fcc.Arguments) ? JsonSerializer.Deserialize>(fcc.Arguments, AbstractionsJsonContext.Default.IDictionaryStringObject!) : null); + break; + } + + if (aiContent is not null) + { + aiContent.RawRepresentation = content; + + update.Contents.Add(aiContent); + } + } + + return update; + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs index 0f18be8df8e0..53551fe67a3c 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs @@ -1,6 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; using System.Threading; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.SemanticKernel; @@ -8,8 +13,30 @@ namespace Microsoft.SemanticKernel; /// /// Class with data related to automatic function invocation. /// -public class AutoFunctionInvocationContext +public class AutoFunctionInvocationContext : Microsoft.Extensions.AI.FunctionInvocationContext { + private ChatHistory? _chatHistory; + + /// + /// Initializes a new instance of the class from an existing . + /// + internal AutoFunctionInvocationContext(KernelChatOptions autoInvocationChatOptions, AIFunction aiFunction) + { + Verify.NotNull(autoInvocationChatOptions); + Verify.NotNull(aiFunction); + if (aiFunction is not KernelFunction kernelFunction) + { + throw new InvalidOperationException($"The function must be of type {nameof(KernelFunction)}."); + } + Verify.NotNull(autoInvocationChatOptions.Kernel); + Verify.NotNull(autoInvocationChatOptions.ChatMessageContent); + + this.Options = autoInvocationChatOptions; + this.ExecutionSettings = autoInvocationChatOptions.ExecutionSettings; + this.AIFunction = aiFunction; + this.Result = new FunctionResult(kernelFunction) { Culture = autoInvocationChatOptions.Kernel.Culture }; + } + /// /// Initializes a new instance of the class. /// @@ -31,11 +58,16 @@ public AutoFunctionInvocationContext( Verify.NotNull(chatHistory); Verify.NotNull(chatMessageContent); - this.Kernel = kernel; - this.Function = function; + this.Options = new KernelChatOptions(kernel) + { + ChatMessageContent = chatMessageContent, + }; + + this._chatHistory = chatHistory; + this.Messages = chatHistory.ToChatMessageList(); + chatHistory.SetChatMessageHandlers(this.Messages); + base.Function = function; this.Result = result; - this.ChatHistory = chatHistory; - this.ChatMessageContent = chatMessageContent; } /// @@ -45,59 +77,107 @@ public AutoFunctionInvocationContext( public CancellationToken CancellationToken { get; init; } /// - /// Boolean flag which indicates whether a filter is invoked within streaming or non-streaming mode. + /// Gets the specialized version of associated with the operation. /// - public bool IsStreaming { get; init; } + /// + /// Due to a clash with the as a type, this property hides + /// it to not break existing code that relies on the as a type. + /// + /// Attempting to access the property when the arguments is not a class. + public new KernelArguments? Arguments + { + get + { + if (base.Arguments is KernelArguments kernelArguments) + { + return kernelArguments; + } - /// - /// Gets the arguments associated with the operation. - /// - public KernelArguments? Arguments { get; init; } + throw new InvalidOperationException($"The arguments provided in the initialization must be of type {nameof(KernelArguments)}."); + } + init => base.Arguments = value ?? new(); + } /// /// Request sequence index of automatic function invocation process. Starts from 0. /// - public int RequestSequenceIndex { get; init; } + public int RequestSequenceIndex + { + get => this.Iteration; + init => this.Iteration = value; + } /// /// Function sequence index. Starts from 0. /// - public int FunctionSequenceIndex { get; init; } - - /// - /// Number of functions that will be invoked during auto function invocation request. - /// - public int FunctionCount { get; init; } + public int FunctionSequenceIndex + { + get => this.FunctionCallIndex; + init => this.FunctionCallIndex = value; + } /// /// The ID of the tool call. /// - public string? ToolCallId { get; init; } + public string? ToolCallId + { + get => this.CallContent.CallId; + init + { + this.CallContent = new Microsoft.Extensions.AI.FunctionCallContent( + callId: value ?? string.Empty, + name: this.CallContent.Name, + arguments: this.CallContent.Arguments); + } + } /// /// The chat message content associated with automatic function invocation. /// - public ChatMessageContent ChatMessageContent { get; } + public ChatMessageContent ChatMessageContent => (this.Options as KernelChatOptions)!.ChatMessageContent!; /// /// The execution settings associated with the operation. /// - public PromptExecutionSettings? ExecutionSettings { get; init; } + public PromptExecutionSettings? ExecutionSettings + { + get => ((KernelChatOptions)this.Options!).ExecutionSettings; + init + { + this.Options ??= new KernelChatOptions(this.Kernel); + ((KernelChatOptions)this.Options!).ExecutionSettings = value; + } + } /// /// Gets the associated with automatic function invocation. /// - public ChatHistory ChatHistory { get; } + public ChatHistory ChatHistory => this._chatHistory ??= new ChatMessageHistory(this.Messages); /// /// Gets the with which this filter is associated. /// - public KernelFunction Function { get; } + /// + /// Due to a clash with the as a type, this property hides + /// it to not break existing code that relies on the as a type. + /// + public new KernelFunction Function + { + get + { + if (this.AIFunction is KernelFunction kf) + { + return kf; + } + + throw new InvalidOperationException($"The function provided in the initialization must be of type {nameof(KernelFunction)}."); + } + } /// /// Gets the containing services, plugins, and other state for use throughout the operation. /// - public Kernel Kernel { get; } + public Kernel Kernel => ((KernelChatOptions)this.Options!).Kernel!; /// /// Gets or sets the result of the function's invocation. @@ -105,18 +185,131 @@ public AutoFunctionInvocationContext( public FunctionResult Result { get; set; } /// - /// Gets or sets a value indicating whether the operation associated with the filter should be terminated. - /// - /// By default, this value is , which means all functions will be invoked. - /// If set to , the behavior depends on how functions are invoked: - /// - /// - If functions are invoked sequentially (the default behavior), the remaining functions will not be invoked, - /// and the last request to the LLM will not be performed. - /// - /// - If functions are invoked concurrently (controlled by the option), - /// other functions will still be invoked, and the last request to the LLM will not be performed. - /// - /// In both cases, the automatic function invocation process will be terminated, and the result of the last executed function will be returned to the caller. + /// Gets or sets the with which this filter is associated. + /// + internal AIFunction AIFunction + { + get => base.Function; + set => base.Function = value; + } + + private static bool IsSameSchema(KernelFunction kernelFunction, AIFunction aiFunction) + { + // Compares the schemas, should be similar. + return string.Equals( + kernelFunction.JsonSchema.ToString(), + aiFunction.JsonSchema.ToString(), + StringComparison.OrdinalIgnoreCase); + + // TODO: Later can be improved by comparing the underlying methods. + // return kernelFunction.UnderlyingMethod == aiFunction.UnderlyingMethod; + } + + /// + /// Mutable IEnumerable of chat message as chat history. /// - public bool Terminate { get; set; } + private class ChatMessageHistory : ChatHistory, IEnumerable + { + private readonly List _messages; + + internal ChatMessageHistory(IEnumerable messages) : base(messages.ToChatHistory()) + { + this._messages = new List(messages); + } + + public override void Add(ChatMessageContent item) + { + base.Add(item); + this._messages.Add(item.ToChatMessage()); + } + + public override void Clear() + { + base.Clear(); + this._messages.Clear(); + } + + public override bool Remove(ChatMessageContent item) + { + var index = base.IndexOf(item); + + if (index < 0) + { + return false; + } + + this._messages.RemoveAt(index); + base.RemoveAt(index); + + return true; + } + + public override void Insert(int index, ChatMessageContent item) + { + base.Insert(index, item); + this._messages.Insert(index, item.ToChatMessage()); + } + + public override void RemoveAt(int index) + { + this._messages.RemoveAt(index); + base.RemoveAt(index); + } + + public override ChatMessageContent this[int index] + { + get => this._messages[index].ToChatMessageContent(); + set + { + this._messages[index] = value.ToChatMessage(); + base[index] = value; + } + } + + public override void RemoveRange(int index, int count) + { + this._messages.RemoveRange(index, count); + base.RemoveRange(index, count); + } + + public override void CopyTo(ChatMessageContent[] array, int arrayIndex) + { + for (int i = 0; i < this._messages.Count; i++) + { + array[arrayIndex + i] = this._messages[i].ToChatMessageContent(); + } + } + + public override bool Contains(ChatMessageContent item) => base.Contains(item); + + public override int IndexOf(ChatMessageContent item) => base.IndexOf(item); + + public override void AddRange(IEnumerable items) + { + base.AddRange(items); + this._messages.AddRange(items.Select(i => i.ToChatMessage())); + } + + public override int Count => this._messages.Count; + + // Explicit implementation of IEnumerable.GetEnumerator() + IEnumerator IEnumerable.GetEnumerator() + { + foreach (var message in this._messages) + { + yield return message.ToChatMessageContent(); // Convert and yield each item + } + } + + // Explicit implementation of non-generic IEnumerable.GetEnumerator() + IEnumerator IEnumerable.GetEnumerator() + => ((IEnumerable)this).GetEnumerator(); + } + + /// Destructor to clear the chat history overrides. + ~AutoFunctionInvocationContext() + { + // The moment this class is destroyed, we need to clear the update message overrides + this._chatHistory?.ClearOverrides(); + } } diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/FullyQualifiedAIFunction.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/FullyQualifiedAIFunction.cs new file mode 100644 index 000000000000..c7f1b336dd97 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/FullyQualifiedAIFunction.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel; + +/// +/// Represents a kernel function that provides the plugin name as part of the original function name. +/// +public abstract class FullyQualifiedAIFunction : AIFunction +{ + /// + /// Initializes a new instance of the class. + /// + /// The metadata describing the function. + internal FullyQualifiedAIFunction(KernelFunctionMetadata metadata) + { + this.Metadata = metadata; + } + + /// + /// Gets the metadata describing the function. + /// + /// An instance of describing the function + public KernelFunctionMetadata Metadata { get; init; } + + /// + /// Gets the name of the function. + /// + /// + /// The fully qualified name (including the plugin name) is used anywhere the function needs to be identified, such as in plans describing what functions + /// should be invoked when, or as part of lookups in a plugin's function collection. Function names are generally + /// handled in an ordinal case-insensitive manner. + /// + public override string Name + => !string.IsNullOrWhiteSpace(this.Metadata.PluginName) + ? $"{this.Metadata.PluginName}_{this.Metadata.Name}" + : this.Metadata.Name; +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs index 0902a4f80c98..4454c2c0b493 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/FunctionResult.cs @@ -3,6 +3,9 @@ using System; using System.Collections.Generic; using System.Globalization; +using System.Linq; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.SemanticKernel; @@ -101,6 +104,105 @@ public FunctionResult(FunctionResult result, object? value = null) { return innerContent; } + + // Attempting to use the new Microsoft.Extensions.AI Chat types will trigger automatic conversion of SK chat contents. + + // ChatMessageContent as ChatMessage + if (typeof(T) == typeof(ChatMessage) + && content is ChatMessageContent chatMessageContent) + { + return (T?)(object)chatMessageContent.ToChatMessage(); + } + + // ChatMessageContent as ChatResponse + if (typeof(T) == typeof(ChatResponse) + && content is ChatMessageContent singleChoiceMessageContent) + { + return (T?)(object)new Microsoft.Extensions.AI.ChatResponse(singleChoiceMessageContent.ToChatMessage()); + } + } + + if (this.Value is IReadOnlyList messageContentList) + { + if (messageContentList.Count == 0) + { + throw new InvalidCastException($"Cannot cast a response with no choices to {typeof(T)}"); + } + + var firstMessage = messageContentList[0]; + if (typeof(T) == typeof(ChatResponse)) + { + // Ignore multiple choices when converting to Microsoft.Extensions.AI.ChatResponse + return (T)(object)new ChatResponse(firstMessage.ToChatMessage()); + } + + if (typeof(T) == typeof(ChatMessage)) + { + return (T)(object)firstMessage.ToChatMessage(); + } + } + + if (this.Value is Microsoft.Extensions.AI.ChatResponse chatResponse) + { + // If no choices are present, return default + if (chatResponse.Messages.Count == 0) + { + throw new InvalidCastException($"Cannot cast a response with no messages to {typeof(T)}"); + } + + var chatMessage = chatResponse.Messages.Last(); + if (typeof(T) == typeof(string)) + { + return (T?)(object?)chatMessage.ToString(); + } + + // ChatMessage from a ChatResponse + if (typeof(T) == typeof(ChatMessage)) + { + return (T?)(object)chatMessage; + } + + if (typeof(Microsoft.Extensions.AI.AIContent).IsAssignableFrom(typeof(T))) + { + // Return the first matching content type of a message if any + var updateContent = chatMessage.Contents.FirstOrDefault(c => c is T); + if (updateContent is not null) + { + return (T)(object)updateContent; + } + } + + if (chatMessage.Contents is T contentsList) + { + return contentsList; + } + + if (chatResponse.RawRepresentation is T rawResponseRepresentation) + { + return rawResponseRepresentation; + } + + if (chatMessage.RawRepresentation is T rawMessageRepresentation) + { + return rawMessageRepresentation; + } + + if (typeof(Microsoft.Extensions.AI.AIContent).IsAssignableFrom(typeof(T))) + { + // Return the first matching content type of a message if any + var updateContent = chatMessage.Contents.FirstOrDefault(c => c is T); + if (updateContent is not null) + { + return (T)(object)updateContent; + } + } + + // Avoid breaking changes this transformation will be dropped once we migrate fully to Microsoft.Extensions.AI abstractions. + // This is also necessary to don't break existing code using KernelContents when using IChatClient connectors. + if (typeof(KernelContent).IsAssignableFrom(typeof(T))) + { + return (T)(object)chatMessage.ToChatMessageContent(); + } } throw new InvalidCastException($"Cannot cast {this.Value.GetType()} to {typeof(T)}"); diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelArguments.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelArguments.cs index eda736b3f583..419a12039049 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelArguments.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelArguments.cs @@ -1,9 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections; using System.Collections.Generic; using System.Text.Json.Serialization; +using Microsoft.Extensions.AI; #pragma warning disable CA1710 // Identifiers should have correct suffix @@ -17,10 +17,8 @@ namespace Microsoft.SemanticKernel; /// A is a dictionary of argument names and values. It also carries a /// , accessible via the property. /// -public sealed class KernelArguments : IDictionary, IReadOnlyDictionary +public sealed class KernelArguments : AIFunctionArguments { - /// Dictionary of name/values for all the arguments in the instance. - private readonly Dictionary _arguments; private IReadOnlyDictionary? _executionSettings; /// @@ -28,8 +26,8 @@ public sealed class KernelArguments : IDictionary, IReadOnlyDic /// [JsonConstructor] public KernelArguments() + : base(StringComparer.OrdinalIgnoreCase) { - this._arguments = new(StringComparer.OrdinalIgnoreCase); } /// @@ -37,7 +35,7 @@ public KernelArguments() /// /// The prompt execution settings. public KernelArguments(PromptExecutionSettings? executionSettings) - : this(executionSettings is null ? null : [executionSettings]) + : this(executionSettings: executionSettings is null ? null : [executionSettings]) { } @@ -46,8 +44,8 @@ public KernelArguments(PromptExecutionSettings? executionSettings) /// /// The prompt execution settings. public KernelArguments(IEnumerable? executionSettings) + : base(StringComparer.OrdinalIgnoreCase) { - this._arguments = new(StringComparer.OrdinalIgnoreCase); if (executionSettings is not null) { var newExecutionSettings = new Dictionary(); @@ -80,10 +78,8 @@ public KernelArguments(IEnumerable? executionSettings) /// Otherwise, if the source is a , its are used. /// public KernelArguments(IDictionary source, Dictionary? executionSettings = null) + : base(source, StringComparer.OrdinalIgnoreCase) { - Verify.NotNull(source); - - this._arguments = new(source, StringComparer.OrdinalIgnoreCase); this.ExecutionSettings = executionSettings ?? (source as KernelArguments)?.ExecutionSettings; } @@ -115,37 +111,6 @@ public IReadOnlyDictionary? ExecutionSettings } } - /// - /// Gets the number of arguments contained in the . - /// - public int Count => this._arguments.Count; - - /// Adds the specified argument name and value to the . - /// The name of the argument to add. - /// The value of the argument to add. - /// is null. - /// An argument with the same name already exists in the . - public void Add(string name, object? value) - { - Verify.NotNull(name); - this._arguments.Add(name, value); - } - - /// Removes the argument value with the specified name from the . - /// The name of the argument value to remove. - /// is null. - public bool Remove(string name) - { - Verify.NotNull(name); - return this._arguments.Remove(name); - } - - /// Removes all arguments names and values from the . - /// - /// This does not affect the property. To clear it as well, set it to null. - /// - public void Clear() => this._arguments.Clear(); - /// Determines whether the contains an argument with the specified name. /// The name of the argument to locate. /// true if the arguments contains an argument with the specified named; otherwise, false. @@ -153,103 +118,9 @@ public bool Remove(string name) public bool ContainsName(string name) { Verify.NotNull(name); - return this._arguments.ContainsKey(name); - } - - /// Gets the value associated with the specified argument name. - /// The name of the argument value to get. - /// - /// When this method returns, contains the value associated with the specified name, - /// if the name is found; otherwise, null. - /// - /// true if the arguments contains an argument with the specified name; otherwise, false. - /// is null. - public bool TryGetValue(string name, out object? value) - { - Verify.NotNull(name); - return this._arguments.TryGetValue(name, out value); - } - - /// Gets or sets the value associated with the specified argument name. - /// The name of the argument value to get or set. - /// is null. - public object? this[string name] - { - get - { - Verify.NotNull(name); - return this._arguments[name]; - } - set - { - Verify.NotNull(name); - this._arguments[name] = value; - } - } - - /// Gets an of all of the arguments' names. - public ICollection Names => this._arguments.Keys; - - /// Gets an of all of the arguments' values. - public ICollection Values => this._arguments.Values; - - #region Interface implementations - /// - ICollection IDictionary.Keys => this._arguments.Keys; - - /// - IEnumerable IReadOnlyDictionary.Keys => this._arguments.Keys; - - /// - IEnumerable IReadOnlyDictionary.Values => this._arguments.Values; - - /// - bool ICollection>.IsReadOnly => false; - - /// - object? IReadOnlyDictionary.this[string key] => this._arguments[key]; - - /// - object? IDictionary.this[string key] - { - get => this._arguments[key]; - set => this._arguments[key] = value; + return base.ContainsKey(name); } - /// - void IDictionary.Add(string key, object? value) => this._arguments.Add(key, value); - - /// - bool IDictionary.ContainsKey(string key) => this._arguments.ContainsKey(key); - - /// - bool IDictionary.Remove(string key) => this._arguments.Remove(key); - - /// - bool IDictionary.TryGetValue(string key, out object? value) => this._arguments.TryGetValue(key, out value); - - /// - void ICollection>.Add(KeyValuePair item) => this._arguments.Add(item.Key, item.Value); - - /// - bool ICollection>.Contains(KeyValuePair item) => ((ICollection>)this._arguments).Contains(item); - - /// - void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) => ((ICollection>)this._arguments).CopyTo(array, arrayIndex); - - /// - bool ICollection>.Remove(KeyValuePair item) => this._arguments.Remove(item.Key); - - /// - IEnumerator> IEnumerable>.GetEnumerator() => this._arguments.GetEnumerator(); - - /// - IEnumerator IEnumerable.GetEnumerator() => this._arguments.GetEnumerator(); - - /// - bool IReadOnlyDictionary.ContainsKey(string key) => this._arguments.ContainsKey(key); - - /// - bool IReadOnlyDictionary.TryGetValue(string key, out object? value) => this._arguments.TryGetValue(key, out value); - #endregion + /// Gets an of all of the arguments names. + public ICollection Names => this.Keys; } diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs index 4188bc6b6994..48fc1bb8b460 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs @@ -19,13 +19,17 @@ using Microsoft.SemanticKernel.Diagnostics; using Microsoft.SemanticKernel.Functions; +#pragma warning disable CS8600 // Converting null literal or possible null value to non-nullable type. + namespace Microsoft.SemanticKernel; /// /// Represents a function that can be invoked as part of a Semantic Kernel workload. /// -public abstract class KernelFunction +public abstract class KernelFunction : FullyQualifiedAIFunction { + private static readonly JsonElement s_defaultSchema = JsonDocument.Parse("{}").RootElement; + /// The measurement tag name for the function name. private protected const string MeasurementFunctionTagName = "semantic_kernel.function.name"; @@ -38,6 +42,18 @@ public abstract class KernelFunction /// for function-related metrics. private protected static readonly Meter s_meter = new("Microsoft.SemanticKernel"); + /// The to use for serialization and deserialization of various aspects of the function. + private readonly JsonSerializerOptions? _jsonSerializerOptions; + + /// The underlying method, if this function was created from a method. +#pragma warning disable CA1051 + protected MethodInfo? _underlyingMethod; +#pragma warning restore CA1051 + + /// The instance that will be prioritized when invoking without a provided argument. + /// This will be normally used when the function is invoked using the interface. + internal Kernel? Kernel { get; set; } + /// to record function invocation duration. private static readonly Histogram s_invocationDuration = s_meter.CreateHistogram( name: "semantic_kernel.function.invocation.duration", @@ -54,9 +70,6 @@ public abstract class KernelFunction unit: "s", description: "Measures the duration of a function's streaming execution"); - /// The to use for serialization and deserialization of various aspects of the function. - protected JsonSerializerOptions? JsonSerializerOptions { get; set; } - /// /// Gets the name of the function. /// @@ -65,7 +78,7 @@ public abstract class KernelFunction /// should be invoked when, or as part of lookups in a plugin's function collection. Function names are generally /// handled in an ordinal case-insensitive manner. /// - public string Name => this.Metadata.Name; + public virtual new string Name => this.Metadata.Name; /// /// Gets the name of the plugin this function was added to. @@ -74,7 +87,7 @@ public abstract class KernelFunction /// The plugin name will be null if the function has not been added to a plugin. /// When a function is added to a plugin it will be cloned and the plugin name will be set. /// - public string? PluginName => this.Metadata.PluginName; + public virtual string? PluginName => this.Metadata.PluginName; /// /// Gets a description of the function. @@ -83,13 +96,7 @@ public abstract class KernelFunction /// The description may be supplied to a model in order to elaborate on the function's purpose, /// in case it may be beneficial for the model to recommend invoking the function. /// - public string Description => this.Metadata.Description; - - /// - /// Gets the metadata describing the function. - /// - /// An instance of describing the function - public KernelFunctionMetadata Metadata { get; init; } + public override string Description => this.Metadata.Description; /// /// Gets the prompt execution settings. @@ -99,20 +106,11 @@ public abstract class KernelFunction /// public IReadOnlyDictionary? ExecutionSettings { get; } - /// - /// Gets the underlying that this function might be wrapping. - /// - /// - /// Provides additional metadata on the function and its signature. Implementations not wrapping .NET methods may return null. - /// - [Experimental("SKEXP0001")] - public MethodInfo? UnderlyingMethod { get; internal init; } - /// /// Initializes a new instance of the class. /// - /// A name of the function to use as its . - /// The description of the function to use as its . + /// A name of the function to use as its . + /// The description of the function to use as its . /// The metadata describing the parameters to the function. /// The metadata describing the return parameter of the function. /// @@ -129,8 +127,8 @@ internal KernelFunction(string name, string description, IReadOnlyList /// Initializes a new instance of the class. /// - /// A name of the function to use as its . - /// The description of the function to use as its . + /// A name of the function to use as its . + /// The description of the function to use as its . /// The metadata describing the parameters to the function. /// The to use for serialization and deserialization of various aspects of the function. /// The metadata describing the return parameter of the function. @@ -146,9 +144,9 @@ internal KernelFunction(string name, string description, IReadOnlyList /// Initializes a new instance of the class. /// - /// A name of the function to use as its . + /// A name of the function to use as its . /// The name of the plugin this function instance has been added to. - /// The description of the function to use as its . + /// The description of the function to use as its . /// The metadata describing the parameters to the function. /// The metadata describing the return parameter of the function. /// @@ -159,18 +157,16 @@ internal KernelFunction(string name, string description, IReadOnlyList parameters, KernelReturnParameterMetadata? returnParameter = null, Dictionary? executionSettings = null, ReadOnlyDictionary? additionalMetadata = null) - { - Verify.NotNull(name); - KernelVerify.ParametersUniqueness(parameters); - - this.Metadata = new KernelFunctionMetadata(name) + : base(new KernelFunctionMetadata(Throw.IfNull(name)) { PluginName = pluginName, Description = description, - Parameters = parameters, + Parameters = KernelVerify.ParametersUniqueness(parameters), ReturnParameter = returnParameter ?? KernelReturnParameterMetadata.Empty, AdditionalProperties = additionalMetadata ?? KernelFunctionMetadata.s_emptyDictionary, - }; + }) + { + this.BuildFunctionSchema(); if (executionSettings is not null) { @@ -180,12 +176,15 @@ internal KernelFunction(string name, string? pluginName, string description, IRe } } + /// + public override JsonElement JsonSchema => this._jsonSchema; + /// /// Initializes a new instance of the class. /// - /// A name of the function to use as its . + /// A name of the function to use as its . /// The name of the plugin this function instance has been added to. - /// The description of the function to use as its . + /// The description of the function to use as its . /// The metadata describing the parameters to the function. /// The to use for serialization and deserialization of various aspects of the function. /// The metadata describing the return parameter of the function. @@ -195,19 +194,18 @@ internal KernelFunction(string name, string? pluginName, string description, IRe /// /// Properties/metadata associated with the function itself rather than its parameters and return type. internal KernelFunction(string name, string? pluginName, string description, IReadOnlyList parameters, JsonSerializerOptions jsonSerializerOptions, KernelReturnParameterMetadata? returnParameter = null, Dictionary? executionSettings = null, ReadOnlyDictionary? additionalMetadata = null) - { - Verify.NotNull(name); - KernelVerify.ParametersUniqueness(parameters); - Verify.NotNull(jsonSerializerOptions); - - this.Metadata = new KernelFunctionMetadata(name) + : base(new KernelFunctionMetadata(Throw.IfNull(name)) { PluginName = pluginName, Description = description, - Parameters = parameters, + Parameters = KernelVerify.ParametersUniqueness(parameters), ReturnParameter = returnParameter ?? KernelReturnParameterMetadata.Empty, AdditionalProperties = additionalMetadata ?? KernelFunctionMetadata.s_emptyDictionary, - }; + }) + { + Verify.NotNull(jsonSerializerOptions); + + this.BuildFunctionSchema(); if (executionSettings is not null) { @@ -216,9 +214,15 @@ internal KernelFunction(string name, string? pluginName, string description, IRe entry => { var clone = entry.Value.Clone(); clone.Freeze(); return clone; }); } - this.JsonSerializerOptions = jsonSerializerOptions; + this._jsonSerializerOptions = jsonSerializerOptions; } + /// + public override JsonSerializerOptions JsonSerializerOptions => this._jsonSerializerOptions ?? base.JsonSerializerOptions; + + /// + public override MethodInfo? UnderlyingMethod => this._underlyingMethod; + /// /// Invokes the . /// @@ -232,6 +236,7 @@ public async Task InvokeAsync( KernelArguments? arguments = null, CancellationToken cancellationToken = default) { + kernel ??= this.Kernel; Verify.NotNull(kernel); using var activity = s_activitySource.StartActivity(this.Name); @@ -334,6 +339,7 @@ public async IAsyncEnumerable InvokeStreamingAsync( KernelArguments? arguments = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + kernel ??= this.Kernel; Verify.NotNull(kernel); using var activity = s_activitySource.StartActivity(this.Name); @@ -419,12 +425,12 @@ public async IAsyncEnumerable InvokeStreamingAsync( /// Creates a new object that is a copy of the current instance /// but the has the plugin name set. /// - /// The name of the plugin this function instance will be added to. + /// The optional name of the plugin this function instance will be added to. /// /// This method should only be used to create a new instance of a when adding /// a function to a . /// - public abstract KernelFunction Clone(string pluginName); + public abstract KernelFunction Clone(string? pluginName = null); /// public override string ToString() => string.IsNullOrWhiteSpace(this.PluginName) ? @@ -443,6 +449,39 @@ protected abstract ValueTask InvokeCoreAsync( KernelArguments arguments, CancellationToken cancellationToken); + /// + /// Invokes the using the interface. + /// + /// + /// When using the interface, the will be acquired as follows, in order of priority: + /// + /// From the dictionary with the key. + /// From the . service provider. + /// From the provided in when Cloning the . + /// A new instance will be created using the same service provider in the .. + /// + /// + /// The arguments to pass to the function's invocation. + /// The to monitor for cancellation requests. + /// The result of the function's execution. + protected override async ValueTask InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken) + { + Kernel kernel = (arguments.TryGetValue(AIFunctionArgumentsExtensions.KernelAIFunctionArgumentKey, out var kernelObject) && kernelObject is not null) + ? (kernelObject as Kernel)! + : arguments.Services?.GetService(typeof(Kernel)) as Kernel + ?? this.Kernel + ?? new(arguments.Services); + + var kernelArguments = new KernelArguments(arguments); + + var result = await this.InvokeCoreAsync(kernel, kernelArguments, cancellationToken).ConfigureAwait(false); + + // Serialize the result to JSON, as with AIFunctionFactory.Create AIFunctions. + return result.Value is object value ? + JsonSerializer.SerializeToElement(value, AbstractionsJsonContext.GetTypeInfo(value.GetType(), this.JsonSerializerOptions)) : + null; + } + /// /// Invokes the and streams its results. /// @@ -486,6 +525,26 @@ private static void HandleException( } } + private void BuildFunctionSchema() + { + KernelFunctionSchemaModel schemaModel = new() + { + Type = "object", + Description = this.Description, + }; + + foreach (var parameter in this.Metadata.Parameters) + { + schemaModel.Properties[parameter.Name] = parameter.Schema?.RootElement ?? s_defaultSchema; + if (parameter.IsRequired) + { + (schemaModel.Required ??= []).Add(parameter.Name); + } + } + + this._jsonSchema = JsonSerializer.SerializeToElement(schemaModel, AbstractionsJsonContext.Default.KernelFunctionSchemaModel); + } + [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "The warning is shown and should be addressed at the function creation site; there is no need to show it again at the function invocation sites.")] [UnconditionalSuppressMessage("AOT", "IL3050:Calling members annotated with 'RequiresDynamicCodeAttribute' may break functionality when AOT compiling.", Justification = "The warning is shown and should be addressed at the function creation site; there is no need to show it again at the function invocation sites.")] private void LogFunctionArguments(ILogger logger, string? pluginName, string functionName, KernelArguments arguments) @@ -520,12 +579,18 @@ private void LogFunctionResult(ILogger logger, string? pluginName, string functi /// /// An instance of that, when invoked, will in turn invoke the current . [Experimental("SKEXP0001")] + [Obsolete("Use the kernel function directly or for similar behavior use Clone(Kernel) method instead.")] public AIFunction AsAIFunction(Kernel? kernel = null) { return new KernelAIFunction(this, kernel); } + #region Private + + private JsonElement _jsonSchema; + /// An wrapper around a . + [Obsolete("Use the kernel function directly or for similar behavior use Clone(Kernel) method instead.")] private sealed class KernelAIFunction : AIFunction { private static readonly JsonElement s_defaultSchema = JsonDocument.Parse("{}").RootElement; @@ -542,7 +607,6 @@ public KernelAIFunction(KernelFunction kernelFunction, Kernel? kernel) this.JsonSchema = BuildFunctionSchema(kernelFunction); } - public override string Name { get; } public override JsonElement JsonSchema { get; } public override string Description => this._kernelFunction.Description; @@ -588,4 +652,6 @@ private static JsonElement BuildFunctionSchema(KernelFunction function) return JsonSerializer.SerializeToElement(schemaModel, AbstractionsJsonContext.Default.KernelFunctionSchemaModel); } } + + #endregion } diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionExtensions.cs new file mode 100644 index 000000000000..383f6cc01ad3 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionExtensions.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel; + +/// Provides extension methods for . +public static class KernelFunctionExtensions +{ + /// + /// Creates a cloned for a specific . Useful when this function is used as a lower-level abstraction directly. + /// + /// + /// The provided will be used by default when none is provided using the arguments in or when a null is used when invoking method. + /// + /// The to clone with a default . + /// The to use as the default option. + /// Optional plugin name to use for the new kernel cloned function. + [Experimental("SKEXP0001")] + public static KernelFunction WithKernel(this KernelFunction kernelFunction, Kernel? kernel = null, string? pluginName = null) + { + var clone = kernelFunction.Clone(pluginName ?? kernelFunction.PluginName); + clone.Kernel = kernel; + + return clone; + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionNoop.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionNoop.cs index ce6ebc7eaf39..5c974eb5f7ec 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionNoop.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunctionNoop.cs @@ -26,7 +26,7 @@ internal KernelFunctionNoop(IReadOnlyDictionary } /// - public override KernelFunction Clone(string pluginName) + public override KernelFunction Clone(string? pluginName = null) { Dictionary? executionSettings = this.ExecutionSettings?.ToDictionary(kv => kv.Key, kv => kv.Value); return new KernelFunctionNoop(executionSettings); diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelParameterMetadata.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelParameterMetadata.cs index c91c6657d149..b081d41a0bae 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelParameterMetadata.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelParameterMetadata.cs @@ -3,6 +3,7 @@ using System; using System.Diagnostics.CodeAnalysis; using System.Text.Json; +using System.Text.Json.Serialization; namespace Microsoft.SemanticKernel; @@ -122,6 +123,7 @@ public object? DefaultValue public bool IsRequired { get; init; } /// Gets the .NET type of the parameter. + [JsonIgnore] public Type? ParameterType { get => this._parameterType; diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelPlugin.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelPlugin.cs index 1b6aab3c87a3..3099075b1464 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelPlugin.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelPlugin.cs @@ -93,17 +93,19 @@ public IList GetFunctionsMetadata() /// public abstract IEnumerator GetEnumerator(); - /// Produces an for every in this plugin. - /// - /// The instance to pass to the s when invoked as part of the 's invocation. - /// - /// An enumerable of instances, one for each in this plugin. + /// Produces a clone of every in the to be used as a lower-level abstraction. + /// + /// Once this function was cloned as a , the Name will be prefixed by the i.e: PluginName_FunctionName. + /// + /// The default to be used when the is invoked. + /// An enumerable clone of instances, for each in this plugin. [Experimental("SKEXP0001")] public IEnumerable AsAIFunctions(Kernel? kernel = null) { foreach (KernelFunction function in this) { - yield return function.AsAIFunction(kernel); + var functionClone = function.WithKernel(kernel); + yield return functionClone; } } diff --git a/dotnet/src/SemanticKernel.Abstractions/Kernel.cs b/dotnet/src/SemanticKernel.Abstractions/Kernel.cs index 99a335e15656..99425171947c 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Kernel.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Kernel.cs @@ -6,6 +6,7 @@ using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Linq; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; @@ -163,6 +164,7 @@ public Kernel Clone() => /// and any functions invoked within the context can consult this property for use in /// operations like formatting and parsing. /// + [JsonIgnore] [AllowNull] public CultureInfo Culture { @@ -177,6 +179,7 @@ public CultureInfo Culture /// This returns any in . If there is /// none, it returns an that won't perform any logging. /// + [JsonIgnore] public ILoggerFactory LoggerFactory => this.Services.GetService() ?? NullLoggerFactory.Instance; @@ -184,6 +187,7 @@ public CultureInfo Culture /// /// Gets the associated with this . /// + [JsonIgnore] public IAIServiceSelector ServiceSelector => this.Services.GetService() ?? OrderedAIServiceSelector.Instance; diff --git a/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj b/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj index 47043cbe1df8..4826426418a6 100644 --- a/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj +++ b/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj @@ -29,7 +29,7 @@ - + diff --git a/dotnet/src/SemanticKernel.Abstractions/Services/AIServiceExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/Services/AIServiceExtensions.cs index 24bc16a0f8e7..ccb32a42ef3f 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Services/AIServiceExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Services/AIServiceExtensions.cs @@ -66,7 +66,7 @@ public static class AIServiceExtensions /// /// /// Specifies the type of the required. This must be the same type - /// with which the service was registered in the orvia + /// with which the service was registered in the or via /// the . /// /// The to use to select a service from the . diff --git a/dotnet/src/SemanticKernel.Abstractions/Services/IAIServiceSelector.cs b/dotnet/src/SemanticKernel.Abstractions/Services/IAIServiceSelector.cs index 93064508d118..353abb9715cc 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Services/IAIServiceSelector.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Services/IAIServiceSelector.cs @@ -16,7 +16,7 @@ public interface IAIServiceSelector /// /// /// Specifies the type of the required. This must be the same type - /// with which the service was registered in the orvia + /// with which the service was registered in the or via /// the . /// /// The containing services, plugins, and other state for use throughout the operation. diff --git a/dotnet/src/SemanticKernel.Abstractions/Services/IChatClientSelector.cs b/dotnet/src/SemanticKernel.Abstractions/Services/IChatClientSelector.cs new file mode 100644 index 000000000000..30f8e2bcb4e6 --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/Services/IChatClientSelector.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.SemanticKernel; + +#pragma warning disable CA1716 // Identifiers should not match keywords + +/// +/// Represents a selector which will return a combination of the containing instances of T and it's pairing +/// from the specified provider based on the model settings. +/// +[Experimental("SKEXP0001")] +public interface IChatClientSelector +{ + /// + /// Resolves an and associated from the specified + /// based on a and associated . + /// + /// + /// Specifies the type of the required. This must be the same type + /// with which the service was registered in the or via + /// the . + /// + /// The containing services, plugins, and other state for use throughout the operation. + /// The function. + /// The function arguments. + /// The selected service, or null if none was selected. + /// The settings associated with the selected service. This may be null even if a service is selected. + /// true if a matching service was selected; otherwise, false. + bool TrySelectChatClient( + Kernel kernel, + KernelFunction function, + KernelArguments arguments, + [NotNullWhen(true)] out T? service, + out PromptExecutionSettings? serviceSettings) where T : class, IChatClient; +} diff --git a/dotnet/src/SemanticKernel.Abstractions/Services/OrderedAIServiceSelector.cs b/dotnet/src/SemanticKernel.Abstractions/Services/OrderedAIServiceSelector.cs index 1200acd3a803..c11851c4d46d 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Services/OrderedAIServiceSelector.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Services/OrderedAIServiceSelector.cs @@ -3,7 +3,9 @@ using System; using System.Diagnostics.CodeAnalysis; using System.Linq; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.SemanticKernel.Services; @@ -11,15 +13,23 @@ namespace Microsoft.SemanticKernel.Services; /// Implementation of that selects the AI service based on the order of the execution settings. /// Uses the service id or model id to select the preferred service provider and then returns the service and associated execution settings. /// -internal sealed class OrderedAIServiceSelector : IAIServiceSelector +internal sealed class OrderedAIServiceSelector : IAIServiceSelector, IChatClientSelector { public static OrderedAIServiceSelector Instance { get; } = new(); /// - public bool TrySelectAIService( + [Experimental("SKEXP0001")] + public bool TrySelectChatClient(Kernel kernel, KernelFunction function, KernelArguments arguments, [NotNullWhen(true)] out T? service, out PromptExecutionSettings? serviceSettings) where T : class, IChatClient + => this.TrySelect(kernel, function, arguments, out service, out serviceSettings); + + /// + public bool TrySelectAIService(Kernel kernel, KernelFunction function, KernelArguments arguments, [NotNullWhen(true)] out T? service, out PromptExecutionSettings? serviceSettings) where T : class, IAIService + => this.TrySelect(kernel, function, arguments, out service, out serviceSettings); + + private bool TrySelect( Kernel kernel, KernelFunction function, KernelArguments arguments, [NotNullWhen(true)] out T? service, - out PromptExecutionSettings? serviceSettings) where T : class, IAIService + out PromptExecutionSettings? serviceSettings) where T : class { // Allow the execution settings from the kernel arguments to take precedence var executionSettings = arguments.ExecutionSettings ?? function.ExecutionSettings; @@ -94,11 +104,20 @@ kernel.Services is IKeyedServiceProvider ? kernel.Services.GetService(); } - private T? GetServiceByModelId(Kernel kernel, string modelId) where T : class, IAIService + private T? GetServiceByModelId(Kernel kernel, string modelId) where T : class { foreach (var service in kernel.GetAllServices()) { - string? serviceModelId = service.GetModelId(); + string? serviceModelId = null; + if (service is IAIService aiService) + { + serviceModelId = aiService.GetModelId(); + } + else if (service is IChatClient chatClient) + { + serviceModelId = chatClient.GetModelId(); + } + if (!string.IsNullOrEmpty(serviceModelId) && serviceModelId == modelId) { return service; diff --git a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromMethod.cs b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromMethod.cs index 2c212df12ef8..de5574375cb0 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromMethod.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromMethod.cs @@ -379,29 +379,14 @@ protected override async IAsyncEnumerable InvokeStreamingCoreAsync - public override KernelFunction Clone(string pluginName) + public override KernelFunction Clone(string? pluginName = null) { - Verify.NotNullOrWhiteSpace(pluginName, nameof(pluginName)); - - if (base.JsonSerializerOptions is not null) + if (pluginName is not null) { - return new KernelFunctionFromMethod( - this.UnderlyingMethod!, - this._function, - this.Name, - pluginName, - this.Description, - this.Metadata.Parameters, - this.Metadata.ReturnParameter, - base.JsonSerializerOptions, - this.Metadata.AdditionalProperties); + Verify.NotNullOrWhiteSpace(pluginName, nameof(pluginName)); } - [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "Non AOT scenario.")] - [UnconditionalSuppressMessage("AOT", "IL3050:Calling members annotated with 'RequiresDynamicCodeAttribute' may break functionality when AOT compiling.", Justification = "Non AOT scenario.")] - KernelFunctionFromMethod Clone() - { - return new KernelFunctionFromMethod( + return new KernelFunctionFromMethod( this.UnderlyingMethod!, this._function, this.Name, @@ -409,10 +394,8 @@ KernelFunctionFromMethod Clone() this.Description, this.Metadata.Parameters, this.Metadata.ReturnParameter, + base.JsonSerializerOptions, this.Metadata.AdditionalProperties); - } - - return Clone(); } /// Delegate used to invoke the underlying delegate. @@ -470,7 +453,7 @@ private KernelFunctionFromMethod( KernelVerify.ValidFunctionName(functionName); this._function = implementationFunc; - this.UnderlyingMethod = method; + this._underlyingMethod = method; } private KernelFunctionFromMethod( @@ -488,7 +471,7 @@ private KernelFunctionFromMethod( KernelVerify.ValidFunctionName(functionName); this._function = implementationFunc; - this.UnderlyingMethod = method; + this._underlyingMethod = method; } [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "This method is AOT save.")] @@ -811,7 +794,7 @@ private static bool TryToDeserializeValue(object value, Type targetType, JsonSer JsonDocument document => document.Deserialize(targetType, jsonSerializerOptions), JsonNode node => node.Deserialize(targetType, jsonSerializerOptions), JsonElement element => element.Deserialize(targetType, jsonSerializerOptions), - // The JSON can be represented by other data types from various libraries. For example, JObject, JToken, and JValue from the Newtonsoft.Json library. + // The JSON can be represented by other data types from various libraries. For example, JObject, JToken, and JValue from the Newtonsoft.Json library. // Since we don't take dependencies on these libraries and don't have access to the types here, // the only way to deserialize those types is to convert them to a string first by calling the 'ToString' method. // Attempting to use the 'JsonSerializer.Serialize' method, instead of calling the 'ToString' directly on those types, can lead to unpredictable outcomes. @@ -1095,7 +1078,7 @@ private static void ThrowForInvalidSignatureIf([DoesNotReturnIf(true)] bool cond { if (input?.GetType() is Type type && converter.CanConvertFrom(type)) { - // This line performs string to type conversion + // This line performs string to type conversion return converter.ConvertFrom(context: null, culture, input); } diff --git a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs index 0d7a76065de1..72219af73b80 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs @@ -7,9 +7,11 @@ using System.Diagnostics.Metrics; using System.Linq; using System.Runtime.CompilerServices; +using System.Text; using System.Text.Json; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.SemanticKernel.ChatCompletion; @@ -249,6 +251,7 @@ protected override async ValueTask InvokeCoreAsync( { IChatCompletionService chatCompletion => await this.GetChatCompletionResultAsync(chatCompletion, kernel, promptRenderingResult, cancellationToken).ConfigureAwait(false), ITextGenerationService textGeneration => await this.GetTextGenerationResultAsync(textGeneration, kernel, promptRenderingResult, cancellationToken).ConfigureAwait(false), + IChatClient chatClient => await this.GetChatClientResultAsync(chatClient, kernel, promptRenderingResult, cancellationToken).ConfigureAwait(false), // The service selector didn't find an appropriate service. This should only happen with a poorly implemented selector. _ => throw new NotSupportedException($"The AI service {promptRenderingResult.AIService.GetType()} is not supported. Supported services are {typeof(IChatCompletionService)} and {typeof(ITextGenerationService)}") }; @@ -268,7 +271,7 @@ protected override async IAsyncEnumerable InvokeStreamingCoreAsync? asyncReference = null; + IAsyncEnumerable? asyncReference = null; if (result.AIService is IChatCompletionService chatCompletion) { @@ -278,74 +281,138 @@ protected override async IAsyncEnumerable InvokeStreamingCoreAsync (TResult)(object)content.ToString(), - - _ when content is TResult contentAsT - => contentAsT, - - _ when content.InnerContent is TResult innerContentAsT - => innerContentAsT, - - _ when typeof(TResult) == typeof(byte[]) - => (TResult)(object)content.ToByteArray(), + if (typeof(TResult) == typeof(string)) + { + yield return (TResult)(object)kernelContent.ToString(); + continue; + } + + if (content is TResult contentAsT) + { + yield return contentAsT; + continue; + } + + if (kernelContent.InnerContent is TResult innerContentAsT) + { + yield return innerContentAsT; + continue; + } + + if (typeof(TResult) == typeof(byte[])) + { + if (content is StreamingKernelContent byteKernelContent) + { + yield return (TResult)(object)byteKernelContent.ToByteArray(); + continue; + } + } + + // Attempting to use the new Microsoft Extensions AI types will trigger automatic conversion of SK chat contents. + if (typeof(ChatResponseUpdate).IsAssignableFrom(typeof(TResult)) + && content is StreamingChatMessageContent streamingChatMessageContent) + { + yield return (TResult)(object)streamingChatMessageContent.ToChatResponseUpdate(); + continue; + } + } + else if (content is ChatResponseUpdate chatUpdate) + { + if (typeof(TResult) == typeof(string)) + { + yield return (TResult)(object)chatUpdate.ToString(); + continue; + } + + if (chatUpdate is TResult contentAsT) + { + yield return contentAsT; + continue; + } + + if (chatUpdate.Contents is TResult contentListsAsT) + { + yield return contentListsAsT; + continue; + } + + if (chatUpdate.RawRepresentation is TResult rawRepresentationAsT) + { + yield return rawRepresentationAsT; + continue; + } + + if (typeof(Microsoft.Extensions.AI.AIContent).IsAssignableFrom(typeof(TResult))) + { + // Return the first matching content type of an update if any + var updateContent = chatUpdate.Contents.FirstOrDefault(c => c is TResult); + if (updateContent is not null) + { + yield return (TResult)(object)updateContent; + continue; + } + } + + if (typeof(TResult) == typeof(byte[])) + { + DataContent? dataContent = (DataContent?)chatUpdate.Contents.FirstOrDefault(c => c is DataContent dataContent); + if (dataContent is not null) + { + yield return (TResult)(object)dataContent.Data.ToArray(); + continue; + } + } + + // Avoid breaking changes this transformation will be dropped once we migrate fully to Microsoft Extensions AI abstractions. + // This is also necessary to don't break existing code using KernelContents when using IChatClient connectors. + if (typeof(StreamingKernelContent).IsAssignableFrom(typeof(TResult))) + { + yield return (TResult)(object)chatUpdate.ToStreamingChatMessageContent(); + continue; + } + } - _ => throw new NotSupportedException($"The specific type {typeof(TResult)} is not supported. Support types are {typeof(StreamingTextContent)}, string, byte[], or a matching type for {typeof(StreamingTextContent)}.{nameof(StreamingTextContent.InnerContent)} property") - }; + throw new NotSupportedException($"The specific type {typeof(TResult)} is not supported. Support types are derivations of {typeof(StreamingKernelContent)}, {typeof(StreamingKernelContent)}, string, byte[], or a matching type for {typeof(StreamingKernelContent)}.{nameof(StreamingKernelContent.InnerContent)} property"); } // There is no post cancellation check to override the result as the stream data was already sent. } /// - public override KernelFunction Clone(string pluginName) + public override KernelFunction Clone(string? pluginName = null) { - Verify.NotNullOrWhiteSpace(pluginName, nameof(pluginName)); - - if (base.JsonSerializerOptions is not null) + if (pluginName is not null) { - return new KernelFunctionFromPrompt( - this._promptTemplate, - this.Name, - pluginName, - this.Description, - this.Metadata.Parameters, - base.JsonSerializerOptions, - this.Metadata.ReturnParameter, - this.ExecutionSettings as Dictionary ?? this.ExecutionSettings!.ToDictionary(kv => kv.Key, kv => kv.Value), - this._inputVariables, - this._logger); + Verify.NotNullOrWhiteSpace(pluginName, nameof(pluginName)); } - [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "Non AOT scenario.")] - [UnconditionalSuppressMessage("AOT", "IL3050:Calling members annotated with 'RequiresDynamicCodeAttribute' may break functionality when AOT compiling.", Justification = "Non AOT scenario.")] - KernelFunctionFromPrompt Clone() - { - return new KernelFunctionFromPrompt( + return new KernelFunctionFromPrompt( this._promptTemplate, this.Name, pluginName, this.Description, this.Metadata.Parameters, + base.JsonSerializerOptions, this.Metadata.ReturnParameter, this.ExecutionSettings as Dictionary ?? this.ExecutionSettings!.ToDictionary(kv => kv.Key, kv => kv.Value), this._inputVariables, this._logger); - } - - return Clone(); } [RequiresUnreferencedCode("Uses reflection to handle various aspects of the function creation and invocation, making it incompatible with AOT scenarios.")] @@ -447,13 +514,13 @@ private KernelFunctionFromPrompt( private const string MeasurementModelTagName = "semantic_kernel.function.model_id"; /// to record function invocation prompt token usage. - private static readonly Histogram s_invocationTokenUsagePrompt = s_meter.CreateHistogram( + private static readonly Histogram s_invocationTokenUsagePrompt = s_meter.CreateHistogram( name: "semantic_kernel.function.invocation.token_usage.prompt", unit: "{token}", description: "Measures the prompt token usage"); /// to record function invocation completion token usage. - private static readonly Histogram s_invocationTokenUsageCompletion = s_meter.CreateHistogram( + private static readonly Histogram s_invocationTokenUsageCompletion = s_meter.CreateHistogram( name: "semantic_kernel.function.invocation.token_usage.completion", unit: "{token}", description: "Measures the completion token usage"); @@ -478,7 +545,7 @@ private async Task RenderPromptAsync( { var serviceSelector = kernel.ServiceSelector; - IAIService? aiService; + IAIService? aiService = null; string renderedPrompt = string.Empty; // Try to use IChatCompletionService. @@ -488,12 +555,41 @@ private async Task RenderPromptAsync( { aiService = chatService; } - else + else if (serviceSelector.TrySelectAIService( + kernel, this, arguments, + out ITextGenerationService? textService, out executionSettings)) + { + aiService = textService; + } +#pragma warning disable CA2000 // Dispose objects before losing scope + else if (serviceSelector is IChatClientSelector chatClientServiceSelector + && chatClientServiceSelector.TrySelectChatClient(kernel, this, arguments, out var chatClient, out executionSettings)) + { + // Resolves a ChatClient as AIService so it don't need to implement IChatCompletionService. + aiService = new ChatClientAIService(chatClient); + } + + if (aiService is null) { - // If IChatCompletionService isn't available, try to fallback to ITextGenerationService, - // throwing if it's not available. - (aiService, executionSettings) = serviceSelector.SelectAIService(kernel, this, arguments); + var message = new StringBuilder().Append("No service was found for any of the supported types: ").Append(typeof(IChatCompletionService)).Append(", ").Append(typeof(ITextGenerationService)).Append(", ").Append(typeof(IChatClient)).Append('.'); + if (this.ExecutionSettings is not null) + { + string serviceIds = string.Join("|", this.ExecutionSettings.Keys); + if (!string.IsNullOrEmpty(serviceIds)) + { + message.Append(" Expected serviceIds: ").Append(serviceIds).Append('.'); + } + + string modelIds = string.Join("|", this.ExecutionSettings.Values.Select(model => model.ModelId)); + if (!string.IsNullOrEmpty(modelIds)) + { + message.Append(" Expected modelIds: ").Append(modelIds).Append('.'); + } + } + + throw new KernelException(message.ToString()); } +#pragma warning restore CA2000 // Dispose objects before losing scope Verify.NotNull(aiService); @@ -594,16 +690,30 @@ JsonElement SerializeToElement(object? value) s_invocationTokenUsagePrompt.Record(promptTokens, in tags); s_invocationTokenUsageCompletion.Record(completionTokens, in tags); } - else if (jsonObject.TryGetProperty("InputTokenCount", out var inputTokensJson) && - inputTokensJson.TryGetInt32(out int inputTokens) && - jsonObject.TryGetProperty("OutputTokenCount", out var outputTokensJson) && - outputTokensJson.TryGetInt32(out int outputTokens)) + else if (jsonObject.TryGetProperty("InputTokenCount", out var pascalInputTokensJson) && + pascalInputTokensJson.TryGetInt32(out int pascalInputTokens) && + jsonObject.TryGetProperty("OutputTokenCount", out var pascalOutputTokensJson) && + pascalOutputTokensJson.TryGetInt32(out int pascalOutputTokens)) { TagList tags = new() { { MeasurementFunctionTagName, this.Name }, { MeasurementModelTagName, modelId } }; + s_invocationTokenUsagePrompt.Record(pascalInputTokens, in tags); + s_invocationTokenUsageCompletion.Record(pascalOutputTokens, in tags); + } + else if (jsonObject.TryGetProperty("inputTokenCount", out var inputTokensJson) && + inputTokensJson.TryGetInt32(out int inputTokens) && + jsonObject.TryGetProperty("outputTokenCount", out var outputTokensJson) && + outputTokensJson.TryGetInt32(out int outputTokens)) + { + TagList tags = new() + { + { MeasurementFunctionTagName, this.Name }, + { MeasurementModelTagName, modelId } + }; + s_invocationTokenUsagePrompt.Record(inputTokens, in tags); s_invocationTokenUsageCompletion.Record(outputTokens, in tags); } @@ -613,6 +723,46 @@ JsonElement SerializeToElement(object? value) } } + /// + /// Captures usage details, including token information. + /// + private void CaptureUsageDetails(string? modelId, UsageDetails? usageDetails, ILogger logger) + { + if (!logger.IsEnabled(LogLevel.Information) && + !s_invocationTokenUsageCompletion.Enabled && + !s_invocationTokenUsagePrompt.Enabled) + { + // Bail early to avoid unnecessary work. + return; + } + + if (string.IsNullOrWhiteSpace(modelId)) + { + logger.LogInformation("No model ID provided to capture usage details."); + return; + } + + if (usageDetails is null) + { + logger.LogInformation("No usage details was provided."); + return; + } + + if (usageDetails.InputTokenCount.HasValue && usageDetails.OutputTokenCount.HasValue) + { + TagList tags = new() { + { MeasurementFunctionTagName, this.Name }, + { MeasurementModelTagName, modelId } + }; + s_invocationTokenUsagePrompt.Record(usageDetails.InputTokenCount.Value, in tags); + s_invocationTokenUsageCompletion.Record(usageDetails.OutputTokenCount.Value, in tags); + } + else + { + logger.LogWarning("Unable to get token details from model result."); + } + } + private async Task GetChatCompletionResultAsync( IChatCompletionService chatCompletion, Kernel kernel, @@ -644,6 +794,40 @@ private async Task GetChatCompletionResultAsync( return new FunctionResult(this, chatContents, kernel.Culture) { RenderedPrompt = promptRenderingResult.RenderedPrompt }; } + private async Task GetChatClientResultAsync( + IChatClient chatClient, + Kernel kernel, + PromptRenderingResult promptRenderingResult, + CancellationToken cancellationToken) + { + var chatResponse = await chatClient.GetResponseAsync( + promptRenderingResult.RenderedPrompt, + promptRenderingResult.ExecutionSettings, + kernel, + cancellationToken).ConfigureAwait(false); + + if (chatResponse.Messages is { Count: 0 }) + { + return new FunctionResult(this, chatResponse) + { + Culture = kernel.Culture, + RenderedPrompt = promptRenderingResult.RenderedPrompt + }; + } + + var modelId = chatClient.GetService()?.DefaultModelId; + + // Usage details are global and duplicated for each chat message content, use first one to get usage information + this.CaptureUsageDetails(chatClient.GetService()?.DefaultModelId, chatResponse.Usage, this._logger); + + return new FunctionResult(this, chatResponse) + { + Culture = kernel.Culture, + RenderedPrompt = promptRenderingResult.RenderedPrompt, + Metadata = chatResponse.AdditionalProperties, + }; + } + private async Task GetTextGenerationResultAsync( ITextGenerationService textGeneration, Kernel kernel, diff --git a/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/AIFunctionKernelFunctionTests.cs b/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/AIFunctionKernelFunctionTests.cs index d2b7634eec6b..a4d5b9990e7e 100644 --- a/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/AIFunctionKernelFunctionTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/AIFunctionKernelFunctionTests.cs @@ -1,6 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Linq; +using System.Threading; +using System.Threading.Tasks; using Microsoft.Extensions.AI; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; @@ -26,4 +29,184 @@ public void ShouldAssignIsRequiredParameterMetadataPropertyCorrectly() KernelParameterMetadata? p2Metadata = sut.Metadata.Parameters.FirstOrDefault(p => p.Name == "p2"); Assert.False(p2Metadata?.IsRequired); } + + [Fact] + public void ShouldUseKernelFunctionNameWhenWrappingKernelFunction() + { + // Arrange + var kernelFunction = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + + // Act + AIFunctionKernelFunction sut = new(kernelFunction); + + // Assert + Assert.Equal("TestFunction", sut.Name); + } + + [Fact] + public void ShouldUseKernelFunctionPluginAndNameWhenWrappingKernelFunction() + { + // Arrange + var kernelFunction = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction") + .Clone("TestPlugin"); // Simulate a plugin name + + // Act + AIFunctionKernelFunction sut = new(kernelFunction); + + // Assert + Assert.Equal("TestPlugin_TestFunction", sut.Name); + Assert.Equal("TestPlugin", sut.PluginName); + } + + [Fact] + public void ShouldUseNameOnlyInToStringWhenWrappingKernelFunctionWithPlugin() + { + // Arrange + var kernelFunction = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction") + .Clone("TestPlugin"); + + // Act + AIFunctionKernelFunction sut = new(kernelFunction); + + // Assert + Assert.Equal("TestPlugin_TestFunction", sut.ToString()); + } + + [Fact] + public void ShouldUseAIFunctionNameWhenWrappingNonKernelFunction() + { + // Arrange + var aiFunction = new TestAIFunction("CustomName"); + + // Act + AIFunctionKernelFunction sut = new(aiFunction); + + // Assert + Assert.Equal("CustomName", sut.Name); + Assert.Null(sut.PluginName); + } + + [Fact] + public void ShouldPreserveDescriptionFromAIFunction() + { + // Arrange + var aiFunction = new TestAIFunction("TestFunction", "This is a test description"); + + // Act + AIFunctionKernelFunction sut = new(aiFunction); + + // Assert + Assert.Equal("This is a test description", sut.Description); + } + + [Fact] + public async Task ShouldInvokeUnderlyingAIFunctionWhenInvoked() + { + // Arrange + var testAIFunction = new TestAIFunction("TestFunction"); + AIFunctionKernelFunction sut = new(testAIFunction); + var kernel = new Kernel(); + var arguments = new KernelArguments(); + + // Act + await sut.InvokeAsync(kernel, arguments); + + // Assert + Assert.True(testAIFunction.WasInvoked); + } + + [Fact] + public void ShouldCloneCorrectlyWithNewPluginName() + { + // Arrange + var aiFunction = new TestAIFunction("TestFunction"); + AIFunctionKernelFunction original = new(aiFunction); + + // Act + var cloned = original.Clone("NewPlugin"); + + // Assert + Assert.Equal("NewPlugin", cloned.PluginName); + Assert.Equal("NewPlugin_TestFunction", cloned.Name); + Assert.Equal("NewPlugin_TestFunction", cloned.ToString()); + } + + [Fact] + public async Task ClonedFunctionShouldInvokeOriginalAIFunction() + { + // Arrange + var testAIFunction = new TestAIFunction("TestFunction"); + AIFunctionKernelFunction original = new(testAIFunction); + var cloned = original.Clone("NewPlugin"); + var kernel = new Kernel(); + var arguments = new KernelArguments(); + + // Act + await cloned.InvokeAsync(kernel, arguments); + + // Assert + Assert.True(testAIFunction.WasInvoked); + } + + [Fact] + public async Task ShouldUseProvidedKernelWhenInvoking() + { + // Arrange + var kernel1 = new Kernel(); + var kernel2 = new Kernel(); + + // Create a function that returns the kernel's hash code + var function = KernelFunctionFactory.CreateFromMethod( + (Kernel k) => k.GetHashCode().ToString(), + "GetKernelHashCode"); + + var aiFunction = new AIFunctionKernelFunction(function); + + // Clone with a new plugin name + var clonedFunction = aiFunction.Clone("NewPlugin"); + + // Act + var result1 = await clonedFunction.InvokeAsync(kernel1, new()); + var result2 = await clonedFunction.InvokeAsync(kernel2, new()); + + // Assert - verify that the results are different when using different kernels + var result1Str = result1.GetValue()?.ToString(); + var result2Str = result2.GetValue()?.ToString(); + Assert.NotNull(result1Str); + Assert.NotNull(result2Str); + Assert.NotEqual(result1Str, result2Str); + } + + [Fact] + public void ShouldThrowWhenPluginNameIsNullOrWhitespace() + { + // Arrange + var aiFunction = new TestAIFunction("TestFunction"); + AIFunctionKernelFunction original = new(aiFunction); + + // Act & Assert + Assert.Throws(() => original.Clone(string.Empty)); + Assert.Throws(() => original.Clone(" ")); + } + + private sealed class TestAIFunction : AIFunction + { + public bool WasInvoked { get; private set; } + + public TestAIFunction(string name, string description = "") + { + this.Name = name; + this.Description = description; + } + + public override string Name { get; } + + public override string Description { get; } + + protected override ValueTask InvokeCoreAsync(AIFunctionArguments? arguments = null, CancellationToken cancellationToken = default) + { + this.WasInvoked = true; + return ValueTask.FromResult("Test result"); + } + } } diff --git a/dotnet/src/SemanticKernel.UnitTests/AI/ServiceConversionExtensionsTests.cs b/dotnet/src/SemanticKernel.UnitTests/AI/ServiceConversionExtensionsTests.cs index cd6a1e943a4f..6b705263a010 100644 --- a/dotnet/src/SemanticKernel.UnitTests/AI/ServiceConversionExtensionsTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/AI/ServiceConversionExtensionsTests.cs @@ -30,7 +30,7 @@ public void EmbeddingGenerationInvalidArgumentsThrow() public void ChatCompletionInvalidArgumentsThrow() { Assert.Throws("service", () => ChatCompletionServiceExtensions.AsChatClient(null!)); - Assert.Throws("client", () => ChatCompletionServiceExtensions.AsChatCompletionService(null!)); + Assert.Throws("client", () => Microsoft.SemanticKernel.ChatCompletion.ChatClientExtensions.AsChatCompletionService(null!)); } [Fact] @@ -304,11 +304,11 @@ public async Task AsChatClientNonStreamingToolsPropagated(ChatToolMode mode) [ new NopAIFunction("AIFunc1"), new NopAIFunction("AIFunc2"), - KernelFunctionFactory.CreateFromMethod(() => "invoked", "NiftyFunction").AsAIFunction(), + KernelFunctionFactory.CreateFromMethod(() => "invoked", "NiftyFunction"), .. KernelPluginFactory.CreateFromFunctions("NiftyPlugin", [ KernelFunctionFactory.CreateFromMethod(() => "invoked", "NiftyFunction") - ]).AsAIFunctions(), + ]), ], ToolMode = mode, }); diff --git a/dotnet/src/SemanticKernel.UnitTests/Extensions/ClientResultExceptionExtensionsTests.cs b/dotnet/src/SemanticKernel.UnitTests/Extensions/ClientResultExceptionExtensionsTests.cs index c9c348d1ac44..c25fa0ccd249 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Extensions/ClientResultExceptionExtensionsTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Extensions/ClientResultExceptionExtensionsTests.cs @@ -64,7 +64,6 @@ public void ItProvideStatusForResponsesWithoutContent() // Assert Assert.NotNull(httpOperationException); Assert.NotNull(httpOperationException.StatusCode); - Assert.Null(httpOperationException.ResponseContent); Assert.Equal(exception, httpOperationException.InnerException); Assert.Equal(exception.Message, httpOperationException.Message); Assert.Equal(pipelineResponse.Status, (int)httpOperationException.StatusCode!); diff --git a/dotnet/src/SemanticKernel.UnitTests/Filters/AutoFunctionInvocation/AutoFunctionInvocationContextTests.cs b/dotnet/src/SemanticKernel.UnitTests/Filters/AutoFunctionInvocation/AutoFunctionInvocationContextTests.cs new file mode 100644 index 000000000000..61930b1c2007 --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Filters/AutoFunctionInvocation/AutoFunctionInvocationContextTests.cs @@ -0,0 +1,566 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Xunit; +using ChatMessageContent = Microsoft.SemanticKernel.ChatMessageContent; + +namespace SemanticKernel.UnitTests.Filters.AutoFunctionInvocation; + +public class AutoFunctionInvocationContextTests +{ + [Fact] + public void ConstructorWithValidParametersCreatesInstance() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Act + var context = new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + chatMessageContent); + + // Assert + Assert.NotNull(context); + Assert.Same(kernel, context.Kernel); + Assert.Same(function, context.Function); + Assert.Same(result, context.Result); + Assert.Same(chatHistory, context.ChatHistory); + Assert.Same(chatMessageContent, context.ChatMessageContent); + } + + [Fact] + public void ConstructorWithNullKernelThrowsException() + { + // Arrange + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Act & Assert + Assert.Throws(() => new AutoFunctionInvocationContext( + null!, + function, + result, + chatHistory, + chatMessageContent)); + } + + [Fact] + public void ConstructorWithNullFunctionThrowsException() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Act & Assert + Assert.Throws(() => new AutoFunctionInvocationContext( + kernel, + null!, + result, + chatHistory, + chatMessageContent)); + } + + [Fact] + public void ConstructorWithNullResultThrowsException() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Act & Assert + Assert.Throws(() => new AutoFunctionInvocationContext( + kernel, + function, + null!, + chatHistory, + chatMessageContent)); + } + + [Fact] + public void ConstructorWithNullChatHistoryThrowsException() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Act & Assert + Assert.Throws(() => new AutoFunctionInvocationContext( + kernel, + function, + result, + null!, + chatMessageContent)); + } + + [Fact] + public void ConstructorWithNullChatMessageContentThrowsException() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + + // Act & Assert + Assert.Throws(() => new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + null!)); + } + + [Fact] + public void PropertiesReturnCorrectValues() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Act + var context = new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + chatMessageContent); + + // Assert + Assert.Same(kernel, context.Kernel); + Assert.Same(function, context.Function); + Assert.Same(result, context.Result); + Assert.Same(chatHistory, context.ChatHistory); + Assert.Same(chatMessageContent, context.ChatMessageContent); + } + + [Fact] + public async Task AutoFunctionInvocationContextCanBeUsedInFilter() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + var context = new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + chatMessageContent); + + bool filterWasCalled = false; + + // Create a simple filter that just sets a flag + async Task FilterMethod(AutoFunctionInvocationContext ctx, Func next) + { + filterWasCalled = true; + Assert.Same(context, ctx); + await next(ctx); + } + + // Act + await FilterMethod(context, _ => Task.CompletedTask); + + // Assert + Assert.True(filterWasCalled); + } + + [Fact] + public void ExecutionSettingsCanBeSetAndRetrieved() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + var executionSettings = new PromptExecutionSettings(); + + var options = new KernelChatOptions(kernel, settings: executionSettings) + { + ChatMessageContent = chatMessageContent, + }; + + // Act + var context = new AutoFunctionInvocationContext(options, function); + + // Assert + Assert.Same(executionSettings, context.ExecutionSettings); + } + + [Fact] + public async Task KernelFunctionCloneWithKernelUsesProvidedKernel() + { + // Arrange + var originalKernel = new Kernel(); + var newKernel = new Kernel(); + + // Create a function that returns the kernel's hash code + var function = KernelFunctionFactory.CreateFromMethod( + (Kernel k) => k.GetHashCode().ToString(), + "GetKernelHashCode"); + + // Act + // Create AIFunctions with different kernels + var aiFunction1 = function.WithKernel(originalKernel); + var aiFunction2 = function.WithKernel(newKernel); + + // Invoke both functions + var args = new AIFunctionArguments(); + var result1 = await aiFunction1.InvokeAsync(args, default); + var result2 = await aiFunction2.InvokeAsync(args, default); + + // Assert + // The results should be different because they use different kernels + Assert.NotNull(result1); + Assert.NotNull(result2); + Assert.NotEqual(result1, result2); + Assert.Equal(originalKernel.GetHashCode().ToString(), result1.ToString()); + Assert.Equal(newKernel.GetHashCode().ToString(), result2.ToString()); + } + + // Let's simplify our approach and use a different testing strategy + [Fact] + public void ArgumentsPropertyHandlesKernelArguments() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Create KernelArguments and set them via the init property + var kernelArgs = new KernelArguments { ["test"] = "value" }; + + // Set the arguments via the init property + var contextWithArgs = new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + chatMessageContent) + { + Arguments = kernelArgs + }; + + // Act & Assert + Assert.Same(kernelArgs, contextWithArgs.Arguments); + } + + [Fact] + public void ArgumentsPropertyInitializesEmptyKernelArgumentsWhenSetToNull() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Set the arguments to null via the init property + var contextWithNullArgs = new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + chatMessageContent) + { + Arguments = null + }; + + // Act & Assert + Assert.NotNull(contextWithNullArgs.Arguments); + Assert.IsType(contextWithNullArgs.Arguments); + Assert.Empty(contextWithNullArgs.Arguments); + } + + [Fact] + public void ArgumentsPropertyCanBeSetWithMultipleValues() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Create KernelArguments with multiple values + var kernelArgs = new KernelArguments + { + ["string"] = "value", + ["int"] = 42, + ["bool"] = true, + ["object"] = new object() + }; + + // Set the arguments via the init property + var context = new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + chatMessageContent) + { + Arguments = kernelArgs + }; + + // Act & Assert + Assert.Same(kernelArgs, context.Arguments); + Assert.Equal(4, context.Arguments.Count); + Assert.Equal("value", context.Arguments["string"]); + Assert.Equal(42, context.Arguments["int"]); + Assert.Equal(true, context.Arguments["bool"]); + Assert.NotNull(context.Arguments["object"]); + } + + [Fact] + public void ArgumentsPropertyCanBeSetWithExecutionSettings() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + var executionSettings = new PromptExecutionSettings(); + + // Create KernelArguments with execution settings + var kernelArgs = new KernelArguments(executionSettings) + { + ["test"] = "value" + }; + + // Set the arguments via the init property + var context = new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + chatMessageContent) + { + Arguments = kernelArgs + }; + + // Act & Assert + Assert.Same(kernelArgs, context.Arguments); + Assert.Equal("value", context.Arguments["test"]); + Assert.Same(executionSettings, context.Arguments.ExecutionSettings?[PromptExecutionSettings.DefaultServiceId]); + } + + [Fact] + public void ArgumentsPropertyThrowsWhenBaseArgumentsIsNotKernelArguments() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var result = new FunctionResult(function); + var chatHistory = new ChatHistory(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + var context = new AutoFunctionInvocationContext( + kernel, + function, + result, + chatHistory, + chatMessageContent); + + ((Microsoft.Extensions.AI.FunctionInvocationContext)context).Arguments = new AIFunctionArguments(); + + // Act & Assert + Assert.Throws(() => context.Arguments); + } + + [Fact] + public void InternalConstructorWithOptionsAndAIFunctionSetsPropertiesCorrectly() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + var executionSettings = new PromptExecutionSettings(); + + var options = new KernelChatOptions(kernel, settings: executionSettings) + { + ChatMessageContent = chatMessageContent, + }; + + // Act + var context = new AutoFunctionInvocationContext(options, function); + + // Assert + Assert.Same(kernel, context.Kernel); + Assert.Same(function, context.Function); + Assert.Same(executionSettings, context.ExecutionSettings); + Assert.Same(chatMessageContent, context.ChatMessageContent); + Assert.NotNull(context.Result); + Assert.Equal(kernel.Culture, context.Result.Culture); + } + + [Fact] + public void InternalConstructorWithOptionsAndAIFunctionThrowsWithNullOptions() + { + // Arrange + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + + // Act & Assert + Assert.Throws(() => new AutoFunctionInvocationContext(null!, function)); + } + + [Fact] + public void InternalConstructorWithOptionsAndAIFunctionThrowsWithNullFunction() + { + // Arrange + var kernel = new Kernel(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + var options = new KernelChatOptions(kernel) + { + ChatMessageContent = chatMessageContent, + }; + + // Act & Assert + Assert.Throws(() => new AutoFunctionInvocationContext(options, null!)); + } + + [Fact] + public void InternalConstructorWithOptionsAndAIFunctionThrowsWithNonKernelFunction() + { + // Arrange + var kernel = new Kernel(); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + var testAIFunction = new TestAIFunction("TestFunction"); + + var options = new KernelChatOptions(kernel) + { + ChatMessageContent = chatMessageContent, + }; + + // Act & Assert + Assert.Throws(() => new AutoFunctionInvocationContext(options, testAIFunction)); + } + + [Fact] + public void InternalConstructorWithOptionsAndAIFunctionThrowsWithMissingKernel() + { + // Arrange + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Act & Assert + Assert.Throws(() => + // Create options without kernel + new AutoFunctionInvocationContext(new(null!) { ChatMessageContent = chatMessageContent }, function) + ); + } + + [Fact] + public void InternalConstructorWithOptionsAndAIFunctionThrowsWithMissingChatMessageContent() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + + // Create options without chat message content + var options = new KernelChatOptions(kernel); + + // Act & Assert + Assert.Throws(() => new AutoFunctionInvocationContext(options, function)); + } + + [Fact] + public void InternalConstructorWithOptionsAndAIFunctionCanSetAndRetrieveArguments() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + var kernelArgs = new KernelArguments { ["test"] = "value" }; + + // Create options with required properties + var options = new KernelChatOptions(kernel) + { + ChatMessageContent = chatMessageContent, + }; + + // Act + var context = new AutoFunctionInvocationContext(options, function) + { + Arguments = kernelArgs + }; + + // Assert + Assert.Same(kernelArgs, context.Arguments); + Assert.Equal("value", context.Arguments["test"]); + } + + [Fact] + public void InternalConstructorWithOptionsAndAIFunctionInitializesEmptyArgumentsWhenSetToNull() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + var chatMessageContent = new ChatMessageContent(AuthorRole.Assistant, "Test message"); + + // Create options with required properties + var options = new KernelChatOptions(kernel) + { + ChatMessageContent = chatMessageContent, + }; + // Act + var context = new AutoFunctionInvocationContext(options, function) + { + Arguments = null + }; + + // Assert + Assert.NotNull(context.Arguments); + Assert.IsType(context.Arguments); + Assert.Empty(context.Arguments); + } + + // Helper class for testing non-KernelFunction AIFunction + private sealed class TestAIFunction : AIFunction + { + public TestAIFunction(string name, string description = "") + { + this.Name = name; + this.Description = description; + } + + public override string Name { get; } + + public override string Description { get; } + + protected override ValueTask InvokeCoreAsync(AIFunctionArguments arguments, CancellationToken cancellationToken) + { + return ValueTask.FromResult("Test result"); + } + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs new file mode 100644 index 000000000000..7400875cdc9b --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Services; +using Xunit; + +namespace SemanticKernel.UnitTests.Functions; + +public class CustomAIChatClientSelectorTests +{ + [Fact] + public void ItGetsChatClientUsingModelIdAttribute() + { + // Arrange + IKernelBuilder builder = Kernel.CreateBuilder(); + using var chatClient = new ChatClientTest(); + builder.Services.AddKeyedSingleton("service1", chatClient); + Kernel kernel = builder.Build(); + + var function = kernel.CreateFunctionFromPrompt("Hello AI"); + IChatClientSelector chatClientSelector = new CustomChatClientSelector(); + + // Act + chatClientSelector.TrySelectChatClient(kernel, function, [], out var selectedChatClient, out var defaultExecutionSettings); + + // Assert + Assert.NotNull(selectedChatClient); + Assert.Equal("Value1", selectedChatClient.GetModelId()); + Assert.Null(defaultExecutionSettings); + selectedChatClient.Dispose(); + } + + private sealed class CustomChatClientSelector : IChatClientSelector + { +#pragma warning disable CS8769 // Nullability of reference types in value doesn't match target type. Cannot use [NotNullWhen] because of access to internals from abstractions. + public bool TrySelectChatClient(Kernel kernel, KernelFunction function, KernelArguments arguments, [NotNullWhen(true)] out T? service, out PromptExecutionSettings? serviceSettings) + where T : class, IChatClient + { + var keyedService = (kernel.Services as IKeyedServiceProvider)?.GetKeyedService("service1"); + if (keyedService is null || keyedService.GetModelId() is null) + { + service = null; + serviceSettings = null; + return false; + } + + service = string.Equals(keyedService.GetModelId(), "Value1", StringComparison.OrdinalIgnoreCase) ? keyedService as T : null; + serviceSettings = null; + + if (service is null) + { + throw new InvalidOperationException("Service not found"); + } + + return true; + } + } + + private sealed class ChatClientTest : IChatClient + { + private readonly ChatClientMetadata _metadata; + + public ChatClientTest() + { + this._metadata = new ChatClientMetadata(defaultModelId: "Value1"); + } + + public void Dispose() + { + } + + public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public object? GetService(Type serviceType, object? serviceKey = null) + { + return this._metadata; + } + + public IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIServiceSelectorTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIServiceSelectorTests.cs index a53d8550c4d7..4697a0958c64 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIServiceSelectorTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIServiceSelectorTests.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Services; @@ -33,7 +35,8 @@ public void ItGetsAIServiceUsingArbitraryAttributes() private sealed class CustomAIServiceSelector : IAIServiceSelector { #pragma warning disable CS8769 // Nullability of reference types in value doesn't match target type. Cannot use [NotNullWhen] because of access to internals from abstractions. - bool IAIServiceSelector.TrySelectAIService(Kernel kernel, KernelFunction function, KernelArguments arguments, out T? service, out PromptExecutionSettings? serviceSettings) where T : class + public bool TrySelectAIService(Kernel kernel, KernelFunction function, KernelArguments arguments, [NotNullWhen(true)] out T? service, out PromptExecutionSettings? serviceSettings) + where T : class, IAIService { var keyedService = (kernel.Services as IKeyedServiceProvider)?.GetKeyedService("service1"); if (keyedService is null || keyedService.Attributes is null) @@ -45,6 +48,12 @@ bool IAIServiceSelector.TrySelectAIService(Kernel kernel, KernelFunction func service = keyedService.Attributes.ContainsKey("Key1") ? keyedService as T : null; serviceSettings = null; + + if (service is null) + { + throw new InvalidOperationException("Service not found"); + } + return true; } } diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/FunctionResultTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/FunctionResultTests.cs index 787718b6e8e4..44ba05c5e547 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/FunctionResultTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/FunctionResultTests.cs @@ -3,8 +3,11 @@ using System; using System.Collections.Generic; using System.Globalization; +using System.Linq; using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; using Xunit; +using MEAI = Microsoft.Extensions.AI; namespace SemanticKernel.UnitTests.Functions; @@ -134,4 +137,186 @@ public void GetValueWhenValueIsKernelContentGenericTypeMatchShouldReturn() Assert.Equal(valueType, target.GetValue()); Assert.Equal(valueType, target.GetValue()); } + + [Fact] + public void GetValueConvertsFromMEAIChatMessageToSKChatMessageContent() + { + // Arrange + string expectedValue = Guid.NewGuid().ToString(); + var openAICompletion = OpenAI.Chat.OpenAIChatModelFactory.ChatCompletion( + role: OpenAI.Chat.ChatMessageRole.User, + content: new OpenAI.Chat.ChatMessageContent(expectedValue)); + + var valueType = new MEAI.ChatResponse( + [ + new MEAI.ChatMessage(MEAI.ChatRole.User, expectedValue) + { + RawRepresentation = openAICompletion.Content + }, + new MEAI.ChatMessage(MEAI.ChatRole.Assistant, expectedValue) + { + RawRepresentation = openAICompletion.Content + } + ]) + { + RawRepresentation = openAICompletion + }; + + FunctionResult target = new(s_nopFunction, valueType); + + // Act and Assert + var message = target.GetValue()!; + Assert.Equal(valueType.Messages.Last().Text, message.Content); + Assert.Same(valueType.Messages.Last().RawRepresentation, message.InnerContent); + } + + [Fact] + public void GetValueConvertsFromSKChatMessageContentToMEAIChatMessage() + { + // Arrange + string expectedValue = Guid.NewGuid().ToString(); + var openAIChatMessage = OpenAI.Chat.ChatMessage.CreateUserMessage(expectedValue); + var valueType = new ChatMessageContent(AuthorRole.User, expectedValue) { InnerContent = openAIChatMessage }; + FunctionResult target = new(s_nopFunction, valueType); + + // Act and Assert + Assert.Equal(valueType.Content, target.GetValue()!.Text); + Assert.Same(valueType.InnerContent, target.GetValue()!.RawRepresentation); + } + + [Fact] + public void GetValueConvertsFromSKChatMessageContentToMEAIChatResponse() + { + // Arrange + string expectedValue = Guid.NewGuid().ToString(); + var openAIChatMessage = OpenAI.Chat.ChatMessage.CreateUserMessage(expectedValue); + var valueType = new ChatMessageContent(AuthorRole.User, expectedValue) { InnerContent = openAIChatMessage }; + FunctionResult target = new(s_nopFunction, valueType); + + // Act and Assert + + Assert.Equal(valueType.Content, target.GetValue()!.Text); + Assert.Same(valueType.InnerContent, target.GetValue()!.Messages[0].RawRepresentation); + } + + [Theory] + [InlineData(1)] + [InlineData(2)] + [InlineData(5)] + public void GetValueConvertsFromSKChatMessageContentListToMEAIChatResponse(int listSize) + { + // Arrange + List multipleChoiceResponse = []; + for (int i = 0; i < listSize; i++) + { + multipleChoiceResponse.Add(new ChatMessageContent(AuthorRole.User, Guid.NewGuid().ToString()) + { + InnerContent = OpenAI.Chat.ChatMessage.CreateUserMessage(i.ToString()) + }); + } + FunctionResult target = new(KernelFunctionFactory.CreateFromMethod(() => { }), (IReadOnlyList)multipleChoiceResponse); + + // Act and Assert + // Ensure returns the ChatResponse for no choices as well + var result = target.GetValue()!; + Assert.NotNull(result); + + for (int i = 0; i < listSize; i++) + { + // Ensure the other choices are not added as messages, only the first choice is considered + Assert.Single(result.Messages); + + if (i == 0) + { + // The first choice is converted to a message + Assert.Equal(multipleChoiceResponse[i].Content, result.Messages.Last().Text); + Assert.Same(multipleChoiceResponse[i].InnerContent, result.Messages.Last().RawRepresentation); + } + else + { + // Any following choices messages are ignored and should not match the result message + Assert.NotEqual(multipleChoiceResponse[i].Content, result.Text); + Assert.NotSame(multipleChoiceResponse[i].InnerContent, result.Messages.Last().RawRepresentation); + } + } + + if (listSize > 0) + { + // Ensure the conversion to the first message works in one or multiple choice response + Assert.Equal(multipleChoiceResponse[0].Content, target.GetValue()!.Text); + Assert.Same(multipleChoiceResponse[0].InnerContent, target.GetValue()!.RawRepresentation); + } + } + + [Fact] + public void GetValueThrowsForEmptyChoicesFromSKChatMessageContentListToMEAITypes() + { + // Arrange + List multipleChoiceResponse = []; + FunctionResult target = new(KernelFunctionFactory.CreateFromMethod(() => { }), (IReadOnlyList)multipleChoiceResponse); + + // Act and Assert + var exception = Assert.Throws(target.GetValue); + Assert.Contains("no choices", exception.Message); + + exception = Assert.Throws(target.GetValue); + Assert.Contains("no choices", exception.Message); + } + + [Fact] + public void GetValueCanRetrieveMEAITypes() + { + // Arrange + string expectedValue = Guid.NewGuid().ToString(); + var openAICompletion = OpenAI.Chat.OpenAIChatModelFactory.ChatCompletion( + role: OpenAI.Chat.ChatMessageRole.User, + content: new OpenAI.Chat.ChatMessageContent(expectedValue)); + + var valueType = new MEAI.ChatResponse( + new MEAI.ChatMessage(MEAI.ChatRole.User, expectedValue) + { + RawRepresentation = openAICompletion.Content + }) + { + RawRepresentation = openAICompletion + }; + + FunctionResult target = new(s_nopFunction, valueType); + + // Act and Assert + Assert.Same(valueType, target.GetValue()); + Assert.Same(valueType.Messages[0], target.GetValue()); + Assert.Same(valueType.Messages[0].Contents[0], target.GetValue()); + Assert.Same(valueType.Messages[0].Contents[0], target.GetValue()); + + // Check the the content list is returned + Assert.Same(valueType.Messages[0].Contents, target.GetValue>()!); + Assert.Same(valueType.Messages[0].Contents[0], target.GetValue>()![0]); + Assert.IsType(target.GetValue>()![0]); + + // Check the raw representations are returned + Assert.Same(valueType.RawRepresentation, target.GetValue()!); + Assert.Same(valueType.Messages[0].RawRepresentation, target.GetValue()!); + } + + [Fact] + public void GetValueThrowsForEmptyMessagesToMEAITypes() + { + // Arrange + string expectedValue = Guid.NewGuid().ToString(); + var valueType = new MEAI.ChatResponse([]); + FunctionResult target = new(s_nopFunction, valueType); + + // Act and Assert + Assert.Empty(target.GetValue()!.Messages); + + var exception = Assert.Throws(target.GetValue); + Assert.Contains("no messages", exception.Message); + + exception = Assert.Throws(target.GetValue); + Assert.Contains("no messages", exception.Message); + + exception = Assert.Throws(target.GetValue); + Assert.Contains("no messages", exception.Message); + } } diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionCloneTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionCloneTests.cs new file mode 100644 index 000000000000..e4bc3aa9015c --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionCloneTests.cs @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using Xunit; + +namespace SemanticKernel.UnitTests.Functions; + +public class KernelFunctionCloneTests +{ + [Fact] + public async Task ClonedKernelFunctionUsesProvidedKernelWhenInvokingAsAIFunction() + { + // Arrange + var originalKernel = new Kernel(); + var newKernel = new Kernel(); + + // Create a function that returns the kernel's hash code + var function = KernelFunctionFactory.CreateFromMethod( + (Kernel k) => k.GetHashCode().ToString(), + "GetKernelHashCode"); + + // Create an AIFunction from the KernelFunction with the original kernel + var aiFunction = function.WithKernel(originalKernel); + + // Act + // Clone the function and create a new AIFunction with the new kernel + var clonedFunction = function.Clone("TestPlugin"); + var clonedAIFunction = clonedFunction.WithKernel(newKernel); + + // Invoke both functions + var originalResult = await aiFunction.InvokeAsync(new AIFunctionArguments(), default); + var clonedResult = await clonedAIFunction.InvokeAsync(new AIFunctionArguments(), default); + + // Assert + // The results should be different because they use different kernels + Assert.NotNull(originalResult); + Assert.NotNull(clonedResult); + Assert.NotEqual(originalResult, clonedResult); + Assert.Equal(originalKernel.GetHashCode().ToString(), originalResult.ToString()); + Assert.Equal(newKernel.GetHashCode().ToString(), clonedResult.ToString()); + } + + [Fact] + public async Task KernelAIFunctionUsesProvidedKernelWhenInvoking() + { + // Arrange + var kernel1 = new Kernel(); + var kernel2 = new Kernel(); + + // Create a function that returns the kernel's hash code + var function = KernelFunctionFactory.CreateFromMethod( + (Kernel k) => k.GetHashCode().ToString(), + "GetKernelHashCode"); + + // Act + // Create AIFunctions with different kernels + var aiFunction1 = function.WithKernel(kernel1); + var aiFunction2 = function.WithKernel(kernel2); + + // Invoke both functions + var result1 = await aiFunction1.InvokeAsync(new AIFunctionArguments(), default); + var result2 = await aiFunction2.InvokeAsync(new AIFunctionArguments(), default); + + // Assert + // The results should be different because they use different kernels + Assert.NotNull(result1); + Assert.NotNull(result2); + Assert.NotEqual(result1, result2); + Assert.Equal(kernel1.GetHashCode().ToString(), result1.ToString()); + Assert.Equal(kernel2.GetHashCode().ToString(), result2.ToString()); + } + + [Fact] + public void CloneStoresKernelForLaterUse() + { + // Arrange + var kernel = new Kernel(); + var function = KernelFunctionFactory.CreateFromMethod(() => "Test", "TestFunction"); + + // Act + var aiFunction = function.WithKernel(kernel); + + // Assert + // We can't directly access the private _kernel field, but we can verify it's used + // by checking that the AIFunction has the correct name format + Assert.Equal("TestFunction", aiFunction.Name); + } + + [Fact] + public void ClonePreservesMetadataButChangesPluginName() + { + // Arrange + var function = KernelFunctionFactory.CreateFromMethod( + () => "Test", + "TestFunction", + "Test description"); + + // Act + var clonedFunction = function.Clone("NewPlugin"); + + // Assert + Assert.Equal("TestFunction", clonedFunction.Name); + Assert.Equal("NewPlugin", clonedFunction.PluginName); + Assert.Equal("Test description", clonedFunction.Description); + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionFromMethodTests1.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionFromMethodTests1.cs index 9ca0e9c55557..814cc08c22bc 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionFromMethodTests1.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionFromMethodTests1.cs @@ -9,6 +9,7 @@ using System.Text.Json.Nodes; using System.Text.Json.Serialization; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel; using Moq; @@ -910,7 +911,7 @@ public async Task ItSupportsJsonElementArgumentsImplicitConversionAsync() functionName: "Test"); await function.InvokeAsync(this._kernel, arguments); - await function.AsAIFunction().InvokeAsync(new(arguments)); + await function.InvokeAsync(new AIFunctionArguments(arguments)); } [Fact] @@ -940,7 +941,7 @@ public async Task ItSupportsStringJsonElementArgumentsImplicitConversionAsync() functionName: "Test"); await function.InvokeAsync(this._kernel, arguments); - await function.AsAIFunction().InvokeAsync(new(arguments)); + await function.InvokeAsync(new AIFunctionArguments(arguments)); } [Fact] diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionFromPromptTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionFromPromptTests.cs index 72dc5199dafb..7bb77ce26105 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionFromPromptTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionFromPromptTests.cs @@ -17,8 +17,7 @@ using Moq; using SemanticKernel.UnitTests.Functions.JsonSerializerContexts; using Xunit; - -// ReSharper disable StringLiteralTypo +using MEAI = Microsoft.Extensions.AI; namespace SemanticKernel.UnitTests.Functions; @@ -150,18 +149,18 @@ public async Task ItUsesServiceIdWhenProvidedInMethodAsync() public async Task ItUsesChatServiceIdWhenProvidedInMethodAsync() { // Arrange - var mockTextGeneration1 = new Mock(); - var mockTextGeneration2 = new Mock(); + var mockTextGeneration = new Mock(); + var mockChatCompletion = new Mock(); var fakeTextContent = new TextContent("llmResult"); var fakeChatContent = new ChatMessageContent(AuthorRole.User, "content"); - mockTextGeneration1.Setup(c => c.GetTextContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync([fakeTextContent]); - mockTextGeneration2.Setup(c => c.GetChatMessageContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync([fakeChatContent]); + mockTextGeneration.Setup(c => c.GetTextContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync([fakeTextContent]); + mockChatCompletion.Setup(c => c.GetChatMessageContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync([fakeChatContent]); IKernelBuilder builder = Kernel.CreateBuilder(); - builder.Services.AddKeyedSingleton("service1", mockTextGeneration1.Object); - builder.Services.AddKeyedSingleton("service2", mockTextGeneration2.Object); - builder.Services.AddKeyedSingleton("service3", mockTextGeneration1.Object); + builder.Services.AddKeyedSingleton("service1", mockTextGeneration.Object); + builder.Services.AddKeyedSingleton("service2", mockChatCompletion.Object); + builder.Services.AddKeyedSingleton("service3", mockTextGeneration.Object); Kernel kernel = builder.Build(); var func = kernel.CreateFunctionFromPrompt("my prompt", [new PromptExecutionSettings { ServiceId = "service2" }]); @@ -170,8 +169,41 @@ public async Task ItUsesChatServiceIdWhenProvidedInMethodAsync() await kernel.InvokeAsync(func); // Assert - mockTextGeneration1.Verify(a => a.GetTextContentsAsync("my prompt", It.IsAny(), It.IsAny(), It.IsAny()), Times.Never()); - mockTextGeneration2.Verify(a => a.GetChatMessageContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Once()); + mockTextGeneration.Verify(a => a.GetTextContentsAsync("my prompt", It.IsAny(), It.IsAny(), It.IsAny()), Times.Never()); + mockChatCompletion.Verify(a => a.GetChatMessageContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Once()); + } + + [Fact] + public async Task ItUsesChatClientIdWhenProvidedInMethodAsync() + { + // Arrange + var mockTextGeneration = new Mock(); + var mockChatCompletion = new Mock(); + var mockChatClient = new Mock(); + var fakeTextContent = new TextContent("llmResult"); + var fakeChatContent = new ChatMessageContent(AuthorRole.User, "content"); + var fakeChatResponse = new MEAI.ChatResponse(new MEAI.ChatMessage()); + + mockTextGeneration.Setup(c => c.GetTextContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync([fakeTextContent]); + mockChatCompletion.Setup(c => c.GetChatMessageContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).ReturnsAsync([fakeChatContent]); + mockChatClient.Setup(c => c.GetResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny())).ReturnsAsync(fakeChatResponse); + mockChatClient.Setup(c => c.GetService(typeof(MEAI.ChatClientMetadata), It.IsAny())).Returns(new MEAI.ChatClientMetadata()); + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddKeyedSingleton("service1", mockTextGeneration.Object); + builder.Services.AddKeyedSingleton("service2", mockChatClient.Object); + builder.Services.AddKeyedSingleton("service3", mockChatCompletion.Object); + Kernel kernel = builder.Build(); + + var func = kernel.CreateFunctionFromPrompt("my prompt", [new PromptExecutionSettings { ServiceId = "service2" }]); + + // Act + await kernel.InvokeAsync(func); + + // Assert + mockTextGeneration.Verify(a => a.GetTextContentsAsync("my prompt", It.IsAny(), It.IsAny(), It.IsAny()), Times.Never()); + mockChatCompletion.Verify(a => a.GetChatMessageContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny()), Times.Never()); + mockChatClient.Verify(a => a.GetResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny()), Times.Once()); } [Fact] @@ -194,7 +226,7 @@ public async Task ItFailsIfInvalidServiceIdIsProvidedAsync() var exception = await Assert.ThrowsAsync(() => kernel.InvokeAsync(func)); // Assert - Assert.Equal("Required service of type Microsoft.SemanticKernel.TextGeneration.ITextGenerationService not registered. Expected serviceIds: service3.", exception.Message); + Assert.Contains("Expected serviceIds: service3.", exception.Message); } [Fact] @@ -219,6 +251,28 @@ public async Task ItParsesStandardizedPromptWhenServiceIsChatCompletionAsync() Assert.Equal("How many 20 cents can I get from 1 dollar?", fakeService.ChatHistory[1].Content); } + [Fact] + public async Task ItParsesStandardizedPromptWhenServiceIsChatClientAsync() + { + using var fakeService = new FakeChatClient(); + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + await kernel.InvokeAsync(function); + + Assert.NotNull(fakeService.ChatMessages); + Assert.Equal(2, fakeService.ChatMessages.Count); + Assert.Equal("You are a helpful assistant.", fakeService.ChatMessages[0].Text); + Assert.Equal("How many 20 cents can I get from 1 dollar?", fakeService.ChatMessages[1].Text); + } + [Fact] public async Task ItParsesStandardizedPromptWhenServiceIsStreamingChatCompletionAsync() { @@ -342,6 +396,79 @@ public async Task InvokeAsyncReturnsTheConnectorChatResultWhenInServiceIsOnlyCha Assert.Equal("something", result.GetValue()!.ToString()); } + [Fact] + public async Task InvokeAsyncReturnsTheConnectorChatResultWhenInServiceIsOnlyChatClientAsync() + { + var customTestType = new CustomTestType(); + var fakeChatMessage = new MEAI.ChatMessage(MEAI.ChatRole.User, "something") { RawRepresentation = customTestType }; + var fakeChatResponse = new MEAI.ChatResponse(fakeChatMessage); + Mock mockChatClient = new(); + mockChatClient.Setup(c => c.GetResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny())).ReturnsAsync(fakeChatResponse); + mockChatClient.Setup(c => c.GetService(typeof(MEAI.ChatClientMetadata), It.IsAny())).Returns(new MEAI.ChatClientMetadata()); + + using var chatClient = mockChatClient.Object; + KernelBuilder builder = new(); + builder.Services.AddTransient((sp) => chatClient); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt("Anything"); + + var result = await kernel.InvokeAsync(function); + + Assert.Equal("something", result.GetValue()); + Assert.Equal("something", result.GetValue()!.Text); + Assert.Equal(MEAI.ChatRole.User, result.GetValue()!.Role); + Assert.Same(customTestType, result.GetValue()!); + Assert.Equal("something", result.GetValue()!.ToString()); + Assert.Equal("something", result.GetValue()!.ToString()); + } + + [Fact] + public async Task InvokeAsyncReturnsTheConnectorChatResultMessagesWhenInServiceIsOnlyChatClientAsync() + { + var firstMessageContent = "something 1"; + var lastMessageContent = "something 2"; + + var customTestType = new CustomTestType(); + var fakeChatResponse = new MEAI.ChatResponse([ + new MEAI.ChatMessage(MEAI.ChatRole.User, firstMessageContent), + new MEAI.ChatMessage(MEAI.ChatRole.Assistant, lastMessageContent) { RawRepresentation = customTestType } + ]); + + Mock mockChatClient = new(); + mockChatClient.Setup(c => c.GetResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny())).ReturnsAsync(fakeChatResponse); + mockChatClient.Setup(c => c.GetService(typeof(MEAI.ChatClientMetadata), It.IsAny())).Returns(new MEAI.ChatClientMetadata()); + + using var chatClient = mockChatClient.Object; + KernelBuilder builder = new(); + builder.Services.AddTransient((sp) => chatClient); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt("Anything"); + + var result = await kernel.InvokeAsync(function); + + var response = result.GetValue(); + Assert.NotNull(response); + Assert.Collection(response.Messages, + item1 => + { + Assert.Equal(firstMessageContent, item1.Text); Assert.Equal(MEAI.ChatRole.User, item1.Role); + }, + item2 => + { + Assert.Equal(lastMessageContent, item2.Text); Assert.Equal(MEAI.ChatRole.Assistant, item2.Role); + }); + + // Other specific types will be checked against the first choice and last message + Assert.Equal(lastMessageContent, result.GetValue()); + Assert.Equal(lastMessageContent, result.GetValue()!.Text); + Assert.Equal(MEAI.ChatRole.Assistant, result.GetValue()!.Role); + Assert.Same(customTestType, result.GetValue()!); + Assert.Equal(lastMessageContent, result.GetValue()!.ToString()); + Assert.Equal(lastMessageContent, result.GetValue()!.ToString()); + } + [Fact] public async Task InvokeAsyncReturnsTheConnectorChatResultWhenInServiceIsChatAndTextCompletionAsync() { @@ -947,6 +1074,422 @@ public async Task ItCanBeCloned(JsonSerializerOptions? jsos) Assert.Equal("Prompt with a variable", result); } + [Fact] + public async Task ItCanRetrieveDirectMEAIChatMessageUpdatesAsync() + { + using var fakeService = new FakeChatClient() + { + GetStreamingResponseResult = [ + new MEAI.ChatResponseUpdate(MEAI.ChatRole.Assistant, "Hi! How can "), + new MEAI.ChatResponseUpdate(role: null, content: "I assist you today?") + ] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(KernelFunctionFromPrompt.Create("Prompt with {{$A}} variable"))) + { + Assert.Same(fakeService.GetStreamingResponseResult![updateIndex], update); + Assert.Equal(fakeService.GetStreamingResponseResult![updateIndex].Text, update.Text); + + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingResponseResult.Count); + } + + [Fact] + public async Task ItCanRetrieveDirectMEAITextContentAsync() + { + using var fakeService = new FakeChatClient() + { + GetStreamingResponseResult = [ + new MEAI.ChatResponseUpdate(MEAI.ChatRole.Assistant, "Hi! How can "), + new MEAI.ChatResponseUpdate(role: null, content: "I assist you today?") + ] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(KernelFunctionFromPrompt.Create("Prompt with {{$A}} variable"))) + { + Assert.Same(fakeService.GetStreamingResponseResult![updateIndex].Contents[0], update); + + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingResponseResult.Count); + } + + [Fact] + public async Task ItCanRetrieveDirectMEAIStringAsync() + { + using var fakeService = new FakeChatClient() + { + GetStreamingResponseResult = [ + new MEAI.ChatResponseUpdate(MEAI.ChatRole.Assistant, "Hi! How can "), + new MEAI.ChatResponseUpdate(role: null, content: "I assist you today?") + ] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(KernelFunctionFromPrompt.Create("Prompt with {{$A}} variable"))) + { + Assert.Equal(fakeService.GetStreamingResponseResult![updateIndex].Text, update); + + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingResponseResult.Count); + } + + [Fact] + public async Task ItCanRetrieveDirectMEAIRawRepresentationAsync() + { + var rawRepresentation = OpenAI.Chat.OpenAIChatModelFactory.StreamingChatCompletionUpdate(contentUpdate: new OpenAI.Chat.ChatMessageContent("Hi!")); + using var fakeService = new FakeChatClient() + { + GetStreamingResponseResult = [ + new MEAI.ChatResponseUpdate(role: MEAI.ChatRole.Assistant, content: "Hi! How can ") { RawRepresentation = rawRepresentation }, + new MEAI.ChatResponseUpdate(role: null, content: "I assist you today?") { RawRepresentation = rawRepresentation } + ] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(KernelFunctionFromPrompt.Create("Prompt with {{$A}} variable"))) + { + Assert.Same(fakeService.GetStreamingResponseResult![updateIndex].RawRepresentation, update); + + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingResponseResult.Count); + } + + [Fact] + public async Task ItCanRetrieveDirectMEAIContentListAsync() + { + var rawRepresentation = OpenAI.Chat.OpenAIChatModelFactory.StreamingChatCompletionUpdate(contentUpdate: new OpenAI.Chat.ChatMessageContent("Hi!")); + using var fakeService = new FakeChatClient() + { + GetStreamingResponseResult = [ + new MEAI.ChatResponseUpdate(role: MEAI.ChatRole.Assistant, content: "Hi! How can ") { RawRepresentation = rawRepresentation }, + new MEAI.ChatResponseUpdate(role: null, content: "I assist you today?") { RawRepresentation = rawRepresentation } + ] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync>(KernelFunctionFromPrompt.Create("Prompt with {{$A}} variable"))) + { + Assert.Same(fakeService.GetStreamingResponseResult![updateIndex].Contents, update); + + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingResponseResult.Count); + } + + [Fact] + public async Task ItConvertsFromMEAIChatMessageUpdateToSKStreamingChatMessageContentAsync() + { + var rawRepresentation = new { test = "a" }; + using var fakeService = new FakeChatClient() + { + GetStreamingResponseResult = [ + new MEAI.ChatResponseUpdate(MEAI.ChatRole.Assistant, "Hi! How can ") { RawRepresentation = rawRepresentation }, + new MEAI.ChatResponseUpdate(role: null, content: "I assist you today?") { RawRepresentation = rawRepresentation } + ] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(function)) + { + Assert.Equal(fakeService.GetStreamingResponseResult![updateIndex].Text, update.Content); + Assert.Same(fakeService.GetStreamingResponseResult![updateIndex].RawRepresentation, update.InnerContent); + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingResponseResult.Count); + } + + [Fact] + public async Task ItConvertsFromMEAIChatMessageUpdateToSKStreamingContentAsync() + { + var rawRepresentation = new { test = "a" }; + using var fakeService = new FakeChatClient() + { + GetStreamingResponseResult = [ + new MEAI.ChatResponseUpdate(MEAI.ChatRole.Assistant, "Hi! How can ") { RawRepresentation = rawRepresentation }, + new MEAI.ChatResponseUpdate(role: null, content: "I assist you today?") { RawRepresentation = rawRepresentation } + ] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(function)) + { + var streamingChatContent = Assert.IsType(update); + Assert.Same(fakeService.GetStreamingResponseResult![updateIndex].RawRepresentation, update.InnerContent); + + Assert.Equal(fakeService.GetStreamingResponseResult![updateIndex].Text, streamingChatContent.Content); + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingResponseResult.Count); + } + + [Fact] + public async Task ItConvertsFromSKStreamingChatMessageContentToMEAIChatResponseUpdate() + { + var innerContent = new { test = "a" }; + var fakeService = new FakeChatCompletionService() + { + GetStreamingChatMessageContentsResult = [ + new StreamingChatMessageContent(AuthorRole.Assistant, "Hi! How can ") { InnerContent = innerContent }, + new StreamingChatMessageContent(null, "I assist you today?") { InnerContent = innerContent } + ] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(function)) + { + Assert.Same(fakeService.GetStreamingChatMessageContentsResult![updateIndex].InnerContent, update.RawRepresentation); + + Assert.Equal(fakeService.GetStreamingChatMessageContentsResult![updateIndex].Content, update.Text); + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingChatMessageContentsResult.Count); + } + + [Fact] + public async Task ItConvertsFromSKStreamingChatMessageContentToStringAsync() + { + var innerContent = new { test = "a" }; + var fakeService = new FakeChatCompletionService() + { + GetStreamingChatMessageContentsResult = [ + new StreamingChatMessageContent(AuthorRole.Assistant, "Hi! How can ") { InnerContent = innerContent }, + new StreamingChatMessageContent(null, "I assist you today?") { InnerContent = innerContent } + ] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(function)) + { + Assert.Equal(fakeService.GetStreamingChatMessageContentsResult![updateIndex].Content, update); + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingChatMessageContentsResult.Count); + } + + [Fact] + public async Task ItConvertsFromSKStreamingChatMessageContentToItselfAsync() + { + var innerContent = new { test = "a" }; + var fakeService = new FakeChatCompletionService() + { + GetStreamingChatMessageContentsResult = [ + new StreamingChatMessageContent(AuthorRole.Assistant, "Hi! How can ") { InnerContent = innerContent }, + new StreamingChatMessageContent(null, "I assist you today?") { InnerContent = innerContent } + ] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(function)) + { + Assert.Same(fakeService.GetStreamingChatMessageContentsResult![updateIndex], update); + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingChatMessageContentsResult.Count); + } + + [Fact] + public async Task ItConvertsFromSKStreamingChatMessageContentToInnerContentAsync() + { + var innerContent = new Random(); + var fakeService = new FakeChatCompletionService() + { + GetStreamingChatMessageContentsResult = [ + new StreamingChatMessageContent(AuthorRole.Assistant, "Hi! How can ") { InnerContent = innerContent }, + new StreamingChatMessageContent(null, "I assist you today?") { InnerContent = innerContent } + ] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(function)) + { + Assert.Same(fakeService.GetStreamingChatMessageContentsResult![updateIndex].InnerContent, update); + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingChatMessageContentsResult.Count); + } + + [Fact] + public async Task ItConvertsFromSKStreamingChatMessageContentToBytesAsync() + { + var innerContent = new Random(); + var fakeService = new FakeChatCompletionService() + { + GetStreamingChatMessageContentsResult = [ + new StreamingChatMessageContent(AuthorRole.Assistant, "Hi! How can ") { InnerContent = innerContent }, + new StreamingChatMessageContent(null, "I assist you today?") { InnerContent = innerContent } + ] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + var updateIndex = 0; + await foreach (var update in kernel.InvokeStreamingAsync(function)) + { + Assert.Equal(fakeService.GetStreamingChatMessageContentsResult![updateIndex].Content, + fakeService.GetStreamingChatMessageContentsResult![updateIndex].Encoding.GetString(update)); + + updateIndex++; + } + + Assert.Equal(updateIndex, fakeService.GetStreamingChatMessageContentsResult.Count); + } + + /// + /// This scenario covers scenarios on attempting to get a ChatResponseUpdate from a ITextGenerationService. + /// + [Fact] + public async Task ItThrowsConvertingFromNonChatSKStreamingContentToMEAIChatResponseUpdate() + { + var fakeService = new FakeTextGenerationService() + { + GetStreamingTextContentsResult = [new StreamingTextContent("Hi!")] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt("How many 20 cents can I get from 1 dollar?"); + + // Act + Assert + await Assert.ThrowsAsync( + () => kernel.InvokeStreamingAsync(function).GetAsyncEnumerator().MoveNextAsync().AsTask()); + } + + [Fact] + public async Task ItThrowsWhenConvertingFromMEAIChatMessageUpdateWithNoDataContentToBytesAsync() + { + using var fakeService = new FakeChatClient() + { + GetStreamingResponseResult = [ + new MEAI.ChatResponseUpdate(MEAI.ChatRole.Assistant, "Hi! How can ") + ] + }; + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddTransient((sp) => fakeService); + Kernel kernel = builder.Build(); + + KernelFunction function = KernelFunctionFactory.CreateFromPrompt(""" + You are a helpful assistant. + How many 20 cents can I get from 1 dollar? + """); + + // Act + Assert + await Assert.ThrowsAsync( + () => kernel.InvokeStreamingAsync(function).GetAsyncEnumerator().MoveNextAsync().AsTask()); + } + public enum KernelInvocationType { InvokePrompt, @@ -990,6 +1533,93 @@ public Task> GetTextContentsAsync(string prompt, Prom } } + private sealed class FakeChatCompletionService : IChatCompletionService + { + public IReadOnlyDictionary Attributes => throw new NotImplementedException(); + + public IList? GetStreamingChatMessageContentsResult { get; set; } + + public Task> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + public async IAsyncEnumerable GetStreamingChatMessageContentsAsync( + ChatHistory chatHistory, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + foreach (var item in this.GetStreamingChatMessageContentsResult ?? [new StreamingChatMessageContent(AuthorRole.Assistant, "Something")]) + { + yield return item; + } + } +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + } + + private sealed class FakeTextGenerationService : ITextGenerationService + { + public IReadOnlyDictionary Attributes => throw new NotImplementedException(); + + public IList? GetStreamingTextContentsResult { get; set; } + +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + public async IAsyncEnumerable GetStreamingTextContentsAsync( + string prompt, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + foreach (var item in this.GetStreamingTextContentsResult ?? [new StreamingTextContent("Something")]) + { + yield return item; + } + } +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + + public Task> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + } + + private sealed class FakeChatClient : MEAI.IChatClient + { + public List? ChatMessages { get; private set; } + public List? GetStreamingResponseResult { get; set; } + + public void Dispose() + { + } + + public Task GetResponseAsync(IEnumerable messages, MEAI.ChatOptions? options = null, CancellationToken cancellationToken = default) + { + this.ChatMessages = messages.ToList(); + return Task.FromResult(new MEAI.ChatResponse(new MEAI.ChatMessage(MEAI.ChatRole.Assistant, "Something"))); + } + + public object? GetService(Type serviceType, object? serviceKey = null) + { + return new MEAI.ChatClientMetadata(); + } + +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + public async IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, + MEAI.ChatOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + this.ChatMessages = messages.ToList(); + foreach (var item in this.GetStreamingResponseResult ?? [new MEAI.ChatResponseUpdate(MEAI.ChatRole.Assistant, "Something")]) + { + yield return item; + } + } +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + } + private Mock GetMockTextGenerationService(IReadOnlyList? textContents = null) { var mockTextGenerationService = new Mock(); @@ -1012,5 +1642,9 @@ private Mock GetMockChatCompletionService(IReadOnlyList< return mockChatCompletionService; } + private sealed class CustomTestType + { + public string Name { get; set; } = "MyCustomType"; + } #endregion } diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionTests.cs new file mode 100644 index 000000000000..42e8f9b61ded --- /dev/null +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/KernelFunctionTests.cs @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Moq; +using Xunit; + +namespace Microsoft.SemanticKernel.UnitTests.Functions; + +/// +/// Tests for cloning with a instance. +/// +public class KernelFunctionTests +{ + private readonly Mock _loggerFactory = new(); + + [Fact] + public async Task ClonedFunctionWithKernelUsesProvidedKernelWhenInvokedWithoutKernel() + { + // Arrange + // Create a function that will return the kernel's ID + var function = KernelFunctionFactory.CreateFromMethod( + (Kernel kernel) => kernel.Data.TryGetValue("id", out var id) ? id?.ToString() ?? string.Empty : string.Empty, + functionName: "GetKernelId", + description: "Gets the ID of the kernel used for invocation", + loggerFactory: this._loggerFactory.Object); + + // Create two kernels with different IDs + var kernel1 = new Kernel(); + kernel1.Data["id"] = "kernel1"; + + var kernel2 = new Kernel(); + kernel2.Data["id"] = "kernel2"; + + // Clone the function with kernel2 + var clonedFunction = function.WithKernel(kernel2); + + // Act + // Invoke the cloned function without providing a kernel + var result = await clonedFunction.InvokeAsync(); + + // Assert + // The function should have used kernel2 + Assert.NotNull(result); + Assert.Equal("kernel2", result.ToString()); + } + + [Fact] + public async Task ClonedFunctionWithKernelUsesProvidedKernelWhenInvokedWithNullKernel() + { + // Arrange + // Create a function that will return the kernel's ID + var function = KernelFunctionFactory.CreateFromMethod( + (Kernel kernel) => kernel.Data["id"]!.ToString(), + functionName: "GetKernelId", + description: "Gets the ID of the kernel used for invocation", + loggerFactory: this._loggerFactory.Object); + + // Create two kernels with different IDs + var kernel1 = new Kernel(); + kernel1.Data["id"] = "kernel1"; + + var kernel2 = new Kernel(); + kernel2.Data["id"] = "kernel2"; + + // Clone the function with kernel2 + var clonedFunction = function.WithKernel(kernel2); + + // Act + // Invoke the cloned function with null kernel + var result = await clonedFunction.InvokeAsync(kernel: null!); + + // Assert + // The function should have used kernel2 + Assert.Equal("kernel2", result.GetValue()); + } + + [Fact] + public async Task ClonedFunctionWithKernelUsesExplicitKernelWhenProvidedInInvoke() + { + // Arrange + // Create a function that will return the kernel's ID + var function = KernelFunctionFactory.CreateFromMethod( + (Kernel kernel) => kernel.Data.TryGetValue("id", out var id) ? id?.ToString() ?? string.Empty : string.Empty, + functionName: "GetKernelId", + description: "Gets the ID of the kernel used for invocation", + loggerFactory: this._loggerFactory.Object); + + // Create two kernels with different IDs + var kernel1 = new Kernel(); + kernel1.Data["id"] = "kernel1"; + + var kernel2 = new Kernel(); + kernel2.Data["id"] = "kernel2"; + + // Clone the function with kernel2 + var clonedFunction = function.WithKernel(kernel2); + + // Act + // Invoke the cloned function with kernel1 explicitly + var result = await clonedFunction.InvokeAsync(kernel: kernel1); + + // Assert + // The function should have used kernel1, not kernel2 + Assert.Equal("kernel1", result.GetValue()); + } + + [Fact] + public async Task ClonedFunctionWithKernelUsesProvidedKernelWhenInvokedWithArguments() + { + // Arrange + // Create a function that will return the kernel's ID + var function = KernelFunctionFactory.CreateFromMethod( + (Kernel kernel) => kernel.Data.TryGetValue("id", out var id) ? id?.ToString() ?? string.Empty : string.Empty, + functionName: "GetKernelId", + description: "Gets the ID of the kernel used for invocation", + loggerFactory: this._loggerFactory.Object); + + // Create two kernels with different IDs + var kernel1 = new Kernel(); + kernel1.Data["id"] = "kernel1"; + + var kernel2 = new Kernel(); + kernel2.Data["id"] = "kernel2"; + + // Clone the function with kernel2 + var clonedFunction = function.WithKernel(kernel2); + + // Act + // Invoke the cloned function with just arguments + var result = await clonedFunction.InvokeAsync(); + + // Assert + // The function should have used kernel2 + Assert.NotNull(result); + Assert.Equal("kernel2", result.ToString()); + } + + [Fact] + public async Task ClonedFunctionWithKernelUsesExplicitKernelWhenProvidedInArguments() + { + // Arrange + // Create a function that will return the kernel's ID + var function = KernelFunctionFactory.CreateFromMethod( + (Kernel kernel) => kernel.Data["id"]!.ToString(), + functionName: "GetKernelId", + description: "Gets the ID of the kernel used for invocation", + loggerFactory: this._loggerFactory.Object); + + // Create two kernels with different IDs + var kernel1 = new Kernel(); + kernel1.Data["id"] = "kernel1"; + + var kernel2 = new Kernel(); + kernel2.Data["id"] = "kernel2"; + + // Clone the function with kernel2 + var clonedFunction = function.WithKernel(kernel2); + + // Act + // Invoke the cloned function with kernel1 explicitly + var result = await clonedFunction.InvokeAsync(kernel1, new KernelArguments()); + + // Assert + // The function should have used kernel1, not kernel2 + Assert.Equal("kernel1", result.GetValue()); + } + + [Fact] + public async Task NonClonedFunctionThrowsExceptionWhenInvokedWithNullKernel() + { + // Arrange + // Create a function that will return the kernel's ID + var function = KernelFunctionFactory.CreateFromMethod( + (Kernel kernel) => kernel.Data.TryGetValue("id", out var id) ? id?.ToString() ?? string.Empty : string.Empty, + functionName: "GetKernelId", + description: "Gets the ID of the kernel used for invocation", + loggerFactory: this._loggerFactory.Object); + + // Act & Assert + // Invoke the function with null kernel (without cloning it first) + // This should throw an ArgumentNullException because the function requires a kernel and none is provided + var exception = await Assert.ThrowsAsync( + async () => await function.InvokeAsync(kernel: null!)); + + // Verify the exception parameter name is 'kernel' + Assert.Equal("kernel", exception.ParamName); + } +} diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/MultipleModelTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/MultipleModelTests.cs index 40121103ce69..1c585726f82d 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/MultipleModelTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/MultipleModelTests.cs @@ -60,7 +60,7 @@ public async Task ItFailsIfInvalidServiceIdIsProvidedAsync() var exception = await Assert.ThrowsAsync(() => kernel.InvokeAsync(func)); // Assert - Assert.Equal("Required service of type Microsoft.SemanticKernel.TextGeneration.ITextGenerationService not registered. Expected serviceIds: service3.", exception.Message); + Assert.Contains("Expected serviceIds: service3.", exception.Message); } [Theory] diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs index eafac8ac5ca3..52619c22af6e 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs @@ -4,10 +4,13 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Services; using Microsoft.SemanticKernel.TextGeneration; +using Moq; using Xunit; namespace SemanticKernel.UnitTests.Functions; @@ -25,6 +28,7 @@ public void ItThrowsAKernelExceptionForNoServices() // Act // Assert Assert.Throws(() => serviceSelector.SelectAIService(kernel, function, [])); + Assert.Throws(() => serviceSelector.SelectAIService(kernel, function, [])); } [Fact] @@ -46,6 +50,27 @@ public void ItGetsAIServiceConfigurationForSingleAIService() Assert.Null(defaultExecutionSettings); } + [Fact] + public void ItGetsChatClientConfigurationForSingleChatClient() + { + // Arrange + var mockChat = new Mock(); + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddKeyedSingleton("chat1", mockChat.Object); + Kernel kernel = builder.Build(); + + var function = kernel.CreateFunctionFromPrompt("Hello AI"); + var serviceSelector = new OrderedAIServiceSelector(); + + // Act + serviceSelector.TrySelectChatClient(kernel, function, [], out var chatClient, out var defaultExecutionSettings); + chatClient?.Dispose(); + + // Assert + Assert.NotNull(chatClient); + Assert.Null(defaultExecutionSettings); + } + [Fact] public void ItGetsAIServiceConfigurationForSingleTextGeneration() { @@ -90,13 +115,41 @@ public void ItGetsAIServiceConfigurationForTextGenerationByServiceId() Assert.Equivalent(expectedExecutionSettings, defaultExecutionSettings); } + [Fact] + public void ItGetsChatClientConfigurationForChatClientByServiceId() + { + // Arrange + IKernelBuilder builder = Kernel.CreateBuilder(); + using var chatClient1 = new ChatClient("model_id_1"); + using var chatClient2 = new ChatClient("model_id_2"); + builder.Services.AddKeyedSingleton("chat1", chatClient1); + builder.Services.AddKeyedSingleton("chat2", chatClient2); + Kernel kernel = builder.Build(); + + var promptConfig = new PromptTemplateConfig() { Template = "Hello AI" }; + var executionSettings = new PromptExecutionSettings(); + promptConfig.AddExecutionSettings(executionSettings, "chat2"); + var function = kernel.CreateFunctionFromPrompt(promptConfig); + var serviceSelector = new OrderedAIServiceSelector(); + + // Act + serviceSelector.TrySelectChatClient(kernel, function, [], out var aiService, out var defaultExecutionSettings); + aiService?.Dispose(); + + // Assert + Assert.Equal(kernel.GetRequiredService("chat2"), aiService); + var expectedExecutionSettings = executionSettings.Clone(); + expectedExecutionSettings.Freeze(); + Assert.Equivalent(expectedExecutionSettings, defaultExecutionSettings); + } + [Fact] public void ItThrowsAKernelExceptionForNotFoundService() { // Arrange IKernelBuilder builder = Kernel.CreateBuilder(); builder.Services.AddKeyedSingleton("service1", new TextGenerationService("model_id_1")); - builder.Services.AddKeyedSingleton("service2", new TextGenerationService("model_id_2")); + builder.Services.AddKeyedSingleton("service2", new ChatCompletionService("model_id_2")); Kernel kernel = builder.Build(); var promptConfig = new PromptTemplateConfig() { Template = "Hello AI" }; @@ -107,6 +160,7 @@ public void ItThrowsAKernelExceptionForNotFoundService() // Act // Assert Assert.Throws(() => serviceSelector.SelectAIService(kernel, function, [])); + Assert.Throws(() => serviceSelector.SelectAIService(kernel, function, [])); } [Fact] @@ -129,6 +183,30 @@ public void ItGetsDefaultServiceForNotFoundModel() Assert.Equal(kernel.GetRequiredService("service2"), aiService); } + [Fact] + public void ItGetsDefaultChatClientForNotFoundModel() + { + // Arrange + IKernelBuilder builder = Kernel.CreateBuilder(); + using var chatClient1 = new ChatClient("model_id_1"); + using var chatClient2 = new ChatClient("model_id_2"); + builder.Services.AddKeyedSingleton("chat1", chatClient1); + builder.Services.AddKeyedSingleton("chat2", chatClient2); + Kernel kernel = builder.Build(); + + var promptConfig = new PromptTemplateConfig() { Template = "Hello AI" }; + promptConfig.AddExecutionSettings(new PromptExecutionSettings { ModelId = "notfound" }); + var function = kernel.CreateFunctionFromPrompt(promptConfig); + var serviceSelector = new OrderedAIServiceSelector(); + + // Act + // Assert + serviceSelector.TrySelectChatClient(kernel, function, [], out var aiService, out var defaultExecutionSettings); + aiService?.Dispose(); + + Assert.Equal(kernel.GetRequiredService("chat2"), aiService); + } + [Fact] public void ItUsesDefaultServiceForNoExecutionSettings() { @@ -148,6 +226,28 @@ public void ItUsesDefaultServiceForNoExecutionSettings() Assert.Null(defaultExecutionSettings); } + [Fact] + public void ItUsesDefaultChatClientForNoExecutionSettings() + { + // Arrange + IKernelBuilder builder = Kernel.CreateBuilder(); + using var chatClient1 = new ChatClient("model_id_1"); + using var chatClient2 = new ChatClient("model_id_2"); + builder.Services.AddKeyedSingleton("chat1", chatClient1); + builder.Services.AddKeyedSingleton("chat2", chatClient2); + Kernel kernel = builder.Build(); + var function = kernel.CreateFunctionFromPrompt("Hello AI"); + var serviceSelector = new OrderedAIServiceSelector(); + + // Act + serviceSelector.TrySelectChatClient(kernel, function, [], out var aiService, out var defaultExecutionSettings); + aiService?.Dispose(); + + // Assert + Assert.Equal(kernel.GetRequiredService("chat2"), aiService); + Assert.Null(defaultExecutionSettings); + } + [Fact] public void ItUsesDefaultServiceAndSettingsForDefaultExecutionSettings() { @@ -171,6 +271,32 @@ public void ItUsesDefaultServiceAndSettingsForDefaultExecutionSettings() Assert.Equivalent(expectedExecutionSettings, defaultExecutionSettings); } + [Fact] + public void ItUsesDefaultChatClientAndSettingsForDefaultExecutionSettings() + { + // Arrange + IKernelBuilder builder = Kernel.CreateBuilder(); + using var chatClient1 = new ChatClient("model_id_1"); + using var chatClient2 = new ChatClient("model_id_2"); + builder.Services.AddKeyedSingleton("chat1", chatClient1); + builder.Services.AddKeyedSingleton("chat2", chatClient2); + Kernel kernel = builder.Build(); + + var executionSettings = new PromptExecutionSettings(); + var function = kernel.CreateFunctionFromPrompt("Hello AI", executionSettings); + var serviceSelector = new OrderedAIServiceSelector(); + + // Act + serviceSelector.TrySelectChatClient(kernel, function, [], out var aiService, out var defaultExecutionSettings); + aiService?.Dispose(); + + // Assert + Assert.Equal(kernel.GetRequiredService("chat2"), aiService); + var expectedExecutionSettings = executionSettings.Clone(); + expectedExecutionSettings.Freeze(); + Assert.Equivalent(expectedExecutionSettings, defaultExecutionSettings); + } + [Fact] public void ItUsesDefaultServiceAndSettingsForDefaultId() { @@ -194,6 +320,32 @@ public void ItUsesDefaultServiceAndSettingsForDefaultId() Assert.Equivalent(expectedExecutionSettings, defaultExecutionSettings); } + [Fact] + public void ItUsesDefaultChatClientAndSettingsForDefaultId() + { + // Arrange + IKernelBuilder builder = Kernel.CreateBuilder(); + using var chatClient1 = new ChatClient("model_id_1"); + using var chatClient2 = new ChatClient("model_id_2"); + builder.Services.AddKeyedSingleton("chat1", chatClient1); + builder.Services.AddKeyedSingleton("chat2", chatClient2); + Kernel kernel = builder.Build(); + + var executionSettings = new PromptExecutionSettings(); + var function = kernel.CreateFunctionFromPrompt("Hello AI", executionSettings); + var serviceSelector = new OrderedAIServiceSelector(); + + // Act + serviceSelector.TrySelectChatClient(kernel, function, [], out var aiService, out var defaultExecutionSettings); + aiService?.Dispose(); + + // Assert + Assert.Equal(kernel.GetRequiredService("chat2"), aiService); + var expectedExecutionSettings = executionSettings.Clone(); + expectedExecutionSettings.Freeze(); + Assert.Equivalent(expectedExecutionSettings, defaultExecutionSettings); + } + [Theory] [InlineData(new string[] { "modelid_1" }, "modelid_1")] [InlineData(new string[] { "modelid_2" }, "modelid_2")] @@ -228,6 +380,44 @@ public void ItGetsAIServiceConfigurationByOrder(string[] serviceIds, string expe } } + [Theory] + [InlineData(new string[] { "modelid_1" }, "modelid_1")] + [InlineData(new string[] { "modelid_2" }, "modelid_2")] + [InlineData(new string[] { "modelid_3" }, "modelid_3")] + [InlineData(new string[] { "modelid_4", "modelid_1" }, "modelid_1")] + [InlineData(new string[] { "modelid_4", "" }, "modelid_3")] + public void ItGetsChatClientConfigurationByOrder(string[] serviceIds, string expectedModelId) + { + // Arrange + IKernelBuilder builder = Kernel.CreateBuilder(); + using var chatClient1 = new ChatClient("modelid_1"); + using var chatClient2 = new ChatClient("modelid_2"); + using var chatClient3 = new ChatClient("modelid_3"); + builder.Services.AddKeyedSingleton("modelid_1", chatClient1); + builder.Services.AddKeyedSingleton("modelid_2", chatClient2); + builder.Services.AddKeyedSingleton("modelid_3", chatClient3); + Kernel kernel = builder.Build(); + + var executionSettings = new Dictionary(); + foreach (var serviceId in serviceIds) + { + executionSettings.Add(serviceId, new PromptExecutionSettings() { ModelId = serviceId }); + } + var function = kernel.CreateFunctionFromPrompt(promptConfig: new PromptTemplateConfig() { Template = "Hello AI", ExecutionSettings = executionSettings }); + var serviceSelector = new OrderedAIServiceSelector(); + + // Act + serviceSelector.TrySelectChatClient(kernel, function, [], out var aiService, out var defaultExecutionSettings); + aiService?.Dispose(); + + // Assert + Assert.Equal(kernel.GetRequiredService(expectedModelId), aiService); + if (!string.IsNullOrEmpty(defaultExecutionSettings!.ModelId)) + { + Assert.Equal(expectedModelId, defaultExecutionSettings!.ModelId); + } + } + [Fact] public void ItGetsAIServiceConfigurationForTextGenerationByModelId() { @@ -253,6 +443,34 @@ public void ItGetsAIServiceConfigurationForTextGenerationByModelId() Assert.Equivalent(expectedExecutionSettings, defaultExecutionSettings); } + [Fact] + public void ItGetsChatClientConfigurationForChatClientByModelId() + { + // Arrange + IKernelBuilder builder = Kernel.CreateBuilder(); + using var chatClient1 = new ChatClient("model1"); + using var chatClient2 = new ChatClient("model2"); + builder.Services.AddKeyedSingleton(null, chatClient1); + builder.Services.AddKeyedSingleton(null, chatClient2); + Kernel kernel = builder.Build(); + + var arguments = new KernelArguments(); + var executionSettings = new PromptExecutionSettings() { ModelId = "model2" }; + var function = kernel.CreateFunctionFromPrompt("Hello AI", executionSettings: executionSettings); + var serviceSelector = new OrderedAIServiceSelector(); + + // Act + serviceSelector.TrySelectChatClient(kernel, function, arguments, out var aiService, out var defaultExecutionSettings); + aiService?.Dispose(); + + // Assert + Assert.NotNull(aiService); + Assert.Equal("model2", aiService.GetModelId()); + var expectedExecutionSettings = executionSettings.Clone(); + expectedExecutionSettings.Freeze(); + Assert.Equivalent(expectedExecutionSettings, defaultExecutionSettings); + } + #region private private sealed class AIService : IAIService { @@ -270,7 +488,7 @@ public TextGenerationService(string modelId) this._attributes.Add("ModelId", modelId); } - public Task> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + public Task> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) { throw new NotImplementedException(); } @@ -280,5 +498,73 @@ public IAsyncEnumerable GetStreamingTextContentsAsync(stri throw new NotImplementedException(); } } + + private sealed class ChatCompletionService : IChatCompletionService + { + public IReadOnlyDictionary Attributes => this._attributes; + + private readonly Dictionary _attributes = []; + + public ChatCompletionService(string modelId) + { + this._attributes.Add("ModelId", modelId); + } + + public Task> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public IAsyncEnumerable GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + } + + private sealed class ChatClient : IChatClient + { + public ChatClientMetadata Metadata { get; } + + public ChatClient(string modelId) + { + this.Metadata = new ChatClientMetadata(defaultModelId: modelId); + } + + public Task> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public IAsyncEnumerable GetStreamingTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + /// + public object? GetService(Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + return + serviceKey is not null ? null : + serviceType.IsInstanceOfType(this) ? this : + serviceType.IsInstanceOfType(this.Metadata) ? this.Metadata : + null; + } + + public void Dispose() + { + } + } #endregion } diff --git a/dotnet/src/SemanticKernel.UnitTests/KernelTests.cs b/dotnet/src/SemanticKernel.UnitTests/KernelTests.cs index b7ed4fc4a480..866d9d7c7d41 100644 --- a/dotnet/src/SemanticKernel.UnitTests/KernelTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/KernelTests.cs @@ -17,6 +17,7 @@ using Microsoft.SemanticKernel.TextGeneration; using Moq; using Xunit; +using MEAI = Microsoft.Extensions.AI; #pragma warning disable CS0618 // Events are deprecated @@ -257,6 +258,110 @@ public async Task InvokeStreamingAsyncCallsConnectorStreamingApiAsync() mockTextCompletion.Verify(m => m.GetStreamingTextContentsAsync(It.IsIn("Write a simple phrase about UnitTests importance"), It.IsAny(), It.IsAny(), It.IsAny()), Times.Exactly(1)); } + [Fact] + public async Task InvokeStreamingAsyncCallsWithMEAIContentsAndChatCompletionApiAsync() + { + // Arrange + var mockChatCompletion = this.SetupStreamingChatCompletionMocks( + new StreamingChatMessageContent(AuthorRole.User, "chunk1"), + new StreamingChatMessageContent(AuthorRole.User, "chunk2")); + + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddSingleton(mockChatCompletion.Object); + Kernel kernel = builder.Build(); + var prompt = "Write a simple phrase about UnitTests {{$input}}"; + var expectedPrompt = prompt.Replace("{{$input}}", "importance"); + var sut = KernelFunctionFactory.CreateFromPrompt(prompt); + var variables = new KernelArguments() { [InputParameterName] = "importance" }; + + var chunkCount = 0; + + // Act & Assert + await foreach (var chunk in sut.InvokeStreamingAsync(kernel, variables)) + { + Assert.Contains("chunk", chunk.Text); + chunkCount++; + } + + Assert.Equal(2, chunkCount); + mockChatCompletion.Verify(m => m.GetStreamingChatMessageContentsAsync(It.Is((m) => m[0].Content == expectedPrompt), It.IsAny(), It.IsAny(), It.IsAny()), Times.Exactly(1)); + } + + [Fact] + public async Task InvokeStreamingAsyncGenericPermutationsCallsConnectorChatClientAsync() + { + // Arrange + var customRawItem = new MEAI.ChatOptions(); + var mockChatClient = this.SetupStreamingChatClientMock( + new MEAI.ChatResponseUpdate(role: MEAI.ChatRole.Assistant, content: "chunk1") { RawRepresentation = customRawItem }, + new MEAI.ChatResponseUpdate(role: null, content: "chunk2") { RawRepresentation = customRawItem }); + IKernelBuilder builder = Kernel.CreateBuilder(); + builder.Services.AddSingleton(mockChatClient.Object); + Kernel kernel = builder.Build(); + var prompt = "Write a simple phrase about UnitTests {{$input}}"; + var expectedPrompt = prompt.Replace("{{$input}}", "importance"); + var sut = KernelFunctionFactory.CreateFromPrompt(prompt); + var variables = new KernelArguments() { [InputParameterName] = "importance" }; + + var totalChunksExpected = 0; + var totalInvocationTimesExpected = 0; + + // Act & Assert + totalInvocationTimesExpected++; + await foreach (var chunk in sut.InvokeStreamingAsync(kernel, variables)) + { + Assert.Contains("chunk", chunk); + totalChunksExpected++; + } + + totalInvocationTimesExpected++; + await foreach (var chunk in sut.InvokeStreamingAsync(kernel, variables)) + { + totalChunksExpected++; + Assert.Same(customRawItem, chunk.InnerContent); + } + + totalInvocationTimesExpected++; + await foreach (var chunk in sut.InvokeStreamingAsync(kernel, variables)) + { + Assert.Contains("chunk", chunk.Content); + Assert.Same(customRawItem, chunk.InnerContent); + totalChunksExpected++; + } + + totalInvocationTimesExpected++; + await foreach (var chunk in sut.InvokeStreamingAsync(kernel, variables)) + { + Assert.Contains("chunk", chunk.Text); + Assert.Same(customRawItem, chunk.RawRepresentation); + totalChunksExpected++; + } + + totalInvocationTimesExpected++; + await foreach (var chunk in sut.InvokeStreamingAsync(kernel, variables)) + { + Assert.Contains("chunk", chunk.ToString()); + totalChunksExpected++; + } + + totalInvocationTimesExpected++; + await foreach (var chunk in sut.InvokeStreamingAsync>(kernel, variables)) + { + Assert.Contains("chunk", chunk[0].ToString()); + totalChunksExpected++; + } + + totalInvocationTimesExpected++; + await foreach (var chunk in sut.InvokeStreamingAsync(kernel, variables)) + { + Assert.Contains("chunk", chunk.Text); + totalChunksExpected++; + } + + Assert.Equal(totalInvocationTimesExpected * 2, totalChunksExpected); + mockChatClient.Verify(m => m.GetStreamingResponseAsync(It.Is>((m) => m[0].Text == expectedPrompt), It.IsAny(), It.IsAny()), Times.Exactly(totalInvocationTimesExpected)); + } + [Fact] public async Task ValidateInvokeAsync() { @@ -316,6 +421,13 @@ public async IAsyncEnumerable GetStreamingChatMessa return (mockTextContent, mockTextCompletion); } + private Mock SetupStreamingChatCompletionMocks(params StreamingChatMessageContent[] streamingContents) + { + var mockChatCompletion = new Mock(); + mockChatCompletion.Setup(m => m.GetStreamingChatMessageContentsAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())).Returns(streamingContents.ToAsyncEnumerable()); + return mockChatCompletion; + } + private Mock SetupStreamingMocks(params StreamingTextContent[] streamingContents) { var mockTextCompletion = new Mock(); @@ -324,6 +436,15 @@ private Mock SetupStreamingMocks(params StreamingTextCon return mockTextCompletion; } + private Mock SetupStreamingChatClientMock(params MEAI.ChatResponseUpdate[] chatResponseUpdates) + { + var mockChatClient = new Mock(); + mockChatClient.Setup(m => m.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), It.IsAny())).Returns(chatResponseUpdates.ToAsyncEnumerable()); + mockChatClient.Setup(c => c.GetService(typeof(MEAI.ChatClientMetadata), It.IsAny())).Returns(new MEAI.ChatClientMetadata()); + + return mockChatClient; + } + private void AssertFilters(Kernel kernel1, Kernel kernel2) { var functionFilters1 = kernel1.GetAllServices().ToArray();