diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 7878b0d5b359..94e6c8de85ef 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -93,7 +93,6 @@ - @@ -109,7 +108,7 @@ - + diff --git a/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs b/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs index a0aa892a7802..be69fe412d5e 100644 --- a/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs +++ b/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs @@ -109,11 +109,11 @@ private Kernel CreateKernelWithTwoServices(bool useChatClient) { builder.Services.AddKeyedChatClient( ServiceKeyBad, - new OpenAI.OpenAIClient("bad-key").AsChatClient(TestConfiguration.OpenAI.ChatModelId)); + new OpenAI.OpenAIClient("bad-key").GetChatClient(TestConfiguration.OpenAI.ChatModelId).AsIChatClient()); builder.Services.AddKeyedChatClient( ServiceKeyGood, - new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey).AsChatClient(TestConfiguration.OpenAI.ChatModelId)); + new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey).GetChatClient(TestConfiguration.OpenAI.ChatModelId).AsIChatClient()); } else { @@ -122,14 +122,16 @@ private Kernel CreateKernelWithTwoServices(bool useChatClient) new Azure.AI.OpenAI.AzureOpenAIClient( new Uri(TestConfiguration.AzureOpenAI.Endpoint), new Azure.AzureKeyCredential("bad-key")) - .AsChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName)); + .GetChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName) + .AsIChatClient()); builder.Services.AddKeyedChatClient( ServiceKeyGood, new Azure.AI.OpenAI.AzureOpenAIClient( new Uri(TestConfiguration.AzureOpenAI.Endpoint), new Azure.AzureKeyCredential(TestConfiguration.AzureOpenAI.ApiKey)) - .AsChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName)); + .GetChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName) + .AsIChatClient()); } } else diff --git a/dotnet/samples/Concepts/Filtering/ChatClient_AutoFunctionInvocationFiltering.cs b/dotnet/samples/Concepts/Filtering/ChatClient_AutoFunctionInvocationFiltering.cs new file mode 100644 index 000000000000..1e053618a385 --- /dev/null +++ b/dotnet/samples/Concepts/Filtering/ChatClient_AutoFunctionInvocationFiltering.cs @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.OpenAI; + +namespace Filtering; + +public class ChatClient_AutoFunctionInvocationFiltering(ITestOutputHelper output) : BaseTest(output) +{ + /// + /// Shows how to use . + /// + [Fact] + public async Task UsingAutoFunctionInvocationFilter() + { + var builder = Kernel.CreateBuilder(); + + builder.AddOpenAIChatClient("gpt-4", TestConfiguration.OpenAI.ApiKey); + + // This filter outputs information about auto function invocation and returns overridden result. + builder.Services.AddSingleton(new AutoFunctionInvocationFilter(this.Output)); + + var kernel = builder.Build(); + + var function = KernelFunctionFactory.CreateFromMethod(() => "Result from function", "MyFunction"); + + kernel.ImportPluginFromFunctions("MyPlugin", [function]); + + var executionSettings = new OpenAIPromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Required([function], autoInvoke: true) + }; + + var result = await kernel.InvokePromptAsync("Invoke provided function and return result", new(executionSettings)); + + Console.WriteLine(result); + + // Output: + // Request sequence number: 0 + // Function sequence number: 0 + // Total number of functions: 1 + // Result from auto function invocation filter. + } + + /// + /// Shows how to get list of function calls by using . + /// + [Fact] + public async Task GetFunctionCallsWithFilterAsync() + { + var builder = Kernel.CreateBuilder(); + + builder.AddOpenAIChatCompletion("gpt-3.5-turbo-1106", TestConfiguration.OpenAI.ApiKey); + + builder.Services.AddSingleton(new FunctionCallsFilter(this.Output)); + + var kernel = builder.Build(); + + kernel.ImportPluginFromFunctions("HelperFunctions", + [ + kernel.CreateFunctionFromMethod(() => DateTime.UtcNow.ToString("R"), "GetCurrentUtcTime", "Retrieves the current time in UTC."), + kernel.CreateFunctionFromMethod((string cityName) => + cityName switch + { + "Boston" => "61 and rainy", + "London" => "55 and cloudy", + "Miami" => "80 and sunny", + "Paris" => "60 and rainy", + "Tokyo" => "50 and sunny", + "Sydney" => "75 and sunny", + "Tel Aviv" => "80 and sunny", + _ => "31 and snowing", + }, "GetWeatherForCity", "Gets the current weather for the specified city"), + ]); + + var executionSettings = new OpenAIPromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + }; + + await foreach (var chunk in kernel.InvokePromptStreamingAsync("Check current UTC time and return current weather in Boston city.", new(executionSettings))) + { + Console.WriteLine(chunk.ToString()); + } + + // Output: + // Request #0. Function call: HelperFunctions.GetCurrentUtcTime. + // Request #0. Function call: HelperFunctions.GetWeatherForCity. + // The current UTC time is {time of execution}, and the current weather in Boston is 61°F and rainy. + } + + /// Shows available syntax for auto function invocation filter. + private sealed class AutoFunctionInvocationFilter(ITestOutputHelper output) : IAutoFunctionInvocationFilter + { + public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) + { + // Example: get function information + var functionName = context.Function.Name; + + // Example: get chat history + var chatHistory = context.ChatHistory; + + // Example: get information about all functions which will be invoked + var functionCalls = FunctionCallContent.GetFunctionCalls(context.ChatHistory.Last()); + + // In function calling functionality there are two loops. + // Outer loop is "request" loop - it performs multiple requests to LLM until user ask will be satisfied. + // Inner loop is "function" loop - it handles LLM response with multiple function calls. + + // Workflow example: + // 1. Request to LLM #1 -> Response with 3 functions to call. + // 1.1. Function #1 called. + // 1.2. Function #2 called. + // 1.3. Function #3 called. + // 2. Request to LLM #2 -> Response with 2 functions to call. + // 2.1. Function #1 called. + // 2.2. Function #2 called. + + // context.RequestSequenceIndex - it's a sequence number of outer/request loop operation. + // context.FunctionSequenceIndex - it's a sequence number of inner/function loop operation. + // context.FunctionCount - number of functions which will be called per request (based on example above: 3 for first request, 2 for second request). + + // Example: get request sequence index + output.WriteLine($"Request sequence index: {context.RequestSequenceIndex}"); + + // Example: get function sequence index + output.WriteLine($"Function sequence index: {context.FunctionSequenceIndex}"); + + // Example: get total number of functions which will be called + output.WriteLine($"Total number of functions: {context.FunctionCount}"); + + // Calling next filter in pipeline or function itself. + // By skipping this call, next filters and function won't be invoked, and function call loop will proceed to the next function. + await next(context); + + // Example: get function result + var result = context.Result; + + // Example: override function result value + context.Result = new FunctionResult(context.Result, "Result from auto function invocation filter"); + + // Example: Terminate function invocation + context.Terminate = true; + } + } + + /// Shows how to get list of all function calls per request. + private sealed class FunctionCallsFilter(ITestOutputHelper output) : IAutoFunctionInvocationFilter + { + public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) + { + var chatHistory = context.ChatHistory; + var functionCalls = FunctionCallContent.GetFunctionCalls(chatHistory.Last()).ToArray(); + + if (functionCalls is { Length: > 0 }) + { + foreach (var functionCall in functionCalls) + { + output.WriteLine($"Request #{context.RequestSequenceIndex}. Function call: {functionCall.PluginName}.{functionCall.FunctionName}."); + } + } + + await next(context); + } + } +} diff --git a/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs b/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs index d4631323c24d..02ddbdb3ec35 100644 --- a/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs +++ b/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs @@ -10,7 +10,7 @@ namespace KernelExamples; /// -/// This sample shows how to use a custom AI service selector to select a specific model by matching it's id. +/// This sample shows how to use a custom AI service selector to select a specific model by matching the model id. /// public class CustomAIServiceSelector(ITestOutputHelper output) : BaseTest(output) { @@ -39,7 +39,8 @@ public async Task UsingCustomSelectToSelectServiceByMatchingModelId() builder.Services .AddSingleton(customSelector) .AddKeyedChatClient("OpenAIChatClient", new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey) - .AsChatClient("gpt-4o")); // Add a IChatClient to the kernel + .GetChatClient("gpt-4o") + .AsIChatClient()); // Add a IChatClient to the kernel Kernel kernel = builder.Build(); @@ -60,7 +61,6 @@ private sealed class GptAIServiceSelector(string modelNameStartsWith, ITestOutpu private readonly ITestOutputHelper _output = output; private readonly string _modelNameStartsWith = modelNameStartsWith; - /// private bool TrySelect( Kernel kernel, KernelFunction function, KernelArguments arguments, [NotNullWhen(true)] out T? service, out PromptExecutionSettings? serviceSettings) where T : class @@ -78,7 +78,7 @@ private bool TrySelect( else if (serviceToCheck is IChatClient chatClient) { var metadata = chatClient.GetService(); - serviceModelId = metadata?.ModelId; + serviceModelId = metadata?.DefaultModelId; endpoint = metadata?.ProviderUri?.ToString(); } diff --git a/dotnet/samples/GettingStartedWithAgents/Step06_DependencyInjection.cs b/dotnet/samples/GettingStartedWithAgents/Step06_DependencyInjection.cs index 8935e4d66d48..39106b957841 100644 --- a/dotnet/samples/GettingStartedWithAgents/Step06_DependencyInjection.cs +++ b/dotnet/samples/GettingStartedWithAgents/Step06_DependencyInjection.cs @@ -43,25 +43,25 @@ public async Task UseDependencyInjectionToCreateAgentAsync(bool useChatClient) IChatClient chatClient; if (this.UseOpenAIConfig) { - chatClient = new Microsoft.Extensions.AI.OpenAIChatClient( - new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey), - TestConfiguration.OpenAI.ChatModelId); + chatClient = new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey) + .GetChatClient(TestConfiguration.OpenAI.ChatModelId) + .AsIChatClient(); } else if (!string.IsNullOrEmpty(this.ApiKey)) { - chatClient = new Microsoft.Extensions.AI.OpenAIChatClient( - openAIClient: new AzureOpenAIClient( + chatClient = new AzureOpenAIClient( endpoint: new Uri(TestConfiguration.AzureOpenAI.Endpoint), - credential: new ApiKeyCredential(TestConfiguration.AzureOpenAI.ApiKey)), - modelId: TestConfiguration.AzureOpenAI.ChatModelId); + credential: new ApiKeyCredential(TestConfiguration.AzureOpenAI.ApiKey)) + .GetChatClient(TestConfiguration.OpenAI.ChatModelId) + .AsIChatClient(); } else { - chatClient = new Microsoft.Extensions.AI.OpenAIChatClient( - openAIClient: new AzureOpenAIClient( + chatClient = new AzureOpenAIClient( endpoint: new Uri(TestConfiguration.AzureOpenAI.Endpoint), - credential: new AzureCliCredential()), - modelId: TestConfiguration.AzureOpenAI.ChatModelId); + credential: new AzureCliCredential()) + .GetChatClient(TestConfiguration.OpenAI.ChatModelId) + .AsIChatClient(); } var functionCallingChatClient = chatClient!.AsKernelFunctionInvokingChatClient(); diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj index 57f983dd5c9b..04d35b9e6561 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj @@ -96,6 +96,15 @@ Always + + Always + + + Always + + + Always + diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs new file mode 100644 index 000000000000..dd8d94c99824 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs @@ -0,0 +1,793 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.OpenAI; +using Xunit; + +namespace SemanticKernel.Connectors.OpenAI.UnitTests.Core; + +public sealed class AutoFunctionInvocationFilterChatClientTests : IDisposable +{ + private readonly MultipleHttpMessageHandlerStub _messageHandlerStub; + private readonly HttpClient _httpClient; + + public AutoFunctionInvocationFilterChatClientTests() + { + this._messageHandlerStub = new MultipleHttpMessageHandlerStub(); + + this._httpClient = new HttpClient(this._messageHandlerStub, false); + } + + [Fact] + public async Task FiltersAreExecutedCorrectlyAsync() + { + // Arrange + int filterInvocations = 0; + int functionInvocations = 0; + int[] expectedRequestSequenceNumbers = [0, 0, 1, 1]; + int[] expectedFunctionSequenceNumbers = [0, 1, 0, 1]; + List requestSequenceNumbers = []; + List functionSequenceNumbers = []; + Kernel? contextKernel = null; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { functionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { functionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + contextKernel = context.Kernel; + + if (context.ChatHistory.Last() is OpenAIChatMessageContent content) + { + Assert.Equal(2, content.ToolCalls.Count); + } + + requestSequenceNumbers.Add(context.RequestSequenceIndex); + functionSequenceNumbers.Add(context.FunctionSequenceIndex); + + await next(context); + + filterInvocations++; + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + })); + + // Assert + Assert.Equal(4, filterInvocations); + Assert.Equal(4, functionInvocations); + Assert.Equal(expectedRequestSequenceNumbers, requestSequenceNumbers); + Assert.Equal(expectedFunctionSequenceNumbers, functionSequenceNumbers); + Assert.Same(kernel, contextKernel); + Assert.Equal("Test chat response", result.ToString()); + } + + [Fact] + public async Task FunctionSequenceIndexIsCorrectForConcurrentCallsAsync() + { + // Arrange + List functionSequenceNumbers = []; + List expectedFunctionSequenceNumbers = [0, 1, 0, 1]; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + functionSequenceNumbers.Add(context.FunctionSequenceIndex); + + await next(context); + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(options: new() + { + AllowParallelCalls = true, + AllowConcurrentInvocation = true + }) + })); + + // Assert + Assert.Equal(expectedFunctionSequenceNumbers, functionSequenceNumbers); + } + + [Fact] + public async Task FiltersAreExecutedCorrectlyOnStreamingAsync() + { + // Arrange + int filterInvocations = 0; + int functionInvocations = 0; + List requestSequenceNumbers = []; + List functionSequenceNumbers = []; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { functionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { functionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + if (context.ChatHistory.Last() is OpenAIChatMessageContent content) + { + Assert.Equal(2, content.ToolCalls.Count); + } + + requestSequenceNumbers.Add(context.RequestSequenceIndex); + functionSequenceNumbers.Add(context.FunctionSequenceIndex); + + await next(context); + + filterInvocations++; + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + var executionSettings = new OpenAIPromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + // Act + await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings))) + { } + + // Assert + Assert.Equal(4, filterInvocations); + Assert.Equal(4, functionInvocations); + Assert.Equal([0, 0, 1, 1], requestSequenceNumbers); + Assert.Equal([0, 1, 0, 1], functionSequenceNumbers); + } + + [Fact] + public async Task DifferentWaysOfAddingFiltersWorkCorrectlyAsync() + { + // Arrange + var executionOrder = new List(); + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var filter1 = new AutoFunctionInvocationFilter(async (context, next) => + { + executionOrder.Add("Filter1-Invoking"); + await next(context); + }); + + var filter2 = new AutoFunctionInvocationFilter(async (context, next) => + { + executionOrder.Add("Filter2-Invoking"); + await next(context); + }); + + var builder = Kernel.CreateBuilder(); + + builder.Plugins.Add(plugin); + + builder.Services.AddOpenAIChatClient("model-id", "test-api-key", "organization-id", httpClient: this._httpClient); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + + // Case #1 - Add filter to services + builder.Services.AddSingleton(filter1); + + var kernel = builder.Build(); + + // Case #2 - Add filter to kernel + kernel.AutoFunctionInvocationFilters.Add(filter2); + + var result = await kernel.InvokePromptAsync("Test prompt", new(new PromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + })); + + // Assert + Assert.Equal("Filter1-Invoking", executionOrder[0]); + Assert.Equal("Filter2-Invoking", executionOrder[1]); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task MultipleFiltersAreExecutedInOrderAsync(bool isStreaming) + { + // Arrange + var executionOrder = new List(); + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var filter1 = new AutoFunctionInvocationFilter(async (context, next) => + { + executionOrder.Add("Filter1-Invoking"); + await next(context); + executionOrder.Add("Filter1-Invoked"); + }); + + var filter2 = new AutoFunctionInvocationFilter(async (context, next) => + { + executionOrder.Add("Filter2-Invoking"); + await next(context); + executionOrder.Add("Filter2-Invoked"); + }); + + var filter3 = new AutoFunctionInvocationFilter(async (context, next) => + { + executionOrder.Add("Filter3-Invoking"); + await next(context); + executionOrder.Add("Filter3-Invoked"); + }); + + var builder = Kernel.CreateBuilder(); + + builder.Plugins.Add(plugin); + + builder.Services.AddOpenAIChatClient("model-id", "test-api-key", "organization-id", httpClient: this._httpClient); + + builder.Services.AddSingleton(filter1); + builder.Services.AddSingleton(filter2); + builder.Services.AddSingleton(filter3); + + var kernel = builder.Build(); + + var settings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + // Act + if (isStreaming) + { + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(settings))) + { } + } + else + { + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + await kernel.InvokePromptAsync("Test prompt", new(settings)); + } + + // Assert + Assert.Equal("Filter1-Invoking", executionOrder[0]); + Assert.Equal("Filter2-Invoking", executionOrder[1]); + Assert.Equal("Filter3-Invoking", executionOrder[2]); + Assert.Equal("Filter3-Invoked", executionOrder[3]); + Assert.Equal("Filter2-Invoked", executionOrder[4]); + Assert.Equal("Filter1-Invoked", executionOrder[5]); + } + + [Fact] + public async Task FilterCanOverrideArgumentsAsync() + { + // Arrange + const string NewValue = "NewValue"; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + context.Arguments!["parameter"] = NewValue; + await next(context); + context.Terminate = true; + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + })); + + // Assert + var chatResponse = Assert.IsType(result.GetValue()); + Assert.NotNull(chatResponse); + + var lastFunctionResult = GetLastFunctionResultFromChatResponse(chatResponse); + Assert.NotNull(lastFunctionResult); + Assert.Equal("NewValue", lastFunctionResult.ToString()); + } + + [Fact] + public async Task FilterCanHandleExceptionAsync() + { + // Arrange + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { throw new KernelException("Exception from Function1"); }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => "Result from Function2", "Function2"); + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + try + { + await next(context); + } + catch (KernelException exception) + { + Assert.Equal("Exception from Function1", exception.Message); + context.Result = new FunctionResult(context.Result, "Result from filter"); + } + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + var chatClient = kernel.GetRequiredService(); + + var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + var options = executionSettings.ToChatOptions(kernel); + List messageList = [new(ChatRole.System, "System message")]; + + // Act + var resultMessages = await chatClient.GetResponseAsync(messageList, options, CancellationToken.None); + + // Assert + var firstToolMessage = resultMessages.Messages.First(m => m.Role == ChatRole.Tool); + Assert.NotNull(firstToolMessage); + var firstFunctionResult = firstToolMessage.Contents[^2] as Microsoft.Extensions.AI.FunctionResultContent; + var secondFunctionResult = firstToolMessage.Contents[^1] as Microsoft.Extensions.AI.FunctionResultContent; + + Assert.NotNull(firstFunctionResult); + Assert.NotNull(secondFunctionResult); + Assert.Equal("Result from filter", firstFunctionResult.Result!.ToString()); + Assert.Equal("Result from Function2", secondFunctionResult.Result!.ToString()); + } + + [Fact] + public async Task FilterCanHandleExceptionOnStreamingAsync() + { + // Arrange + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { throw new KernelException("Exception from Function1"); }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => "Result from Function2", "Function2"); + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + try + { + await next(context); + } + catch (KernelException) + { + context.Result = new FunctionResult(context.Result, "Result from filter"); + } + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + var chatClient = kernel.GetRequiredService(); + + var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + var options = executionSettings.ToChatOptions(kernel); + List messageList = []; + + // Act + List streamingContent = []; + await foreach (var update in chatClient.GetStreamingResponseAsync(messageList, options, CancellationToken.None)) + { + streamingContent.Add(update); + } + var chatResponse = streamingContent.ToChatResponse(); + + // Assert + var firstToolMessage = chatResponse.Messages.First(m => m.Role == ChatRole.Tool); + Assert.NotNull(firstToolMessage); + var firstFunctionResult = firstToolMessage.Contents[^2] as Microsoft.Extensions.AI.FunctionResultContent; + var secondFunctionResult = firstToolMessage.Contents[^1] as Microsoft.Extensions.AI.FunctionResultContent; + + Assert.NotNull(firstFunctionResult); + Assert.NotNull(secondFunctionResult); + Assert.Equal("Result from filter", firstFunctionResult.Result!.ToString()); + Assert.Equal("Result from Function2", secondFunctionResult.Result!.ToString()); + } + + [Fact] + public async Task FiltersCanSkipFunctionExecutionAsync() + { + // Arrange + int filterInvocations = 0; + int firstFunctionInvocations = 0; + int secondFunctionInvocations = 0; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + // Filter delegate is invoked only for second function, the first one should be skipped. + if (context.Function.Name == "MyPlugin_Function2") + { + await next(context); + } + + filterInvocations++; + }); + + using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(File.ReadAllText("TestData/filters_chatclient_multiple_function_calls_test_response.json")) }; + using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(File.ReadAllText("TestData/chat_completion_test_response.json")) }; + + this._messageHandlerStub.ResponsesToReturn = [response1, response2]; + + // Act + var result = await kernel.InvokePromptAsync("Test prompt", new(new PromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + })); + + // Assert + Assert.Equal(2, filterInvocations); + Assert.Equal(0, firstFunctionInvocations); + Assert.Equal(1, secondFunctionInvocations); + } + + [Fact] + public async Task PreFilterCanTerminateOperationAsync() + { + // Arrange + int firstFunctionInvocations = 0; + int secondFunctionInvocations = 0; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + // Terminating before first function, so all functions won't be invoked. + context.Terminate = true; + + await next(context); + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + await kernel.InvokePromptAsync("Test prompt", new(new PromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + })); + + // Assert + Assert.Equal(0, firstFunctionInvocations); + Assert.Equal(0, secondFunctionInvocations); + } + + [Fact] + public async Task PreFilterCanTerminateOperationOnStreamingAsync() + { + // Arrange + int firstFunctionInvocations = 0; + int secondFunctionInvocations = 0; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + // Terminating before first function, so all functions won't be invoked. + context.Terminate = true; + + await next(context); + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + // Act + await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings))) + { } + + // Assert + Assert.Equal(0, firstFunctionInvocations); + Assert.Equal(0, secondFunctionInvocations); + } + + [Fact] + public async Task PostFilterCanTerminateOperationAsync() + { + // Arrange + int firstFunctionInvocations = 0; + int secondFunctionInvocations = 0; + List requestSequenceNumbers = []; + List functionSequenceNumbers = []; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + requestSequenceNumbers.Add(context.RequestSequenceIndex); + functionSequenceNumbers.Add(context.FunctionSequenceIndex); + + await next(context); + + // Terminating after first function, so second function won't be invoked. + context.Terminate = true; + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + var functionResult = await kernel.InvokePromptAsync("Test prompt", new(new PromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + })); + + // Assert + Assert.Equal(1, firstFunctionInvocations); + Assert.Equal(0, secondFunctionInvocations); + Assert.Equal([0], requestSequenceNumbers); + Assert.Equal([0], functionSequenceNumbers); + + // Results of function invoked before termination should be returned + var chatResponse = functionResult.GetValue(); + Assert.NotNull(chatResponse); + + var result = GetLastFunctionResultFromChatResponse(chatResponse); + Assert.NotNull(result); + Assert.Equal("function1-value", result.ToString()); + } + + [Fact] + public async Task PostFilterCanTerminateOperationOnStreamingAsync() + { + // Arrange + int firstFunctionInvocations = 0; + int secondFunctionInvocations = 0; + List requestSequenceNumbers = []; + List functionSequenceNumbers = []; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + requestSequenceNumbers.Add(context.RequestSequenceIndex); + functionSequenceNumbers.Add(context.FunctionSequenceIndex); + + await next(context); + + // Terminating after first function, so second function won't be invoked. + context.Terminate = true; + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + List streamingContent = []; + + // Act + await foreach (var update in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings))) + { + streamingContent.Add(update); + } + + // Assert + Assert.Equal(1, firstFunctionInvocations); + Assert.Equal(0, secondFunctionInvocations); + Assert.Equal([0], requestSequenceNumbers); + Assert.Equal([0], functionSequenceNumbers); + + // Results of function invoked before termination should be returned + Assert.Equal(4, streamingContent.Count); + + var chatResponse = streamingContent.ToChatResponse(); + Assert.NotNull(chatResponse); + + var result = GetLastFunctionResultFromChatResponse(chatResponse); + Assert.NotNull(result); + Assert.Equal("function1-value", result.ToString()); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming) + { + // Arrange + bool? actualStreamingFlag = null; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var filter = new AutoFunctionInvocationFilter(async (context, next) => + { + actualStreamingFlag = context.IsStreaming; + await next(context); + }); + + var builder = Kernel.CreateBuilder(); + + builder.Plugins.Add(plugin); + + builder.Services.AddOpenAIChatClient("model-id", "test-api-key", "organization-id", httpClient: this._httpClient); + + builder.Services.AddSingleton(filter); + + var kernel = builder.Build(); + + var settings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + + // Act + if (isStreaming) + { + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + await kernel.InvokePromptStreamingAsync("Test prompt", new(settings)).ToListAsync(); + } + else + { + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + await kernel.InvokePromptAsync("Test prompt", new(settings)); + } + + // Assert + Assert.Equal(isStreaming, actualStreamingFlag); + } + + [Fact] + public async Task PromptExecutionSettingsArePropagatedFromInvokePromptToFilterContextAsync() + { + // Arrange + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]); + + AutoFunctionInvocationContext? actualContext = null; + + var kernel = this.GetKernelWithFilter(plugin, (context, next) => + { + actualContext = context; + return Task.CompletedTask; + }); + + var expectedExecutionSettings = new PromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + }; + + // Act + var result = await kernel.InvokePromptAsync("Test prompt", new(expectedExecutionSettings)); + + // Assert + Assert.NotNull(actualContext); + Assert.Same(expectedExecutionSettings, actualContext!.ExecutionSettings); + } + + [Fact] + public async Task PromptExecutionSettingsArePropagatedFromInvokePromptStreamingToFilterContextAsync() + { + // Arrange + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]); + + AutoFunctionInvocationContext? actualContext = null; + + var kernel = this.GetKernelWithFilter(plugin, (context, next) => + { + actualContext = context; + return Task.CompletedTask; + }); + + var expectedExecutionSettings = new PromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + }; + + // Act + await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(expectedExecutionSettings))) + { } + + // Assert + Assert.NotNull(actualContext); + Assert.Same(expectedExecutionSettings, actualContext!.ExecutionSettings); + } + + public void Dispose() + { + this._httpClient.Dispose(); + this._messageHandlerStub.Dispose(); + } + + #region private + + private static object? GetLastFunctionResultFromChatResponse(ChatResponse chatResponse) + { + Assert.NotEmpty(chatResponse.Messages); + var chatMessage = chatResponse.Messages[^1]; + + Assert.NotEmpty(chatMessage.Contents); + Assert.Contains(chatMessage.Contents, c => c is Microsoft.Extensions.AI.FunctionResultContent); + + var resultContent = (Microsoft.Extensions.AI.FunctionResultContent)chatMessage.Contents.Last(c => c is Microsoft.Extensions.AI.FunctionResultContent); + return resultContent.Result; + } + +#pragma warning disable CA2000 // Dispose objects before losing scope + private static List GetFunctionCallingResponses() + { + return [ + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_chatclient_multiple_function_calls_test_response.json")) }, + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_chatclient_multiple_function_calls_test_response.json")) }, + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_test_response.json")) } + ]; + } + + private static List GetFunctionCallingStreamingResponses() + { + return [ + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt")) }, + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt")) }, + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_streaming_test_response.txt")) } + ]; + } +#pragma warning restore CA2000 + + private Kernel GetKernelWithFilter( + KernelPlugin plugin, + Func, Task>? onAutoFunctionInvocation) + { + var builder = Kernel.CreateBuilder(); + var filter = new AutoFunctionInvocationFilter(onAutoFunctionInvocation); + + builder.Plugins.Add(plugin); + builder.Services.AddSingleton(filter); + + builder.AddOpenAIChatClient("model-id", "test-api-key", "organization-id", httpClient: this._httpClient); + + return builder.Build(); + } + + private sealed class AutoFunctionInvocationFilter( + Func, Task>? onAutoFunctionInvocation) : IAutoFunctionInvocationFilter + { + private readonly Func, Task>? _onAutoFunctionInvocation = onAutoFunctionInvocation; + + public Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) => + this._onAutoFunctionInvocation?.Invoke(context, next) ?? Task.CompletedTask; + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs index 19992be01667..b308206b12d5 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs @@ -312,7 +312,7 @@ public async Task FilterCanOverrideArgumentsAsync() // Act var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings { - ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() })); // Assert @@ -596,7 +596,7 @@ public async Task PostFilterCanTerminateOperationOnStreamingAsync() Assert.Equal([0], requestSequenceNumbers); Assert.Equal([0], functionSequenceNumbers); - // Results of function invoked before termination should be returned + // Results of function invoked before termination should be returned Assert.Equal(3, streamingContent.Count); var lastMessageContent = streamingContent[^1] as StreamingChatMessageContent; diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIKernelBuilderExtensionsChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIKernelBuilderExtensionsChatClientTests.cs new file mode 100644 index 000000000000..437d347aa194 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIKernelBuilderExtensionsChatClientTests.cs @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using OpenAI; +using Xunit; + +namespace SemanticKernel.Connectors.OpenAI.UnitTests.Extensions; + +public class OpenAIKernelBuilderExtensionsChatClientTests +{ + [Fact] + public void AddOpenAIChatClientNullArgsThrow() + { + // Arrange + IKernelBuilder builder = null!; + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + + // Act & Assert + var exception = Assert.Throws(() => builder.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId)); + Assert.Equal("builder", exception.ParamName); + + exception = Assert.Throws(() => builder.AddOpenAIChatClient(modelId, new OpenAIClient(apiKey), serviceId)); + Assert.Equal("builder", exception.ParamName); + + using var httpClient = new HttpClient(); + exception = Assert.Throws(() => builder.AddOpenAIChatClient(modelId, new Uri("http://localhost"), apiKey, orgId, serviceId, httpClient)); + Assert.Equal("builder", exception.ParamName); + } + + [Fact] + public void AddOpenAIChatClientDefaultValidParametersRegistersService() + { + // Arrange + var builder = Kernel.CreateBuilder(); + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + + // Act + builder.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId); + + // Assert + var kernel = builder.Build(); + Assert.NotNull(kernel.GetRequiredService()); + Assert.NotNull(kernel.GetRequiredService(serviceId)); + } + + [Fact] + public void AddOpenAIChatClientOpenAIClientValidParametersRegistersService() + { + // Arrange + var builder = Kernel.CreateBuilder(); + string modelId = "gpt-3.5-turbo"; + var openAIClient = new OpenAIClient("test_api_key"); + string serviceId = "test_service_id"; + + // Act + builder.AddOpenAIChatClient(modelId, openAIClient, serviceId); + + // Assert + var kernel = builder.Build(); + Assert.NotNull(kernel.GetRequiredService()); + Assert.NotNull(kernel.GetRequiredService(serviceId)); + } + + [Fact] + public void AddOpenAIChatClientCustomEndpointValidParametersRegistersService() + { + // Arrange + var builder = Kernel.CreateBuilder(); + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + using var httpClient = new HttpClient(); + + // Act + builder.AddOpenAIChatClient(modelId, new Uri("http://localhost"), apiKey, orgId, serviceId, httpClient); + + // Assert + var kernel = builder.Build(); + Assert.NotNull(kernel.GetRequiredService()); + Assert.NotNull(kernel.GetRequiredService(serviceId)); + } +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIServiceCollectionExtensionsChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIServiceCollectionExtensionsChatClientTests.cs new file mode 100644 index 000000000000..7a3888b95f30 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIServiceCollectionExtensionsChatClientTests.cs @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using OpenAI; +using Xunit; + +namespace SemanticKernel.Connectors.OpenAI.UnitTests.Extensions; + +public class OpenAIServiceCollectionExtensionsChatClientTests +{ + [Fact] + public void AddOpenAIChatClientNullArgsThrow() + { + // Arrange + ServiceCollection services = null!; + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + + // Act & Assert + var exception = Assert.Throws(() => services.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId)); + Assert.Equal("services", exception.ParamName); + + exception = Assert.Throws(() => services.AddOpenAIChatClient(modelId, new OpenAIClient(apiKey), serviceId)); + Assert.Equal("services", exception.ParamName); + + using var httpClient = new HttpClient(); + exception = Assert.Throws(() => services.AddOpenAIChatClient(modelId, new Uri("http://localhost"), apiKey, orgId, serviceId, httpClient)); + Assert.Equal("services", exception.ParamName); + } + + [Fact] + public void AddOpenAIChatClientDefaultValidParametersRegistersService() + { + // Arrange + var services = new ServiceCollection(); + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + + // Act + services.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId); + + // Assert + var serviceProvider = services.BuildServiceProvider(); + var chatClient = serviceProvider.GetKeyedService(serviceId); + Assert.NotNull(chatClient); + } + + [Fact] + public void AddOpenAIChatClientOpenAIClientValidParametersRegistersService() + { + // Arrange + var services = new ServiceCollection(); + string modelId = "gpt-3.5-turbo"; + var openAIClient = new OpenAIClient("test_api_key"); + string serviceId = "test_service_id"; + + // Act + services.AddOpenAIChatClient(modelId, openAIClient, serviceId); + + // Assert + var serviceProvider = services.BuildServiceProvider(); + var chatClient = serviceProvider.GetKeyedService(serviceId); + Assert.NotNull(chatClient); + } + + [Fact] + public void AddOpenAIChatClientCustomEndpointValidParametersRegistersService() + { + // Arrange + var services = new ServiceCollection(); + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + using var httpClient = new HttpClient(); + // Act + services.AddOpenAIChatClient(modelId, new Uri("http://localhost"), apiKey, orgId, serviceId, httpClient); + // Assert + var serviceProvider = services.BuildServiceProvider(); + var chatClient = serviceProvider.GetKeyedService(serviceId); + Assert.NotNull(chatClient); + } + + [Fact] + public void AddOpenAIChatClientWorksWithKernel() + { + var services = new ServiceCollection(); + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + + // Act + services.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId); + services.AddKernel(); + + var serviceProvider = services.BuildServiceProvider(); + var kernel = serviceProvider.GetRequiredService(); + + var serviceFromCollection = serviceProvider.GetKeyedService(serviceId); + var serviceFromKernel = kernel.GetRequiredService(serviceId); + + Assert.NotNull(serviceFromKernel); + Assert.Same(serviceFromCollection, serviceFromKernel); + } +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/chat_completion_streaming_chatclient_multiple_function_calls_test_response.txt b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/chat_completion_streaming_chatclient_multiple_function_calls_test_response.txt new file mode 100644 index 000000000000..17ce94647fd5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/chat_completion_streaming_chatclient_multiple_function_calls_test_response.txt @@ -0,0 +1,9 @@ +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":0,"id":"1","type":"function","function":{"name":"MyPlugin_GetCurrentWeather","arguments":"{\n\"location\": \"Boston, MA\"\n}"}}]},"finish_reason":"tool_calls"}]} + +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":1,"id":"2","type":"function","function":{"name":"MyPlugin_FunctionWithException","arguments":"{\n\"argument\": \"value\"\n}"}}]},"finish_reason":"tool_calls"}]} + +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":2,"id":"3","type":"function","function":{"name":"MyPlugin_NonExistentFunction","arguments":"{\n\"argument\": \"value\"\n}"}}]},"finish_reason":"tool_calls"}]} + +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":3,"id":"4","type":"function","function":{"name":"MyPlugin_InvalidArguments","arguments":"invalid_arguments_format"}}]},"finish_reason":"tool_calls"}]} + +data: [DONE] diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_multiple_function_calls_test_response.json b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_multiple_function_calls_test_response.json new file mode 100644 index 000000000000..2c499b14089f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_multiple_function_calls_test_response.json @@ -0,0 +1,40 @@ +{ + "id": "response-id", + "object": "chat.completion", + "created": 1699896916, + "model": "gpt-3.5-turbo-0613", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "1", + "type": "function", + "function": { + "name": "MyPlugin_Function1", + "arguments": "{\n\"parameter\": \"function1-value\"\n}" + } + }, + { + "id": "2", + "type": "function", + "function": { + "name": "MyPlugin_Function2", + "arguments": "{\n\"parameter\": \"function2-value\"\n}" + } + } + ] + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 82, + "completion_tokens": 17, + "total_tokens": 99 + } +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt new file mode 100644 index 000000000000..c113e3fa97ca --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt @@ -0,0 +1,5 @@ +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":0,"id":"1","type":"function","function":{"name":"MyPlugin_Function1","arguments":"{\n\"parameter\": \"function1-value\"\n}"}}]},"finish_reason":"tool_calls"}]} + +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":1,"id":"2","type":"function","function":{"name":"MyPlugin_Function2","arguments":"{\n\"parameter\": \"function2-value\"\n}"}}]},"finish_reason":"tool_calls"}]} + +data: [DONE] diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj b/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj index 64a0e72bde6d..c17e878a7a42 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj +++ b/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj @@ -37,6 +37,7 @@ + diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs new file mode 100644 index 000000000000..9d1832b340ff --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Net.Http; +using Microsoft.Extensions.AI; +using OpenAI; + +namespace Microsoft.SemanticKernel; + +/// Extension methods for . +[Experimental("SKEXP0010")] +public static class OpenAIChatClientKernelBuilderExtensions +{ + #region Chat Completion + + /// + /// Adds an OpenAI to the . + /// + /// The instance to augment. + /// OpenAI model name, see https://platform.openai.com/docs/models + /// OpenAI API key, see https://platform.openai.com/account/api-keys + /// OpenAI organization id. This is usually optional unless your account belongs to multiple organizations. + /// A local identifier for the given AI service + /// The HttpClient to use with this service. + /// The same instance as . + public static IKernelBuilder AddOpenAIChatClient( + this IKernelBuilder builder, + string modelId, + string apiKey, + string? orgId = null, + string? serviceId = null, + HttpClient? httpClient = null) + { + Verify.NotNull(builder); + + builder.Services.AddOpenAIChatClient( + modelId, + apiKey, + orgId, + serviceId, + httpClient); + + return builder; + } + + /// + /// Adds an OpenAI to the . + /// + /// The instance to augment. + /// OpenAI model id + /// to use for the service. If null, one must be available in the service provider when this service is resolved. + /// A local identifier for the given AI service + /// The same instance as . + public static IKernelBuilder AddOpenAIChatClient( + this IKernelBuilder builder, + string modelId, + OpenAIClient? openAIClient = null, + string? serviceId = null) + { + Verify.NotNull(builder); + + builder.Services.AddOpenAIChatClient( + modelId, + openAIClient, + serviceId); + + return builder; + } + + /// + /// Adds a custom endpoint OpenAI to the . + /// + /// The instance to augment. + /// OpenAI model name, see https://platform.openai.com/docs/models + /// Custom OpenAI Compatible Message API endpoint + /// OpenAI API key, see https://platform.openai.com/account/api-keys + /// OpenAI organization id. This is usually optional unless your account belongs to multiple organizations. + /// A local identifier for the given AI service + /// The HttpClient to use with this service. + /// The same instance as . + public static IKernelBuilder AddOpenAIChatClient( + this IKernelBuilder builder, + string modelId, + Uri endpoint, + string? apiKey, + string? orgId = null, + string? serviceId = null, + HttpClient? httpClient = null) + { + Verify.NotNull(builder); + + builder.Services.AddOpenAIChatClient( + modelId, + endpoint, + apiKey, + orgId, + serviceId, + httpClient); + + return builder; + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.cs index a07a81fdb5b3..01307b9adc2a 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.cs @@ -20,7 +20,7 @@ namespace Microsoft.SemanticKernel; /// -/// Sponsor extensions class for . +/// Extension methods for . /// public static class OpenAIKernelBuilderExtensions { @@ -269,7 +269,7 @@ public static IKernelBuilder AddOpenAIFiles( #region Chat Completion /// - /// Adds the OpenAI chat completion service to the list. + /// Adds an to the . /// /// The instance to augment. /// OpenAI model name, see https://platform.openai.com/docs/models @@ -304,7 +304,7 @@ OpenAIChatCompletionService Factory(IServiceProvider serviceProvider, object? _) } /// - /// Adds the OpenAI chat completion service to the list. + /// Adds an to the . /// /// The instance to augment. /// OpenAI model id @@ -330,7 +330,7 @@ OpenAIChatCompletionService Factory(IServiceProvider serviceProvider, object? _) } /// - /// Adds the Custom Endpoint OpenAI chat completion service to the list. + /// Adds a custom endpoint to the . /// /// The instance to augment. /// OpenAI model name, see https://platform.openai.com/docs/models diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs new file mode 100644 index 000000000000..2954e958936a --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs @@ -0,0 +1,154 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Diagnostics.CodeAnalysis; +using System.Net.Http; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Http; +using OpenAI; + +namespace Microsoft.SemanticKernel; + +/// +/// Sponsor extensions class for . +/// +[Experimental("SKEXP0010")] +public static class OpenAIChatClientServiceCollectionExtensions +{ + /// + /// White space constant. + /// + private const string SingleSpace = " "; + + /// + /// Adds the OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// OpenAI model name, see https://platform.openai.com/docs/models + /// OpenAI API key, see https://platform.openai.com/account/api-keys + /// OpenAI organization id. This is usually optional unless your account belongs to multiple organizations. + /// A local identifier for the given AI service + /// The HttpClient to use with this service. + /// The same instance as . + public static IServiceCollection AddOpenAIChatClient( + this IServiceCollection services, + string modelId, + string apiKey, + string? orgId = null, + string? serviceId = null, + HttpClient? httpClient = null) + { + Verify.NotNull(services); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + var loggerFactory = serviceProvider.GetService(); + + return new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), options: GetClientOptions(orgId: orgId, httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider))) + .GetChatClient(modelId) + .AsIChatClient() + .AsKernelFunctionInvokingChatClient(loggerFactory); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + + /// + /// Adds the OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// OpenAI model id + /// to use for the service. If null, one must be available in the service provider when this service is resolved. + /// A local identifier for the given AI service + /// The same instance as . + public static IServiceCollection AddOpenAIChatClient(this IServiceCollection services, + string modelId, + OpenAIClient? openAIClient = null, + string? serviceId = null) + { + Verify.NotNull(services); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + var loggerFactory = serviceProvider.GetService(); + + return (openAIClient ?? serviceProvider.GetRequiredService()) + .GetChatClient(modelId) + .AsIChatClient() + .AsKernelFunctionInvokingChatClient(loggerFactory); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + + /// + /// Adds the Custom OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// OpenAI model name, see https://platform.openai.com/docs/models + /// A Custom Message API compatible endpoint. + /// OpenAI API key, see https://platform.openai.com/account/api-keys + /// OpenAI organization id. This is usually optional unless your account belongs to multiple organizations. + /// A local identifier for the given AI service + /// The HttpClient to use with this service. + /// The same instance as . + public static IServiceCollection AddOpenAIChatClient( + this IServiceCollection services, + string modelId, + Uri endpoint, + string? apiKey = null, + string? orgId = null, + string? serviceId = null, + HttpClient? httpClient = null) + { + Verify.NotNull(services); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + var loggerFactory = serviceProvider.GetService(); + + return new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), GetClientOptions(endpoint, orgId, HttpClientProvider.GetHttpClient(httpClient, serviceProvider))) + .GetChatClient(modelId) + .AsIChatClient() + .AsKernelFunctionInvokingChatClient(loggerFactory); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + + private static OpenAIClientOptions GetClientOptions( + Uri? endpoint = null, + string? orgId = null, + HttpClient? httpClient = null) + { + OpenAIClientOptions options = new(); + + if (endpoint is not null) + { + options.Endpoint = endpoint; + } + + if (orgId is not null) + { + options.OrganizationId = orgId; + } + + if (httpClient is not null) + { + options.Transport = new HttpClientPipelineTransport(httpClient); + } + + return options; + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIAudioToTextTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIAudioToTextTests.cs index 90375307c533..9e1127fa8b55 100644 --- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIAudioToTextTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIAudioToTextTests.cs @@ -22,7 +22,7 @@ public sealed class OpenAIAudioToTextTests() .AddUserSecrets() .Build(); - [RetryFact]//(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] + [RetryFact] //(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] public async Task OpenAIAudioToTextTestAsync() { // Arrange diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs index fe8ff155d9c5..ddfe6b997a25 100644 --- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs @@ -12,7 +12,6 @@ using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Http.Resilience; -using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.OpenAI; using OpenAI; @@ -48,16 +47,16 @@ public async Task ItCanUseOpenAiChatForTextGenerationAsync() [Fact] public async Task ItCanUseOpenAiChatClientAndContentsAsync() { - var OpenAIConfiguration = this._configuration.GetSection("OpenAI").Get(); - Assert.NotNull(OpenAIConfiguration); - Assert.NotNull(OpenAIConfiguration.ChatModelId); - Assert.NotNull(OpenAIConfiguration.ApiKey); - Assert.NotNull(OpenAIConfiguration.ServiceId); + var openAIConfiguration = this._configuration.GetSection("OpenAI").Get(); + Assert.NotNull(openAIConfiguration); + Assert.NotNull(openAIConfiguration.ChatModelId); + Assert.NotNull(openAIConfiguration.ApiKey); + Assert.NotNull(openAIConfiguration.ServiceId); // Arrange - var openAIClient = new OpenAIClient(OpenAIConfiguration.ApiKey); + var openAIClient = new OpenAIClient(openAIConfiguration.ApiKey); var builder = Kernel.CreateBuilder(); - builder.Services.AddChatClient(openAIClient.AsChatClient(OpenAIConfiguration.ChatModelId)); + builder.Services.AddChatClient(openAIClient.GetChatClient(openAIConfiguration.ChatModelId).AsIChatClient()); var kernel = builder.Build(); var func = kernel.CreateFunctionFromPrompt( @@ -104,16 +103,16 @@ public async Task OpenAIStreamingTestAsync() [Fact] public async Task ItCanUseOpenAiStreamingChatClientAndContentsAsync() { - var OpenAIConfiguration = this._configuration.GetSection("OpenAI").Get(); - Assert.NotNull(OpenAIConfiguration); - Assert.NotNull(OpenAIConfiguration.ChatModelId); - Assert.NotNull(OpenAIConfiguration.ApiKey); - Assert.NotNull(OpenAIConfiguration.ServiceId); + var openAIConfiguration = this._configuration.GetSection("OpenAI").Get(); + Assert.NotNull(openAIConfiguration); + Assert.NotNull(openAIConfiguration.ChatModelId); + Assert.NotNull(openAIConfiguration.ApiKey); + Assert.NotNull(openAIConfiguration.ServiceId); // Arrange - var openAIClient = new OpenAIClient(OpenAIConfiguration.ApiKey); + var openAIClient = new OpenAIClient(openAIConfiguration.ApiKey); var builder = Kernel.CreateBuilder(); - builder.Services.AddChatClient(openAIClient.AsChatClient(OpenAIConfiguration.ChatModelId)); + builder.Services.AddChatClient(openAIClient.GetChatClient(openAIConfiguration.ChatModelId).AsIChatClient()); var kernel = builder.Build(); var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin"); @@ -179,7 +178,7 @@ public async Task OpenAIHttpRetryPolicyTestAsync() // Assert Assert.All(statusCodes, s => Assert.Equal(HttpStatusCode.Unauthorized, s)); - Assert.Equal(HttpStatusCode.Unauthorized, ((HttpOperationException)exception).StatusCode); + Assert.Equal(HttpStatusCode.Unauthorized, exception.StatusCode); } [Fact] @@ -258,11 +257,11 @@ public async Task SemanticKernelVersionHeaderIsSentAsync() var kernel = this.CreateAndInitializeKernel(httpClient); // Act - var result = await kernel.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?"); + await kernel.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?"); // Assert Assert.NotNull(httpHeaderHandler.RequestHeaders); - Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var values)); + Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var _)); } //[Theory(Skip = "This test is for manual verification.")] @@ -301,18 +300,18 @@ public async Task LogProbsDataIsReturnedWhenRequestedAsync(bool? logprobs, int? private Kernel CreateAndInitializeKernel(HttpClient? httpClient = null) { - var OpenAIConfiguration = this._configuration.GetSection("OpenAI").Get(); - Assert.NotNull(OpenAIConfiguration); - Assert.NotNull(OpenAIConfiguration.ChatModelId); - Assert.NotNull(OpenAIConfiguration.ApiKey); - Assert.NotNull(OpenAIConfiguration.ServiceId); + var openAIConfiguration = this._configuration.GetSection("OpenAI").Get(); + Assert.NotNull(openAIConfiguration); + Assert.NotNull(openAIConfiguration.ChatModelId); + Assert.NotNull(openAIConfiguration.ApiKey); + Assert.NotNull(openAIConfiguration.ServiceId); - var kernelBuilder = base.CreateKernelBuilder(); + var kernelBuilder = this.CreateKernelBuilder(); kernelBuilder.AddOpenAIChatCompletion( - modelId: OpenAIConfiguration.ChatModelId, - apiKey: OpenAIConfiguration.ApiKey, - serviceId: OpenAIConfiguration.ServiceId, + modelId: openAIConfiguration.ChatModelId, + apiKey: openAIConfiguration.ApiKey, + serviceId: openAIConfiguration.ServiceId, httpClient: httpClient); return kernelBuilder.Build(); diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAITextToAudioTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAITextToAudioTests.cs index c2818abe2502..420295fe4349 100644 --- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAITextToAudioTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAITextToAudioTests.cs @@ -18,7 +18,7 @@ public sealed class OpenAITextToAudioTests .AddUserSecrets() .Build(); - [Fact]//(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] + [Fact] //(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] public async Task OpenAITextToAudioTestAsync() { // Arrange diff --git a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs index 97cb426c307d..88f3da9d6a53 100644 --- a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs +++ b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs @@ -211,7 +211,7 @@ public FunctionCallsProcessor(ILogger? logger = null) { bool terminationRequested = false; - // Wait for all of the function invocations to complete, then add the results to the chat, but stop when we hit a + // Wait for all the function invocations to complete, then add the results to the chat, but stop when we hit a // function for which termination was requested. FunctionResultContext[] resultContexts = await Task.WhenAll(functionTasks).ConfigureAwait(false); foreach (FunctionResultContext resultContext in resultContexts) @@ -487,8 +487,8 @@ public static string ProcessFunctionResult(object functionResult) return stringResult; } - // This is an optimization to use ChatMessageContent content directly - // without unnecessary serialization of the whole message content class. + // This is an optimization to use ChatMessageContent content directly + // without unnecessary serialization of the whole message content class. if (functionResult is ChatMessageContent chatMessageContent) { return chatMessageContent.ToString(); diff --git a/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs b/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs index 2fefb6ee9d16..3a8d561f4eaf 100644 --- a/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs +++ b/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs @@ -99,25 +99,25 @@ protected IChatClient AddChatClientToKernel(IKernelBuilder builder) IChatClient chatClient; if (this.UseOpenAIConfig) { - chatClient = new Microsoft.Extensions.AI.OpenAIChatClient( - new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey), - TestConfiguration.OpenAI.ChatModelId); + chatClient = new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey) + .GetChatClient(TestConfiguration.OpenAI.ChatModelId) + .AsIChatClient(); } else if (!string.IsNullOrEmpty(this.ApiKey)) { - chatClient = new Microsoft.Extensions.AI.OpenAIChatClient( - openAIClient: new AzureOpenAIClient( + chatClient = new AzureOpenAIClient( endpoint: new Uri(TestConfiguration.AzureOpenAI.Endpoint), - credential: new ApiKeyCredential(TestConfiguration.AzureOpenAI.ApiKey)), - modelId: TestConfiguration.AzureOpenAI.ChatDeploymentName); + credential: new ApiKeyCredential(TestConfiguration.AzureOpenAI.ApiKey)) + .GetChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName) + .AsIChatClient(); } else { - chatClient = new Microsoft.Extensions.AI.OpenAIChatClient( - openAIClient: new AzureOpenAIClient( + chatClient = new AzureOpenAIClient( endpoint: new Uri(TestConfiguration.AzureOpenAI.Endpoint), - credential: new AzureCliCredential()), - modelId: TestConfiguration.AzureOpenAI.ChatDeploymentName); + credential: new AzureCliCredential()) + .GetChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName) + .AsIChatClient(); } var functionCallingChatClient = chatClient!.AsKernelFunctionInvokingChatClient(); diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs deleted file mode 100644 index a0d6b1865a8f..000000000000 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs +++ /dev/null @@ -1,631 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Buffers; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.ComponentModel; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.IO; -using System.Linq; -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Text.Json; -using System.Text.Json.Nodes; -using System.Text.Json.Serialization.Metadata; -using System.Text.RegularExpressions; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.AI; - -#pragma warning disable IDE0009 // Use explicit 'this.' qualifier -#pragma warning disable IDE1006 // Missing static prefix s_ suffix - -namespace Microsoft.SemanticKernel.ChatCompletion; - -// Slight modified source from -// https://raw.githubusercontent.com/dotnet/extensions/refs/heads/main/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs - -/// Provides factory methods for creating commonly used implementations of . -[ExcludeFromCodeCoverage] -internal static partial class AIFunctionFactory -{ - /// Holds the default options instance used when creating function. - private static readonly AIFunctionFactoryOptions _defaultOptions = new(); - - /// Creates an instance for a method, specified via a delegate. - /// The method to be represented via the created . - /// Metadata to use to override defaults inferred from . - /// The created for invoking . - /// - /// - /// Return values are serialized to using 's - /// . Arguments that are not already of the expected type are - /// marshaled to the expected type via JSON and using 's - /// . If the argument is a , - /// , or , it is deserialized directly. If the argument is anything else unknown, - /// it is round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. - /// - /// - public static AIFunction Create(Delegate method, AIFunctionFactoryOptions? options) - { - Verify.NotNull(method); - - return ReflectionAIFunction.Build(method.Method, method.Target, options ?? _defaultOptions); - } - - /// Creates an instance for a method, specified via a delegate. - /// The method to be represented via the created . - /// The name to use for the . - /// The description to use for the . - /// The used to marshal function parameters and any return value. - /// The created for invoking . - /// - /// - /// Return values are serialized to using . - /// Arguments that are not already of the expected type are marshaled to the expected type via JSON and using - /// . If the argument is a , , - /// or , it is deserialized directly. If the argument is anything else unknown, it is - /// round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. - /// - /// - public static AIFunction Create(Delegate method, string? name = null, string? description = null, JsonSerializerOptions? serializerOptions = null) - { - Verify.NotNull(method); - - AIFunctionFactoryOptions createOptions = serializerOptions is null && name is null && description is null - ? _defaultOptions - : new() - { - Name = name, - Description = description, - SerializerOptions = serializerOptions, - }; - - return ReflectionAIFunction.Build(method.Method, method.Target, createOptions); - } - - /// - /// Creates an instance for a method, specified via an instance - /// and an optional target object if the method is an instance method. - /// - /// The method to be represented via the created . - /// - /// The target object for the if it represents an instance method. - /// This should be if and only if is a static method. - /// - /// Metadata to use to override defaults inferred from . - /// The created for invoking . - /// - /// - /// Return values are serialized to using 's - /// . Arguments that are not already of the expected type are - /// marshaled to the expected type via JSON and using 's - /// . If the argument is a , - /// , or , it is deserialized directly. If the argument is anything else unknown, - /// it is round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. - /// - /// - public static AIFunction Create(MethodInfo method, object? target, AIFunctionFactoryOptions? options) - { - Verify.NotNull(method); - - return ReflectionAIFunction.Build(method, target, options ?? _defaultOptions); - } - - /// - /// Creates an instance for a method, specified via an instance - /// and an optional target object if the method is an instance method. - /// - /// The method to be represented via the created . - /// - /// The target object for the if it represents an instance method. - /// This should be if and only if is a static method. - /// - /// The name to use for the . - /// The description to use for the . - /// The used to marshal function parameters and return value. - /// The created for invoking . - /// - /// - /// Return values are serialized to using . - /// Arguments that are not already of the expected type are marshaled to the expected type via JSON and using - /// . If the argument is a , , - /// or , it is deserialized directly. If the argument is anything else unknown, it is - /// round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. - /// - /// - public static AIFunction Create(MethodInfo method, object? target, string? name = null, string? description = null, JsonSerializerOptions? serializerOptions = null) - { - Verify.NotNull(method); - - AIFunctionFactoryOptions createOptions = serializerOptions is null && name is null && description is null - ? _defaultOptions - : new() - { - Name = name, - Description = description, - SerializerOptions = serializerOptions, - }; - - return ReflectionAIFunction.Build(method, target, createOptions); - } - - private sealed class ReflectionAIFunction : AIFunction - { - public static ReflectionAIFunction Build(MethodInfo method, object? target, AIFunctionFactoryOptions options) - { - Verify.NotNull(method); - - if (method.ContainsGenericParameters) - { - throw new ArgumentException("Open generic methods are not supported", nameof(method)); - } - - if (!method.IsStatic && target is null) - { - throw new ArgumentNullException(nameof(target), "Target must not be null for an instance method."); - } - - ReflectionAIFunctionDescriptor functionDescriptor = ReflectionAIFunctionDescriptor.GetOrCreate(method, options); - - if (target is null && options.AdditionalProperties is null) - { - // We can use a cached value for static methods not specifying additional properties. - return functionDescriptor.CachedDefaultInstance ??= new(functionDescriptor, target, options); - } - - return new(functionDescriptor, target, options); - } - - private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, object? target, AIFunctionFactoryOptions options) - { - FunctionDescriptor = functionDescriptor; - Target = target; - AdditionalProperties = options.AdditionalProperties ?? EmptyReadOnlyDictionary.Instance; - } - - public ReflectionAIFunctionDescriptor FunctionDescriptor { get; } - public object? Target { get; } - public override IReadOnlyDictionary AdditionalProperties { get; } - public override string Name => FunctionDescriptor.Name; - public override string Description => FunctionDescriptor.Description; - public override MethodInfo UnderlyingMethod => FunctionDescriptor.Method; - public override JsonElement JsonSchema => FunctionDescriptor.JsonSchema; - public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions; - protected override Task InvokeCoreAsync( - IEnumerable>? arguments, - CancellationToken cancellationToken) - { - var paramMarshallers = FunctionDescriptor.ParameterMarshallers; - object?[] args = paramMarshallers.Length != 0 ? new object?[paramMarshallers.Length] : []; - - IReadOnlyDictionary argDict = - arguments is null || args.Length == 0 ? EmptyReadOnlyDictionary.Instance : - arguments as IReadOnlyDictionary ?? - arguments. -#if NET8_0_OR_GREATER - ToDictionary(); -#else - ToDictionary(kvp => kvp.Key, kvp => kvp.Value); -#endif - for (int i = 0; i < args.Length; i++) - { - args[i] = paramMarshallers[i](argDict, cancellationToken); - } - - return FunctionDescriptor.ReturnParameterMarshaller(ReflectionInvoke(FunctionDescriptor.Method, Target, args), cancellationToken); - } - } - - /// - /// A descriptor for a .NET method-backed AIFunction that precomputes its marshalling delegates and JSON schema. - /// - private sealed class ReflectionAIFunctionDescriptor - { - private const int InnerCacheSoftLimit = 512; - private static readonly ConditionalWeakTable> _descriptorCache = new(); - - /// A boxed . - private static readonly object? _boxedDefaultCancellationToken = default(CancellationToken); - - /// - /// Gets or creates a descriptors using the specified method and options. - /// - public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFunctionFactoryOptions options) - { - JsonSerializerOptions serializerOptions = options.SerializerOptions ?? AIJsonUtilities.DefaultOptions; - AIJsonSchemaCreateOptions schemaOptions = options.JsonSchemaCreateOptions ?? AIJsonSchemaCreateOptions.Default; - serializerOptions.MakeReadOnly(); - ConcurrentDictionary innerCache = _descriptorCache.GetOrCreateValue(serializerOptions); - - DescriptorKey key = new(method, options.Name, options.Description, schemaOptions); - if (innerCache.TryGetValue(key, out ReflectionAIFunctionDescriptor? descriptor)) - { - return descriptor; - } - - descriptor = new(key, serializerOptions); - return innerCache.Count < InnerCacheSoftLimit - ? innerCache.GetOrAdd(key, descriptor) - : descriptor; - } - - private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions serializerOptions) - { - // Get marshaling delegates for parameters. - ParameterInfo[] parameters = key.Method.GetParameters(); - ParameterMarshallers = new Func, CancellationToken, object?>[parameters.Length]; - for (int i = 0; i < parameters.Length; i++) - { - ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, parameters[i]); - } - - // Get a marshaling delegate for the return value. - ReturnParameterMarshaller = GetReturnParameterMarshaller(key.Method, serializerOptions); - - Method = key.Method; - Name = key.Name ?? GetFunctionName(key.Method); - Description = key.Description ?? key.Method.GetCustomAttribute(inherit: true)?.Description ?? string.Empty; - JsonSerializerOptions = serializerOptions; - JsonSchema = AIJsonUtilities.CreateFunctionJsonSchema( - key.Method, - Name, - Description, - serializerOptions, - key.SchemaOptions); - } - - public string Name { get; } - public string Description { get; } - public MethodInfo Method { get; } - public JsonSerializerOptions JsonSerializerOptions { get; } - public JsonElement JsonSchema { get; } - public Func, CancellationToken, object?>[] ParameterMarshallers { get; } - public Func> ReturnParameterMarshaller { get; } - public ReflectionAIFunction? CachedDefaultInstance { get; set; } - - private static string GetFunctionName(MethodInfo method) - { - // Get the function name to use. - string name = SanitizeMemberName(method.Name); - - const string AsyncSuffix = "Async"; - if (IsAsyncMethod(method) && - name.EndsWith(AsyncSuffix, StringComparison.Ordinal) && - name.Length > AsyncSuffix.Length) - { - name = name.Substring(0, name.Length - AsyncSuffix.Length); - } - - return name; - - static bool IsAsyncMethod(MethodInfo method) - { - Type t = method.ReturnType; - - if (t == typeof(Task) || t == typeof(ValueTask)) - { - return true; - } - - if (t.IsGenericType) - { - t = t.GetGenericTypeDefinition(); - if (t == typeof(Task<>) || t == typeof(ValueTask<>) || t == typeof(IAsyncEnumerable<>)) - { - return true; - } - } - - return false; - } - } - - /// - /// Gets a delegate for handling the marshaling of a parameter. - /// - private static Func, CancellationToken, object?> GetParameterMarshaller( - JsonSerializerOptions serializerOptions, - ParameterInfo parameter) - { - if (string.IsNullOrWhiteSpace(parameter.Name)) - { - throw new ArgumentException("Parameter is missing a name.", nameof(parameter)); - } - - // Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found. - Type parameterType = parameter.ParameterType; - JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(parameterType); - - // For CancellationToken parameters, we always bind to the token passed directly to InvokeAsync. - if (parameterType == typeof(CancellationToken)) - { - return static (_, cancellationToken) => - cancellationToken == default ? _boxedDefaultCancellationToken : // optimize common case of a default CT to avoid boxing - cancellationToken; - } - - // For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary. - return (arguments, _) => - { - // If the parameter has an argument specified in the dictionary, return that argument. - if (arguments.TryGetValue(parameter.Name, out object? value)) - { - return value switch - { - null => null, // Return as-is if null -- if the parameter is a struct this will be handled by MethodInfo.Invoke - _ when parameterType.IsInstanceOfType(value) => value, // Do nothing if value is assignable to parameter type - JsonElement element => JsonSerializer.Deserialize(element, typeInfo), - JsonDocument doc => JsonSerializer.Deserialize(doc, typeInfo), - JsonNode node => JsonSerializer.Deserialize(node, typeInfo), - _ => MarshallViaJsonRoundtrip(value), - }; - - object? MarshallViaJsonRoundtrip(object value) - { -#pragma warning disable CA1031 // Do not catch general exception types - try - { - string json = JsonSerializer.Serialize(value, serializerOptions.GetTypeInfo(value.GetType())); - return JsonSerializer.Deserialize(json, typeInfo); - } - catch - { - // Eat any exceptions and fall back to the original value to force a cast exception later on. - return value; - } -#pragma warning restore CA1031 - } - } - - // There was no argument for the parameter in the dictionary. - // Does it have a default value? - if (parameter.HasDefaultValue) - { - return parameter.DefaultValue; - } - - // Leave it empty. - return null; - }; - } - - /// - /// Gets a delegate for handling the result value of a method, converting it into the to return from the invocation. - /// - private static Func> GetReturnParameterMarshaller(MethodInfo method, JsonSerializerOptions serializerOptions) - { - Type returnType = method.ReturnType; - JsonTypeInfo returnTypeInfo; - - // Void - if (returnType == typeof(void)) - { - return static (_, _) => Task.FromResult(null); - } - - // Task - if (returnType == typeof(Task)) - { - return async static (result, _) => - { - await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false); - return null; - }; - } - - // ValueTask - if (returnType == typeof(ValueTask)) - { - return async static (result, _) => - { - await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(false); - return null; - }; - } - - if (returnType.IsGenericType) - { - // Task - if (returnType.GetGenericTypeDefinition() == typeof(Task<>)) - { - MethodInfo taskResultGetter = GetMethodFromGenericMethodDefinition(returnType, _taskGetResult); - returnTypeInfo = serializerOptions.GetTypeInfo(taskResultGetter.ReturnType); - return async (taskObj, cancellationToken) => - { - await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(false); - object? result = ReflectionInvoke(taskResultGetter, taskObj, null); - return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); - }; - } - - // ValueTask - if (returnType.GetGenericTypeDefinition() == typeof(ValueTask<>)) - { - MethodInfo valueTaskAsTask = GetMethodFromGenericMethodDefinition(returnType, _valueTaskAsTask); - MethodInfo asTaskResultGetter = GetMethodFromGenericMethodDefinition(valueTaskAsTask.ReturnType, _taskGetResult); - returnTypeInfo = serializerOptions.GetTypeInfo(asTaskResultGetter.ReturnType); - return async (taskObj, cancellationToken) => - { - var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!; - await task.ConfigureAwait(false); - object? result = ReflectionInvoke(asTaskResultGetter, task, null); - return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); - }; - } - } - - // For everything else, just serialize the result as-is. - returnTypeInfo = serializerOptions.GetTypeInfo(returnType); - return (result, cancellationToken) => SerializeResultAsync(result, returnTypeInfo, cancellationToken); - - static async Task SerializeResultAsync(object? result, JsonTypeInfo returnTypeInfo, CancellationToken cancellationToken) - { - if (returnTypeInfo.Kind is JsonTypeInfoKind.None) - { - // Special-case trivial contracts to avoid the more expensive general-purpose serialization path. - return JsonSerializer.SerializeToElement(result, returnTypeInfo); - } - - // Serialize asynchronously to support potential IAsyncEnumerable responses. - using PooledMemoryStream stream = new(); -#if NET9_0_OR_GREATER - await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken).ConfigureAwait(false); - Utf8JsonReader reader = new(stream.GetBuffer()); - return JsonElement.ParseValue(ref reader); -#else - await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken).ConfigureAwait(false); - stream.Position = 0; - var serializerOptions = _defaultOptions.SerializerOptions ?? AIJsonUtilities.DefaultOptions; - return await JsonSerializer.DeserializeAsync(stream, serializerOptions.GetTypeInfo(typeof(JsonElement)), cancellationToken).ConfigureAwait(false); -#endif - } - - // Throws an exception if a result is found to be null unexpectedly - static object ThrowIfNullResult(object? result) => result ?? throw new InvalidOperationException("Function returned null unexpectedly."); - } - - private static readonly MethodInfo _taskGetResult = typeof(Task<>).GetProperty(nameof(Task.Result), BindingFlags.Instance | BindingFlags.Public)!.GetMethod!; - private static readonly MethodInfo _valueTaskAsTask = typeof(ValueTask<>).GetMethod(nameof(ValueTask.AsTask), BindingFlags.Instance | BindingFlags.Public)!; - - private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedType, MethodInfo genericMethodDefinition) - { - Debug.Assert(specializedType.IsGenericType && specializedType.GetGenericTypeDefinition() == genericMethodDefinition.DeclaringType, "generic member definition doesn't match type."); -#if NET - return (MethodInfo)specializedType.GetMemberWithSameMetadataDefinitionAs(genericMethodDefinition); -#else -#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields - const BindingFlags All = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance; -#pragma warning restore S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields - return specializedType.GetMethods(All).First(m => m.MetadataToken == genericMethodDefinition.MetadataToken); -#endif - } - - private record struct DescriptorKey(MethodInfo Method, string? Name, string? Description, AIJsonSchemaCreateOptions SchemaOptions); - } - - /// - /// Removes characters from a .NET member name that shouldn't be used in an AI function name. - /// - /// The .NET member name that should be sanitized. - /// - /// Replaces non-alphanumeric characters in the identifier with the underscore character. - /// Primarily intended to remove characters produced by compiler-generated method name mangling. - /// - internal static string SanitizeMemberName(string memberName) - { - Verify.NotNull(memberName); - return InvalidNameCharsRegex().Replace(memberName, "_"); - } - - /// Regex that flags any character other than ASCII digits or letters or the underscore. -#if NET - [GeneratedRegex("[^0-9A-Za-z_]")] - private static partial Regex InvalidNameCharsRegex(); -#else - private static Regex InvalidNameCharsRegex() => _invalidNameCharsRegex; - private static readonly Regex _invalidNameCharsRegex = new("[^0-9A-Za-z_]", RegexOptions.Compiled); -#endif - - /// Invokes the MethodInfo with the specified target object and arguments. - private static object? ReflectionInvoke(MethodInfo method, object? target, object?[]? arguments) - { -#if NET - return method.Invoke(target, BindingFlags.DoNotWrapExceptions, binder: null, arguments, culture: null); -#else - try - { - return method.Invoke(target, BindingFlags.Default, binder: null, arguments, culture: null); - } - catch (TargetInvocationException e) when (e.InnerException is not null) - { - // If we're targeting .NET Framework, such that BindingFlags.DoNotWrapExceptions - // is ignored, the original exception will be wrapped in a TargetInvocationException. - // Unwrap it and throw that original exception, maintaining its stack information. - System.Runtime.ExceptionServices.ExceptionDispatchInfo.Capture(e.InnerException).Throw(); - throw; - } -#endif - } - - /// - /// Implements a simple write-only memory stream that uses pooled buffers. - /// - private sealed class PooledMemoryStream : Stream - { - private const int DefaultBufferSize = 4096; - private byte[] _buffer; - private int _position; - - public PooledMemoryStream(int initialCapacity = DefaultBufferSize) - { - _buffer = ArrayPool.Shared.Rent(initialCapacity); - _position = 0; - } - - public ReadOnlySpan GetBuffer() => _buffer.AsSpan(0, _position); - public override bool CanWrite => true; - public override bool CanRead => false; - public override bool CanSeek => false; - public override long Length => _position; - public override long Position - { - get => _position; - set => throw new NotSupportedException(); - } - - public override void Write(byte[] buffer, int offset, int count) - { - EnsureNotDisposed(); - EnsureCapacity(_position + count); - - Buffer.BlockCopy(buffer, offset, _buffer, _position, count); - _position += count; - } - - public override void Flush() - { - } - - public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); - public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); - public override void SetLength(long value) => throw new NotSupportedException(); - - protected override void Dispose(bool disposing) - { - if (_buffer is not null) - { - ArrayPool.Shared.Return(_buffer); - _buffer = null!; - } - - base.Dispose(disposing); - } - - private void EnsureCapacity(int requiredCapacity) - { - if (requiredCapacity <= _buffer.Length) - { - return; - } - - int newCapacity = Math.Max(requiredCapacity, _buffer.Length * 2); - byte[] newBuffer = ArrayPool.Shared.Rent(newCapacity); - Buffer.BlockCopy(_buffer, 0, newBuffer, 0, _position); - - ArrayPool.Shared.Return(_buffer); - _buffer = newBuffer; - } - - private void EnsureNotDisposed() - { - if (_buffer is null) - { - Throw(); - static void Throw() => throw new ObjectDisposedException(nameof(PooledMemoryStream)); - } - } - } -} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs index 8a5abc42a6e0..58ea317804f9 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs @@ -33,7 +33,7 @@ internal ChatClientAIService(IChatClient chatClient) var metadata = this._chatClient.GetService(); Verify.NotNull(metadata); - this._internalAttributes[nameof(metadata.ModelId)] = metadata.ModelId; + this._internalAttributes[AIServiceExtensions.ModelIdKey] = metadata.DefaultModelId; this._internalAttributes[nameof(metadata.ProviderName)] = metadata.ProviderName; this._internalAttributes[nameof(metadata.ProviderUri)] = metadata.ProviderUri; } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs index c36ca453bd04..e035a436a83a 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs @@ -1,10 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; namespace Microsoft.SemanticKernel.ChatCompletion; @@ -27,7 +29,7 @@ internal static Task GetResponseAsync( Kernel? kernel = null, CancellationToken cancellationToken = default) { - var chatOptions = executionSettings?.ToChatOptions(kernel); + var chatOptions = GetChatOptionsFromSettings(executionSettings, kernel); // Try to parse the text as a chat history if (ChatPromptParser.TryParse(prompt, out var chatHistoryFromPrompt)) @@ -39,6 +41,25 @@ internal static Task GetResponseAsync( return chatClient.GetResponseAsync(prompt, chatOptions, cancellationToken); } + /// Get ChatClient streaming response for the prompt, settings and kernel. + /// Target chat client service. + /// The standardized prompt input. + /// The AI execution settings (optional). + /// The containing services, plugins, and other state for use throughout the operation. + /// The to monitor for cancellation requests. The default is . + /// Streaming list of different completion streaming string updates generated by the remote model + internal static IAsyncEnumerable GetStreamingResponseAsync( + this IChatClient chatClient, + string prompt, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + { + var chatOptions = GetChatOptionsFromSettings(executionSettings, kernel); + + return chatClient.GetStreamingResponseAsync(prompt, chatOptions, cancellationToken); + } + /// Creates an for the specified . /// The chat client to be represented as a chat completion service. /// An optional that can be used to resolve services to use in the instance. @@ -64,21 +85,31 @@ public static IChatCompletionService AsChatCompletionService(this IChatClient cl { Verify.NotNull(client); - return client.GetService()?.ModelId; + return client.GetService()?.DefaultModelId; } /// /// Creates a new that supports for function invocation with a . /// /// Target chat client service. + /// Optional logger factory to use for logging. /// Function invoking chat client. [Experimental("SKEXP0001")] - public static IChatClient AsKernelFunctionInvokingChatClient(this IChatClient client) + public static IChatClient AsKernelFunctionInvokingChatClient(this IChatClient client, ILoggerFactory? loggerFactory = null) { Verify.NotNull(client); return client is KernelFunctionInvokingChatClient kernelFunctionInvocationClient ? kernelFunctionInvocationClient - : new KernelFunctionInvokingChatClient(client); + : new KernelFunctionInvokingChatClient(client, loggerFactory); + } + + private static ChatOptions GetChatOptionsFromSettings(PromptExecutionSettings? executionSettings, Kernel? kernel) + { + ChatOptions chatOptions = executionSettings?.ToChatOptions(kernel) ?? new ChatOptions().AddKernel(kernel); + + // Passing by reference to be used by AutoFunctionInvocationFilters + chatOptions.AdditionalProperties![ChatOptionsExtensions.PromptExecutionSettingsKey] = executionSettings; + return chatOptions; } } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs index 3d190a48c5e4..1501cb71d988 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Collections.Generic; using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel.ChatCompletion; @@ -7,7 +8,6 @@ namespace Microsoft.SemanticKernel.ChatCompletion; internal static class ChatMessageExtensions { /// Converts a to a . - /// This conversion should not be necessary once SK eventually adopts the shared content types. internal static ChatMessageContent ToChatMessageContent(this ChatMessage message, Microsoft.Extensions.AI.ChatResponse? response = null) { ChatMessageContent result = new() @@ -46,4 +46,15 @@ internal static ChatMessageContent ToChatMessageContent(this ChatMessage message return result; } + + /// Converts a list of to a . + internal static ChatHistory ToChatHistory(this IEnumerable chatMessages) + { + ChatHistory chatHistory = []; + foreach (var message in chatMessages) + { + chatHistory.Add(message.ToChatMessageContent()); + } + return chatHistory; + } } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs index d8fab37e57bd..68540a1c32d8 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs @@ -13,6 +13,11 @@ namespace Microsoft.SemanticKernel.ChatCompletion; /// internal static class ChatOptionsExtensions { + internal const string KernelKey = "AutoInvokingKernel"; + internal const string IsStreamingKey = "AutoInvokingIsStreaming"; + internal const string ChatMessageContentKey = "AutoInvokingChatCompletionContent"; + internal const string PromptExecutionSettingsKey = "AutoInvokingPromptExecutionSettings"; + /// Converts a to a . internal static PromptExecutionSettings? ToPromptExecutionSettings(this ChatOptions? options) { @@ -118,4 +123,23 @@ internal static class ChatOptionsExtensions return settings; } + + /// + /// To enable usage of AutoFunctionInvocationFilters with ChatClient's the kernel needs to be provided in the ChatOptions + /// + /// Chat options. + /// Kernel to be used for auto function invocation. + internal static ChatOptions AddKernel(this ChatOptions options, Kernel? kernel) + { + Verify.NotNull(options); + + // Only add the kernel if it is provided + if (kernel is not null) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties.TryAdd(KernelKey, kernel); + } + + return options; + } } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/FunctionFactoryOptions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/FunctionFactoryOptions.cs deleted file mode 100644 index f9f43ee630ae..000000000000 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/FunctionFactoryOptions.cs +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.ComponentModel; -using System.Diagnostics.CodeAnalysis; -using System.Reflection; -using System.Text.Json; -using Microsoft.Extensions.AI; - -namespace Microsoft.SemanticKernel.ChatCompletion; - -// Slight modified source from -// https://raw.githubusercontent.com/dotnet/extensions/refs/heads/main/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryOptions.cs - -/// -/// Represents options that can be provided when creating an from a method. -/// -[ExcludeFromCodeCoverage] -internal sealed class AIFunctionFactoryOptions -{ - /// - /// Initializes a new instance of the class. - /// - public AIFunctionFactoryOptions() - { - } - - /// Gets or sets the used to marshal .NET values being passed to the underlying delegate. - /// - /// If no value has been specified, the instance will be used. - /// - public JsonSerializerOptions? SerializerOptions { get; set; } - - /// - /// Gets or sets the governing the generation of JSON schemas for the function. - /// - /// - /// If no value has been specified, the instance will be used. - /// - public AIJsonSchemaCreateOptions? JsonSchemaCreateOptions { get; set; } - - /// Gets or sets the name to use for the function. - /// - /// The name to use for the function. The default value is a name derived from the method represented by the passed or . - /// - public string? Name { get; set; } - - /// Gets or sets the description to use for the function. - /// - /// The description for the function. The default value is a description derived from the passed or , if possible - /// (for example, via a on the method). - /// - public string? Description { get; set; } - - /// - /// Gets or sets additional values to store on the resulting property. - /// - /// - /// This property can be used to provide arbitrary information about the function. - /// - public IReadOnlyDictionary? AdditionalProperties { get; set; } -} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs deleted file mode 100644 index da5af46620fd..000000000000 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; -using Microsoft.Extensions.AI; - -#pragma warning disable IDE0009 // Use explicit 'this.' qualifier -#pragma warning disable CA2213 // Disposable fields should be disposed -#pragma warning disable IDE0044 // Add readonly modifier - -namespace Microsoft.SemanticKernel.ChatCompletion; - -// Slight modified source from -// https://raw.githubusercontent.com/dotnet/extensions/refs/heads/main/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs - -/// Provides context for an in-flight function invocation. -[ExcludeFromCodeCoverage] -internal sealed class KernelFunctionInvocationContext -{ - /// - /// A nop function used to allow to be non-nullable. Default instances of - /// start with this as the target function. - /// - private static readonly AIFunction s_nopFunction = AIFunctionFactory.Create(() => { }, nameof(KernelFunctionInvocationContext)); - - /// The chat contents associated with the operation that initiated this function call request. - private IList _messages = Array.Empty(); - - /// The AI function to be invoked. - private AIFunction _function = s_nopFunction; - - /// The function call content information associated with this invocation. - private Microsoft.Extensions.AI.FunctionCallContent _callContent = new(string.Empty, s_nopFunction.Name, EmptyReadOnlyDictionary.Instance); - - /// Initializes a new instance of the class. - internal KernelFunctionInvocationContext() - { - } - - /// Gets or sets the function call content information associated with this invocation. - public Microsoft.Extensions.AI.FunctionCallContent CallContent - { - get => _callContent; - set - { - Verify.NotNull(value); - _callContent = value; - } - } - - /// Gets or sets the chat contents associated with the operation that initiated this function call request. - public IList Messages - { - get => _messages; - set - { - Verify.NotNull(value); - _messages = value; - } - } - - /// Gets or sets the chat options associated with the operation that initiated this function call request. - public ChatOptions? Options { get; set; } - - /// Gets or sets the AI function to be invoked. - public AIFunction Function - { - get => _function; - set - { - Verify.NotNull(value); - _function = value; - } - } - - /// Gets or sets the number of this iteration with the underlying client. - /// - /// The initial request to the client that passes along the chat contents provided to the - /// is iteration 1. If the client responds with a function call request, the next request to the client is iteration 2, and so on. - /// - public int Iteration { get; set; } - - /// Gets or sets the index of the function call within the iteration. - /// - /// The response from the underlying client may include multiple function call requests. - /// This index indicates the position of the function call within the iteration. - /// - public int FunctionCallIndex { get; set; } - - /// Gets or sets the total number of function call requests within the iteration. - /// - /// The response from the underlying client might include multiple function call requests. - /// This count indicates how many there were. - /// - public int FunctionCount { get; set; } - - /// Gets or sets a value indicating whether to terminate the request. - /// - /// In response to a function call request, the function might be invoked, its result added to the chat contents, - /// and a new request issued to the wrapped client. If this property is set to , that subsequent request - /// will not be issued and instead the loop immediately terminated rather than continuing until there are no - /// more function call requests in responses. - /// - public bool Terminate { get; set; } -} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index 1a59b8f5ccbd..ea2dce48fc62 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -1,26 +1,31 @@ // Copyright (c) Microsoft. All rights reserved. using System; +#pragma warning restore IDE0073 // The file header does not match the required text using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; +using System.Runtime.ExceptionServices; using System.Text.Json; using System.Threading; using System.Threading.Tasks; -using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +#pragma warning disable IDE1006 // Naming Styles +#pragma warning disable IDE0009 // This #pragma warning disable CA2213 // Disposable fields should be disposed -#pragma warning disable IDE0009 // Use explicit 'this.' qualifier -#pragma warning disable IDE1006 // Missing prefix: 's_' +#pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test +#pragma warning disable SA1202 // 'protected' members should come before 'private' members -namespace Microsoft.SemanticKernel.ChatCompletion; +// Modified source from 2025-04-07 +// https://raw.githubusercontent.com/dotnet/extensions/84d09b794d994435568adcbb85a981143d4f15cb/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs -// Slight modified source from -// https://raw.githubusercontent.com/dotnet/extensions/refs/heads/main/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +namespace Microsoft.Extensions.AI; /// /// A delegating chat client that invokes functions defined on . @@ -28,9 +33,11 @@ namespace Microsoft.SemanticKernel.ChatCompletion; /// /// /// -/// When this client receives a in a chat response, it responds -/// by calling the corresponding defined in , -/// producing a . +/// When this client receives a in a chat response, it responds +/// by calling the corresponding defined in , +/// producing a that it sends back to the inner client. This loop +/// is repeated until there are no more function calls to make, or until another stop condition is met, +/// such as hitting . /// /// /// The provided implementation of is thread-safe for concurrent use so long as the @@ -44,11 +51,13 @@ namespace Microsoft.SemanticKernel.ChatCompletion; /// invocation requests to that same function. /// /// -[ExcludeFromCodeCoverage] -internal sealed partial class KernelFunctionInvokingChatClient : DelegatingChatClient +public partial class KernelFunctionInvokingChatClient : DelegatingChatClient { - /// The for the current function invocation. - private static readonly AsyncLocal _currentContext = new(); + /// The for the current function invocation. + private static readonly AsyncLocal _currentContext = new(); + + /// Optional services used for function invocation. + private readonly IServiceProvider? _functionInvocationServices; /// The logger to use for logging information about function invocation. private readonly ILogger _logger; @@ -58,49 +67,37 @@ internal sealed partial class KernelFunctionInvokingChatClient : DelegatingChatC private readonly ActivitySource? _activitySource; /// Maximum number of roundtrips allowed to the inner client. - private int? _maximumIterationsPerRequest; + private int _maximumIterationsPerRequest = 10; + + /// Maximum number of consecutive iterations that are allowed contain at least one exception result. If the limit is exceeded, we rethrow the exception instead of continuing. + private int _maximumConsecutiveErrorsPerRequest = 3; /// /// Initializes a new instance of the class. /// /// The underlying , or the next instance in a chain of clients. - /// An to use for logging information about function invocation. - public KernelFunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null) + /// An to use for logging information about function invocation. + /// An optional to use for resolving services required by the instances being invoked. + public KernelFunctionInvokingChatClient(IChatClient innerClient, ILoggerFactory? loggerFactory = null, IServiceProvider? functionInvocationServices = null) : base(innerClient) { - _logger = logger ?? NullLogger.Instance; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; _activitySource = innerClient.GetService(); + _functionInvocationServices = functionInvocationServices; } /// - /// Gets or sets the for the current function invocation. + /// Gets or sets the for the current function invocation. /// /// /// This value flows across async calls. /// - internal static KernelFunctionInvocationContext? CurrentContext + public static AutoFunctionInvocationContext? CurrentContext { get => _currentContext.Value; - set => _currentContext.Value = value; + protected set => _currentContext.Value = value; } - /// - /// Gets or sets a value indicating whether to handle exceptions that occur during function calls. - /// - /// - /// if the - /// underlying will be instructed to give a response without invoking - /// any further functions if a function call fails with an exception. - /// if the underlying is allowed - /// to continue attempting function calls until is reached. - /// The default value is . - /// - /// - /// Changing the value of this property while the client is in use might result in inconsistencies - /// as to whether errors are retried during an in-flight request. - /// - public bool RetryOnError { get; set; } - /// /// Gets or sets a value indicating whether detailed exception information should be included /// in the chat history when calling the underlying . @@ -116,17 +113,17 @@ internal static KernelFunctionInvocationContext? CurrentContext /// Setting the value to prevents the underlying language model from disclosing /// raw exception details to the end user, since it doesn't receive that information. Even in this /// case, the raw object is available to application code by inspecting - /// the property. + /// the property. /// /// /// Setting the value to can help the underlying bypass problems on - /// its own, for example by retrying the function call with different arguments. However it might + /// its own, for example by retrying the function call with different arguments. However, it might /// result in disclosing the raw exception information to external users, which can be a security /// concern depending on the application scenario. /// /// /// Changing the value of this property while the client is in use might result in inconsistencies - /// as to whether detailed errors are provided during an in-flight request. + /// whether detailed errors are provided during an in-flight request. /// /// public bool IncludeDetailedErrors { get; set; } @@ -151,23 +148,22 @@ internal static KernelFunctionInvocationContext? CurrentContext /// /// /// The maximum number of iterations per request. - /// The default value is . + /// The default value is 10. /// /// /// - /// Each request to this might end up making + /// Each request to this might end up making /// multiple requests to the inner client. Each time the inner client responds with /// a function call request, this client might perform that invocation and send the results /// back to the inner client in a new request. This property limits the number of times - /// such a roundtrip is performed. If null, there is no limit applied. If set, the value - /// must be at least one, as it includes the initial request. + /// such a roundtrip is performed. The value must be at least one, as it includes the initial request. /// /// /// Changing the value of this property while the client is in use might result in inconsistencies /// as to how many iterations are allowed for an in-flight request. /// /// - public int? MaximumIterationsPerRequest + public int MaximumIterationsPerRequest { get => _maximumIterationsPerRequest; set @@ -181,6 +177,47 @@ public int? MaximumIterationsPerRequest } } + /// + /// Gets or sets the maximum number of consecutive iterations that are allowed to fail with an error. + /// + /// + /// The maximum number of consecutive iterations that are allowed to fail with an error. + /// The default value is 3. + /// + /// + /// + /// When function invocations fail with an exception, the + /// continues to make requests to the inner client, optionally supplying exception information (as + /// controlled by ). This allows the to + /// recover from errors by trying other function parameters that may succeed. + /// + /// + /// However, in case function invocations continue to produce exceptions, this property can be used to + /// limit the number of consecutive failing attempts. When the limit is reached, the exception will be + /// rethrown to the caller. + /// + /// + /// If the value is set to zero, all function calling exceptions immediately terminate the function + /// invocation loop and the exception will be rethrown to the caller. + /// + /// + /// Changing the value of this property while the client is in use might result in inconsistencies + /// as to how many iterations are allowed for an in-flight request. + /// + /// + public int MaximumConsecutiveErrorsPerRequest + { + get => _maximumConsecutiveErrorsPerRequest; + set + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "Argument less than minimum value 0"); + } + _maximumConsecutiveErrorsPerRequest = value; + } + } + /// public override async Task GetResponseAsync( IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) @@ -200,8 +237,9 @@ public override async Task GetResponseAsync( ChatResponse? response = null; // the response from the inner client, which is possibly modified and then eventually returned List? responseMessages = null; // tracked list of messages, across multiple turns, to be used for the final response UsageDetails? totalUsage = null; // tracked usage across all turns, to be used for the final response - List? functionCallContents = null; // function call contents that need responding to in the current turn + List? functionCallContents = null; // function call contents that need responding to in the current turn bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set + int consecutiveErrorCount = 0; for (int iteration = 0; ; iteration++) { @@ -217,7 +255,7 @@ public override async Task GetResponseAsync( // Any function call work to do? If yes, ensure we're tracking that work in functionCallContents. bool requiresFunctionInvocation = options?.Tools is { Count: > 0 } && - (!MaximumIterationsPerRequest.HasValue || iteration < MaximumIterationsPerRequest.GetValueOrDefault()) && + iteration < MaximumIterationsPerRequest && CopyFunctionCalls(response.Messages, ref functionCallContents); // In a common case where we make a request and there's no function calling work required, @@ -227,7 +265,7 @@ public override async Task GetResponseAsync( return response; } - // Track aggregatable details from the response, including all of the response messages and usage details. + // Track aggregate details from the response, including all the response messages and usage details. (responseMessages ??= []).AddRange(response.Messages); if (response.Usage is not null) { @@ -252,16 +290,24 @@ public override async Task GetResponseAsync( // Prepare the history for the next iteration. FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId); + // Prepare the options for the next auto function invocation iteration. + UpdateOptionsForAutoFunctionInvocation(ref options!, response.Messages.Last().ToChatMessageContent(), isStreaming: false); + // Add the responses from the function calls into the augmented history and also into the tracked // list of response messages. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, cancellationToken).ConfigureAwait(false); + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, isStreaming: false, cancellationToken).ConfigureAwait(false); responseMessages.AddRange(modeAndMessages.MessagesAdded); + consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; - if (UpdateOptionsForMode(modeAndMessages.Mode, ref options!, response.ChatThreadId)) + if (modeAndMessages.ShouldTerminate) { - // Terminate break; } + + // Clear the auto function invocation options. + ClearOptionsForAutoFunctionInvocation(ref options); + + UpdateOptionsForNextIteration(ref options!, response.ChatThreadId); } Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages."); @@ -279,7 +325,7 @@ public override async IAsyncEnumerable GetStreamingResponseA // A single request into this GetStreamingResponseAsync may result in multiple requests to the inner client. // Create an activity to group them together for better observability. - using Activity? activity = _activitySource?.StartActivity(nameof(KernelFunctionInvokingChatClient)); + using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); // Copy the original messages in order to avoid enumerating the original messages multiple times. // The IEnumerable can represent an arbitrary amount of work. @@ -287,10 +333,11 @@ public override async IAsyncEnumerable GetStreamingResponseA messages = originalMessages; List? augmentedHistory = null; // the actual history of messages sent on turns other than the first - List? functionCallContents = null; // function call contents that need responding to in the current turn + List? functionCallContents = null; // function call contents that need responding to in the current turn List? responseMessages = null; // tracked list of messages, across multiple turns, to be used in fallback cases to reconstitute history bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set List updates = []; // updates from the current response + int consecutiveErrorCount = 0; for (int iteration = 0; ; iteration++) { @@ -315,7 +362,7 @@ public override async IAsyncEnumerable GetStreamingResponseA // If there are no tools to call, or for any other reason we should stop, return the response. if (functionCallContents is not { Count: > 0 } || options?.Tools is not { Count: > 0 } || - (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations)) + iteration >= _maximumIterationsPerRequest) { break; } @@ -327,13 +374,26 @@ public override async IAsyncEnumerable GetStreamingResponseA // Prepare the history for the next iteration. FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId); - // Process all of the functions, adding their results into the history. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); + // Prepare the options for the next auto function invocation iteration. + UpdateOptionsForAutoFunctionInvocation(ref options, response.Messages.Last().ToChatMessageContent(), isStreaming: true); + + // Process all the functions, adding their results into the history. + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, isStreaming: true, cancellationToken).ConfigureAwait(false); responseMessages.AddRange(modeAndMessages.MessagesAdded); + consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; - // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages - // includes all activities, including generated function results. + // Clear the auto function invocation options. + ClearOptionsForAutoFunctionInvocation(ref options); + + // This is a synthetic ID since we're generating the tool messages instead of getting them from + // the underlying provider. When emitting the streamed chunks, it's perfectly valid for us to + // use the same message ID for all of them within a given iteration, as this is a single logical + // message with multiple content items. We could also use different message IDs per tool content, + // but there's no benefit to doing so. string toolResponseId = Guid.NewGuid().ToString("N"); + + // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages + // include all activity, including generated function results. foreach (var message in modeAndMessages.MessagesAdded) { var toolResultUpdate = new ChatResponseUpdate @@ -345,6 +405,7 @@ public override async IAsyncEnumerable GetStreamingResponseA Contents = message.Contents, RawRepresentation = message.RawRepresentation, ResponseId = toolResponseId, + MessageId = toolResponseId, // See above for why this can be the same as ResponseId Role = message.Role, }; @@ -352,11 +413,12 @@ public override async IAsyncEnumerable GetStreamingResponseA Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 } - if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, response.ChatThreadId)) + if (modeAndMessages.ShouldTerminate) { - // Terminate yield break; } + + UpdateOptionsForNextIteration(ref options, response.ChatThreadId); } } @@ -396,7 +458,7 @@ private static void FixupHistories( { // In the very rare case where the inner client returned a response with a thread ID but then // returned a subsequent response without one, we want to reconstitute the full history. To do that, - // we can populate the history with the original chat messages and then all of the response + // we can populate the history with the original chat messages and then all the response // messages up until this point, which includes the most recent ones. augmentedHistory ??= []; augmentedHistory.Clear(); @@ -424,7 +486,7 @@ private static void FixupHistories( /// Copies any from to . private static bool CopyFunctionCalls( - IList messages, [NotNullWhen(true)] ref List? functionCalls) + IList messages, [NotNullWhen(true)] ref List? functionCalls) { bool any = false; int count = messages.Count; @@ -436,15 +498,15 @@ private static bool CopyFunctionCalls( return any; } - /// Copies any from to . + /// Copies any from to . private static bool CopyFunctionCalls( - IList content, [NotNullWhen(true)] ref List? functionCalls) + IList content, [NotNullWhen(true)] ref List? functionCalls) { bool any = false; int count = content.Count; for (int i = 0; i < count; i++) { - if (content[i] is Microsoft.Extensions.AI.FunctionCallContent functionCall) + if (content[i] is FunctionCallContent functionCall) { (functionCalls ??= []).Add(functionCall); any = true; @@ -454,47 +516,54 @@ private static bool CopyFunctionCalls( return any; } - /// Updates for the response. - /// true if the function calling loop should terminate; otherwise, false. - private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions options, string? chatThreadId) + private static void UpdateOptionsForAutoFunctionInvocation(ref ChatOptions options, ChatMessageContent content, bool isStreaming) { - switch (mode) + if (options.AdditionalProperties?.ContainsKey(ChatOptionsExtensions.IsStreamingKey) ?? false) { - case ContinueMode.Continue when options.ToolMode is RequiredChatToolMode: - // We have to reset the tool mode to be non-required after the first iteration, - // as otherwise we'll be in an infinite loop. - options = options.Clone(); - options.ToolMode = null; - options.ChatThreadId = chatThreadId; - - break; + throw new KernelException($"The reserved key name '{ChatOptionsExtensions.IsStreamingKey}' is already specified in the options. Avoid using this key name."); + } - case ContinueMode.AllowOneMoreRoundtrip: - // The LLM gets one further chance to answer, but cannot use tools. - options = options.Clone(); - options.Tools = null; - options.ToolMode = null; - options.ChatThreadId = chatThreadId; + if (options.AdditionalProperties?.ContainsKey(ChatOptionsExtensions.ChatMessageContentKey) ?? false) + { + throw new KernelException($"The reserved key name '{ChatOptionsExtensions.ChatMessageContentKey}' is already specified in the options. Avoid using this key name."); + } - break; + options.AdditionalProperties ??= []; - case ContinueMode.Terminate: - // Bail immediately. - return true; + options.AdditionalProperties[ChatOptionsExtensions.IsStreamingKey] = isStreaming; + options.AdditionalProperties[ChatOptionsExtensions.ChatMessageContentKey] = content; + } - default: - // As with the other modes, ensure we've propagated the chat thread ID to the options. - // We only need to clone the options if we're actually mutating it. - if (options.ChatThreadId != chatThreadId) - { - options = options.Clone(); - options.ChatThreadId = chatThreadId; - } + private static void ClearOptionsForAutoFunctionInvocation(ref ChatOptions options) + { + if (options.AdditionalProperties?.ContainsKey(ChatOptionsExtensions.IsStreamingKey) ?? false) + { + options.AdditionalProperties.Remove(ChatOptionsExtensions.IsStreamingKey); + } - break; + if (options.AdditionalProperties?.ContainsKey(ChatOptionsExtensions.ChatMessageContentKey) ?? false) + { + options.AdditionalProperties.Remove(ChatOptionsExtensions.ChatMessageContentKey); } + } - return false; + private static void UpdateOptionsForNextIteration(ref ChatOptions options, string? chatThreadId) + { + if (options.ToolMode is RequiredChatToolMode) + { + // We have to reset the tool mode to be non-required after the first iteration, + // as otherwise we'll be in an infinite loop. + options = options.Clone(); + options.ToolMode = null; + options.ChatThreadId = chatThreadId; + } + else if (options.ChatThreadId != chatThreadId) + { + // As with the other modes, ensure we've propagated the chat thread ID to the options. + // We only need to clone the options if we're actually mutating it. + options = options.Clone(); + options.ChatThreadId = chatThreadId; + } } /// @@ -504,71 +573,124 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti /// The options used for the response being processed. /// The function call contents representing the functions to be invoked. /// The iteration number of how many roundtrips have been made to the inner client. + /// The number of consecutive iterations, prior to this one, that were recorded as having function invocation errors. + /// Whether the function calls are being processed in a streaming context. /// The to monitor for cancellation requests. - /// A value indicating how the caller should proceed. - private async Task<(ContinueMode Mode, IList MessagesAdded)> ProcessFunctionCallsAsync( - List messages, ChatOptions options, List functionCallContents, int iteration, CancellationToken cancellationToken) + /// A value indicating how the caller should proceed. + private async Task<(bool ShouldTerminate, int NewConsecutiveErrorCount, IList MessagesAdded)> ProcessFunctionCallsAsync( + List messages, ChatOptions options, List functionCallContents, + int iteration, int consecutiveErrorCount, bool isStreaming, CancellationToken cancellationToken) { // We must add a response for every tool call, regardless of whether we successfully executed it or not. // If we successfully execute it, we'll add the result. If we don't, we'll add an error. - Debug.Assert(functionCallContents.Count > 0, "Expecteded at least one function call."); + Debug.Assert(functionCallContents.Count > 0, "Expected at least one function call."); + var shouldTerminate = false; + + var captureCurrentIterationExceptions = consecutiveErrorCount < _maximumConsecutiveErrorsPerRequest; // Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel. if (functionCallContents.Count == 1) { FunctionInvocationResult result = await ProcessFunctionCallAsync( - messages, options, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false); + messages, options, functionCallContents, iteration, 0, captureCurrentIterationExceptions, isStreaming, cancellationToken).ConfigureAwait(false); IList added = CreateResponseMessages([result]); ThrowIfNoFunctionResultsAdded(added); + UpdateConsecutiveErrorCountOrThrow(added, ref consecutiveErrorCount); messages.AddRange(added); - return (result.ContinueMode, added); + return (result.ShouldTerminate, consecutiveErrorCount, added); } else { - FunctionInvocationResult[] results; + List results = []; + var terminationRequested = false; if (AllowConcurrentInvocation) { - // Schedule the invocation of every function. - results = await Task.WhenAll( + // Rather than awaiting each function before invoking the next, invoke all of them + // and then await all of them. We avoid forcibly introducing parallelism via Task.Run, + // but if a function invocation completes asynchronously, its processing can overlap + // with the processing of other the other invocation invocations. + results.AddRange(await Task.WhenAll( from i in Enumerable.Range(0, functionCallContents.Count) - select Task.Run(() => ProcessFunctionCallAsync( + select ProcessFunctionCallAsync( messages, options, functionCallContents, - iteration, i, cancellationToken))).ConfigureAwait(false); + iteration, i, captureExceptions: true, isStreaming, cancellationToken)).ConfigureAwait(false)); + + terminationRequested = results.Any(r => r.ShouldTerminate); } else { // Invoke each function serially. - results = new FunctionInvocationResult[functionCallContents.Count]; - for (int i = 0; i < results.Length; i++) + for (int i = 0; i < functionCallContents.Count; i++) { - results[i] = await ProcessFunctionCallAsync( + var result = await ProcessFunctionCallAsync( messages, options, functionCallContents, - iteration, i, cancellationToken).ConfigureAwait(false); + iteration, i, captureCurrentIterationExceptions, isStreaming, cancellationToken).ConfigureAwait(false); + + results.Add(result); + + if (result.ShouldTerminate) + { + shouldTerminate = true; + terminationRequested = true; + break; + } } } - ContinueMode continueMode = ContinueMode.Continue; - IList added = CreateResponseMessages(results); ThrowIfNoFunctionResultsAdded(added); + UpdateConsecutiveErrorCountOrThrow(added, ref consecutiveErrorCount); messages.AddRange(added); - foreach (FunctionInvocationResult fir in results) + + if (!terminationRequested) { - if (fir.ContinueMode > continueMode) + // If any function requested termination, we'll terminate. + shouldTerminate = false; + foreach (FunctionInvocationResult fir in results) { - continueMode = fir.ContinueMode; + shouldTerminate = shouldTerminate || fir.ShouldTerminate; } } - return (continueMode, added); + return (shouldTerminate, consecutiveErrorCount, added); } } + private void UpdateConsecutiveErrorCountOrThrow(IList added, ref int consecutiveErrorCount) + { + var allExceptions = added.SelectMany(m => m.Contents.OfType()) + .Select(frc => frc.Exception!) + .Where(e => e is not null); + +#pragma warning disable CA1851 // Possible multiple enumerations of 'IEnumerable' collection + if (allExceptions.Any()) + { + consecutiveErrorCount++; + if (consecutiveErrorCount > _maximumConsecutiveErrorsPerRequest) + { + var allExceptionsArray = allExceptions.ToArray(); + if (allExceptionsArray.Length == 1) + { + ExceptionDispatchInfo.Capture(allExceptionsArray[0]).Throw(); + } + else + { + throw new AggregateException(allExceptionsArray); + } + } + } + else + { + consecutiveErrorCount = 0; + } +#pragma warning restore CA1851 // Possible multiple enumerations of 'IEnumerable' collection + } + /// /// Throws an exception if doesn't create any messages. /// @@ -576,7 +698,7 @@ private void ThrowIfNoFunctionResultsAdded(IList? messages) { if (messages is null || messages.Count == 0) { - throw new InvalidOperationException($"{GetType().Name}.{nameof(CreateResponseMessages)} returned null or an empty collection of messages."); + throw new InvalidOperationException($"{this.GetType().Name}.{nameof(this.CreateResponseMessages)} returned null or an empty collection of messages."); } } @@ -586,11 +708,13 @@ private void ThrowIfNoFunctionResultsAdded(IList? messages) /// The function call contents representing all the functions being invoked. /// The iteration number of how many roundtrips have been made to the inner client. /// The 0-based index of the function being called out of . + /// If true, handles function-invocation exceptions by returning a value with . Otherwise, rethrows. + /// Whether the function calls are being processed in a streaming context. /// The to monitor for cancellation requests. - /// A value indicating how the caller should proceed. + /// A value indicating how the caller should proceed. private async Task ProcessFunctionCallAsync( - List messages, ChatOptions options, List callContents, - int iteration, int functionCallIndex, CancellationToken cancellationToken) + List messages, ChatOptions options, List callContents, + int iteration, int functionCallIndex, bool captureExceptions, bool isStreaming, CancellationToken cancellationToken) { var callContent = callContents[functionCallIndex]; @@ -598,19 +722,28 @@ private async Task ProcessFunctionCallAsync( AIFunction? function = options.Tools!.OfType().FirstOrDefault(t => t.Name == callContent.Name); if (function is null) { - return new(ContinueMode.Continue, FunctionInvocationStatus.NotFound, callContent, result: null, exception: null); + return new(shouldTerminate: false, FunctionInvokingChatClient.FunctionInvocationStatus.NotFound, callContent, result: null, exception: null); } - KernelFunctionInvocationContext context = new() + if (callContent.Arguments is not null) { + callContent.Arguments = new KernelArguments(callContent.Arguments); + } + + var context = new AutoFunctionInvocationContext(new() + { + Function = function, + Arguments = new(callContent.Arguments) { Services = _functionInvocationServices }, + Messages = messages, Options = options, + CallContent = callContent, - Function = function, Iteration = iteration, FunctionCallIndex = functionCallIndex, FunctionCount = callContents.Count, - }; + }) + { IsStreaming = isStreaming }; object? result; try @@ -619,56 +752,46 @@ private async Task ProcessFunctionCallAsync( } catch (Exception e) when (!cancellationToken.IsCancellationRequested) { + if (!captureExceptions) + { + throw; + } + return new( - RetryOnError ? ContinueMode.Continue : ContinueMode.AllowOneMoreRoundtrip, // We won't allow further function calls, hence the LLM will just get one more chance to give a final answer. - FunctionInvocationStatus.Exception, + shouldTerminate: false, + FunctionInvokingChatClient.FunctionInvocationStatus.Exception, callContent, result: null, exception: e); } return new( - context.Terminate ? ContinueMode.Terminate : ContinueMode.Continue, - FunctionInvocationStatus.RanToCompletion, + shouldTerminate: context.Terminate, + FunctionInvokingChatClient.FunctionInvocationStatus.RanToCompletion, callContent, result, exception: null); } - /// Represents the return value of , dictating how the loop should behave. - /// These values are ordered from least severe to most severe, and code explicitly depends on the ordering. - internal enum ContinueMode - { - /// Send back the responses and continue processing. - Continue = 0, - - /// Send back the response but without any tools. - AllowOneMoreRoundtrip = 1, - - /// Immediately exit the function calling loop. - Terminate = 2, - } - /// Creates one or more response messages for function invocation results. /// Information about the function call invocations and results. /// A list of all chat messages created from . - internal IList CreateResponseMessages( - ReadOnlySpan results) + private IList CreateResponseMessages(List results) { - var contents = new List(results.Length); - for (int i = 0; i < results.Length; i++) + var contents = new List(results.Count); + for (int i = 0; i < results.Count; i++) { contents.Add(CreateFunctionResultContent(results[i])); } return [new(ChatRole.Tool, contents)]; - Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult result) + FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult result) { Verify.NotNull(result); object? functionResult; - if (result.Status == FunctionInvocationStatus.RanToCompletion) + if (result.Status == FunctionInvokingChatClient.FunctionInvocationStatus.RanToCompletion) { functionResult = result.Result ?? "Success: Function completed."; } @@ -676,8 +799,8 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi { string message = result.Status switch { - FunctionInvocationStatus.NotFound => $"Error: Requested function \"{result.CallContent.Name}\" not found.", - FunctionInvocationStatus.Exception => "Error: Function failed.", + FunctionInvokingChatClient.FunctionInvocationStatus.NotFound => $"Error: Requested function \"{result.CallContent.Name}\" not found.", + FunctionInvokingChatClient.FunctionInvocationStatus.Exception => "Error: Function failed.", _ => "Error: Unknown error.", }; @@ -689,7 +812,49 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi functionResult = message; } - return new Microsoft.Extensions.AI.FunctionResultContent(result.CallContent.CallId, functionResult) { Exception = result.Exception }; + return new FunctionResultContent(result.CallContent.CallId, functionResult) { Exception = result.Exception }; + } + } + + /// + /// Invokes the auto function invocation filters. + /// + /// The auto function invocation context. + /// The function to call after the filters. + /// The auto function invocation context. + private async Task OnAutoFunctionInvocationAsync( + AutoFunctionInvocationContext context, + Func functionCallCallback) + { + await this.InvokeFilterOrFunctionAsync(functionCallCallback, context).ConfigureAwait(false); + + return context; + } + + /// + /// This method will execute auto function invocation filters and function recursively. + /// If there are no registered filters, just function will be executed. + /// If there are registered filters, filter on position will be executed. + /// Second parameter of filter is callback. It can be either filter on + 1 position or function if there are no remaining filters to execute. + /// Function will always be executed as last step after all filters. + /// + private async Task InvokeFilterOrFunctionAsync( + Func functionCallCallback, + AutoFunctionInvocationContext context, + int index = 0) + { + IList autoFunctionInvocationFilters = context.Kernel.AutoFunctionInvocationFilters; + + if (autoFunctionInvocationFilters is { Count: > 0 } && index < autoFunctionInvocationFilters.Count) + { + await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( + context, + (ctx) => this.InvokeFilterOrFunctionAsync(functionCallCallback, ctx, index + 1) + ).ConfigureAwait(false); + } + else + { + await functionCallCallback(context).ConfigureAwait(false); } } @@ -700,7 +865,7 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi /// The to monitor for cancellation requests. The default is . /// The result of the function invocation, or if the function invocation returned . /// is . - internal async Task InvokeFunctionAsync(KernelFunctionInvocationContext context, CancellationToken cancellationToken) + private async Task InvokeFunctionAsync(AutoFunctionInvocationContext context, CancellationToken cancellationToken) { Verify.NotNull(context); @@ -712,7 +877,7 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi startingTimestamp = Stopwatch.GetTimestamp(); if (_logger.IsEnabled(LogLevel.Trace)) { - LogInvokingSensitive(context.Function.Name, LoggingAsJson(context.CallContent.Arguments, context.Function.JsonSerializerOptions)); + LogInvokingSensitive(context.Function.Name, LoggingAsJson(context.CallContent.Arguments, context.AIFunction.JsonSerializerOptions)); } else { @@ -723,8 +888,24 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi object? result = null; try { - CurrentContext = context; - result = await context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken).ConfigureAwait(false); + CurrentContext = context; // doesn't need to be explicitly reset after, as that's handled automatically at async method exit + context = await this.OnAutoFunctionInvocationAsync( + context, + async (ctx) => + { + // Check if filter requested termination + if (ctx.Terminate) + { + return; + } + + // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any + // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here, + // as the called function could in turn telling the model about itself as a possible candidate for invocation. + result = await context.AIFunction.InvokeAsync(new(context.Arguments), cancellationToken).ConfigureAwait(false); + ctx.Result = new FunctionResult(ctx.Function, result); + }).ConfigureAwait(false); + result = context.Result.GetValue(); } catch (Exception e) { @@ -753,7 +934,7 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi if (result is not null && _logger.IsEnabled(LogLevel.Trace)) { - LogInvocationCompletedSensitive(context.Function.Name, elapsed, LoggingAsJson(result, context.Function.JsonSerializerOptions)); + LogInvocationCompletedSensitive(context.Function.Name, elapsed, LoggingAsJson(result, context.AIFunction.JsonSerializerOptions)); } else { @@ -815,9 +996,9 @@ private static TimeSpan GetElapsedTime(long startingTimestamp) => /// Provides information about the invocation of a function call. public sealed class FunctionInvocationResult { - internal FunctionInvocationResult(ContinueMode continueMode, FunctionInvocationStatus status, Microsoft.Extensions.AI.FunctionCallContent callContent, object? result, Exception? exception) + internal FunctionInvocationResult(bool shouldTerminate, FunctionInvokingChatClient.FunctionInvocationStatus status, FunctionCallContent callContent, object? result, Exception? exception) { - ContinueMode = continueMode; + ShouldTerminate = shouldTerminate; Status = status; CallContent = callContent; Result = result; @@ -825,10 +1006,10 @@ internal FunctionInvocationResult(ContinueMode continueMode, FunctionInvocationS } /// Gets status about how the function invocation completed. - public FunctionInvocationStatus Status { get; } + public FunctionInvokingChatClient.FunctionInvocationStatus Status { get; } /// Gets the function call content information associated with this invocation. - public Microsoft.Extensions.AI.FunctionCallContent CallContent { get; } + public FunctionCallContent CallContent { get; } /// Gets the result of the function call. public object? Result { get; } @@ -836,20 +1017,7 @@ internal FunctionInvocationResult(ContinueMode continueMode, FunctionInvocationS /// Gets any exception the function call threw. public Exception? Exception { get; } - /// Gets an indication for how the caller should continue the processing loop. - internal ContinueMode ContinueMode { get; } - } - - /// Provides error codes for when errors occur as part of the function calling loop. - public enum FunctionInvocationStatus - { - /// The operation completed successfully. - RanToCompletion, - - /// The requested function could not be found. - NotFound, - - /// The function call failed with an exception. - Exception, + /// Gets a value indicating whether the caller should terminate the processing loop. + internal bool ShouldTerminate { get; } } } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs index 22968c47ea38..147cdd5ba332 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs @@ -18,6 +18,14 @@ public class ChatHistory : IList, IReadOnlyListThe messages. private readonly List _messages; + private Action? _overrideAdd; + private Func? _overrideRemove; + private Action? _overrideClear; + private Action? _overrideInsert; + private Action? _overrideRemoveAt; + private Action? _overrideRemoveRange; + private Action>? _overrideAddRange; + /// Initializes an empty history. /// /// Creates a new instance of the class @@ -27,6 +35,38 @@ public ChatHistory() this._messages = []; } + // This allows observation of the chat history changes by-reference reflecting in an + // internal IEnumerable when used from IChatClients + // with AutoFunctionInvocationFilters + internal void SetOverrides( + Action overrideAdd, + Func overrideRemove, + Action onClear, + Action overrideInsert, + Action overrideRemoveAt, + Action overrideRemoveRange, + Action> overrideAddRange) + { + this._overrideAdd = overrideAdd; + this._overrideRemove = overrideRemove; + this._overrideClear = onClear; + this._overrideInsert = overrideInsert; + this._overrideRemoveAt = overrideRemoveAt; + this._overrideRemoveRange = overrideRemoveRange; + this._overrideAddRange = overrideAddRange; + } + + internal void ClearOverrides() + { + this._overrideAdd = null; + this._overrideRemove = null; + this._overrideClear = null; + this._overrideInsert = null; + this._overrideRemoveAt = null; + this._overrideRemoveRange = null; + this._overrideAddRange = null; + } + /// /// Creates a new instance of the with a first message in the provided . /// If not role is provided then the first message will default to role. @@ -37,8 +77,7 @@ public ChatHistory(string message, AuthorRole role) { Verify.NotNullOrWhiteSpace(message); - this._messages = []; - this.Add(new ChatMessageContent(role, message)); + this._messages = [new ChatMessageContent(role, message)]; } /// @@ -60,7 +99,7 @@ public ChatHistory(IEnumerable messages) } /// Gets the number of messages in the history. - public int Count => this._messages.Count; + public virtual int Count => this._messages.Count; /// /// Role of the message author @@ -118,29 +157,32 @@ public void AddDeveloperMessage(string content) => /// Adds a message to the history. /// The message to add. /// is null. - public void Add(ChatMessageContent item) + public virtual void Add(ChatMessageContent item) { Verify.NotNull(item); this._messages.Add(item); + this._overrideAdd?.Invoke(item); } /// Adds the messages to the history. /// The collection whose messages should be added to the history. /// is null. - public void AddRange(IEnumerable items) + public virtual void AddRange(IEnumerable items) { Verify.NotNull(items); this._messages.AddRange(items); + this._overrideAddRange?.Invoke(items); } /// Inserts a message into the history at the specified index. /// The index at which the item should be inserted. /// The message to insert. /// is null. - public void Insert(int index, ChatMessageContent item) + public virtual void Insert(int index, ChatMessageContent item) { Verify.NotNull(item); this._messages.Insert(index, item); + this._overrideInsert?.Invoke(index, item); } /// @@ -151,17 +193,22 @@ public void Insert(int index, ChatMessageContent item) /// is null. /// The number of messages in the history is greater than the available space from to the end of . /// is less than 0. - public void CopyTo(ChatMessageContent[] array, int arrayIndex) => this._messages.CopyTo(array, arrayIndex); + public virtual void CopyTo(ChatMessageContent[] array, int arrayIndex) + => this._messages.CopyTo(array, arrayIndex); /// Removes all messages from the history. - public void Clear() => this._messages.Clear(); + public virtual void Clear() + { + this._messages.Clear(); + this._overrideClear?.Invoke(); + } /// Gets or sets the message at the specified index in the history. /// The index of the message to get or set. /// The message at the specified index. /// is null. /// The was not valid for this history. - public ChatMessageContent this[int index] + public virtual ChatMessageContent this[int index] { get => this._messages[index]; set @@ -175,7 +222,7 @@ public ChatMessageContent this[int index] /// The message to locate. /// true if the message is found in the history; otherwise, false. /// is null. - public bool Contains(ChatMessageContent item) + public virtual bool Contains(ChatMessageContent item) { Verify.NotNull(item); return this._messages.Contains(item); @@ -185,7 +232,7 @@ public bool Contains(ChatMessageContent item) /// The message to locate. /// The index of the first found occurrence of the specified message; -1 if the message could not be found. /// is null. - public int IndexOf(ChatMessageContent item) + public virtual int IndexOf(ChatMessageContent item) { Verify.NotNull(item); return this._messages.IndexOf(item); @@ -194,16 +241,22 @@ public int IndexOf(ChatMessageContent item) /// Removes the message at the specified index from the history. /// The index of the message to remove. /// The was not valid for this history. - public void RemoveAt(int index) => this._messages.RemoveAt(index); + public virtual void RemoveAt(int index) + { + this._messages.RemoveAt(index); + this._overrideRemoveAt?.Invoke(index); + } /// Removes the first occurrence of the specified message from the history. /// The message to remove from the history. /// true if the item was successfully removed; false if it wasn't located in the history. /// is null. - public bool Remove(ChatMessageContent item) + public virtual bool Remove(ChatMessageContent item) { Verify.NotNull(item); - return this._messages.Remove(item); + var result = this._messages.Remove(item); + this._overrideRemove?.Invoke(item); + return result; } /// @@ -214,9 +267,10 @@ public bool Remove(ChatMessageContent item) /// is less than 0. /// is less than 0. /// and do not denote a valid range of messages. - public void RemoveRange(int index, int count) + public virtual void RemoveRange(int index, int count) { this._messages.RemoveRange(index, count); + this._overrideRemoveRange?.Invoke(index, count); } /// diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs index 5c7e56d0ce43..381e073a1446 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs @@ -80,6 +80,67 @@ public static async Task ReduceAsync(this ChatHistory chatHistory, return chatHistory; } + /// Converts a to a list. + /// The chat history to convert. + /// A list of objects. internal static List ToChatMessageList(this ChatHistory chatHistory) => chatHistory.Select(m => m.ToChatMessage()).ToList(); + + internal static void SetChatMessageHandlers(this ChatHistory chatHistory, IList messages) + { + chatHistory.SetOverrides(Add, Remove, Clear, Insert, RemoveAt, RemoveRange, AddRange); + + void Add(ChatMessageContent item) + { + messages.Add(item.ToChatMessage()); + } + + void Clear() + { + messages.Clear(); + } + + bool Remove(ChatMessageContent item) + { + var index = chatHistory.IndexOf(item); + + if (index < 0) + { + return false; + } + + messages.RemoveAt(index); + + return true; + } + + void Insert(int index, ChatMessageContent item) + { + messages.Insert(index, item.ToChatMessage()); + } + + void RemoveAt(int index) + { + messages.RemoveAt(index); + } + + void RemoveRange(int index, int count) + { + if (messages is List messageList) + { + messageList.RemoveRange(index, count); + return; + } + + foreach (var chatMessage in messages.Skip(index).Take(count)) + { + messages.Remove(chatMessage); + } + } + + void AddRange(IEnumerable items) + { + messages.AddRange(items.Select(i => i.ToChatMessage())); + } + } } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs index 98bb09be6f85..74fe27c2e841 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs @@ -8,13 +8,15 @@ using System.Text.Json; using System.Text.Json.Serialization.Metadata; using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.SemanticKernel; -internal static class PromptExecutionSettingsExtensions +/// Extensions methods for . +public static class PromptExecutionSettingsExtensions { /// Converts a pair of and to a . - internal static ChatOptions? ToChatOptions(this PromptExecutionSettings? settings, Kernel? kernel) + public static ChatOptions? ToChatOptions(this PromptExecutionSettings? settings, Kernel? kernel) { if (settings is null) { @@ -149,7 +151,8 @@ internal static class PromptExecutionSettingsExtensions options.Tools = functions.Select(f => f.AsAIFunction(kernel)).Cast().ToList(); } - return options; + // Enables usage of AutoFunctionInvocationFilters + return options.AddKernel(kernel!); // Be a little lenient on the types of the values used in the extension data, // e.g. allow doubles even when requesting floats. diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs index 0f18be8df8e0..bc8dd0c3490c 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs @@ -1,6 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; using System.Threading; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.SemanticKernel; @@ -10,6 +15,34 @@ namespace Microsoft.SemanticKernel; /// public class AutoFunctionInvocationContext { + private ChatHistory? _chatHistory; + private KernelFunction? _kernelFunction; + private readonly Microsoft.Extensions.AI.FunctionInvocationContext _invocationContext = new(); + + /// + /// Initializes a new instance of the class from an existing . + /// + internal AutoFunctionInvocationContext(Microsoft.Extensions.AI.FunctionInvocationContext invocationContext) + { + Verify.NotNull(invocationContext); + Verify.NotNull(invocationContext.Options); + + // the ChatOptions must be provided with AdditionalProperties. + Verify.NotNull(invocationContext.Options.AdditionalProperties); + + invocationContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.KernelKey, out var kernel); + Verify.NotNull(kernel); + + invocationContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.ChatMessageContentKey, out var chatMessageContent); + Verify.NotNull(chatMessageContent); + + invocationContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.PromptExecutionSettingsKey, out var executionSettings); + this.ExecutionSettings = executionSettings; + this._invocationContext = invocationContext; + + this.Result = new FunctionResult(this.Function) { Culture = kernel.Culture }; + } + /// /// Initializes a new instance of the class. /// @@ -31,11 +64,21 @@ public AutoFunctionInvocationContext( Verify.NotNull(chatHistory); Verify.NotNull(chatMessageContent); - this.Kernel = kernel; - this.Function = function; + this._invocationContext.Options = new() + { + AdditionalProperties = new() + { + [ChatOptionsExtensions.ChatMessageContentKey] = chatMessageContent, + [ChatOptionsExtensions.KernelKey] = kernel + } + }; + + this._kernelFunction = function; + this._chatHistory = chatHistory; + this._invocationContext.Messages = chatHistory.ToChatMessageList(); + chatHistory.SetChatMessageHandlers(this._invocationContext.Messages); + this._invocationContext.Function = function.AsAIFunction(); this.Result = result; - this.ChatHistory = chatHistory; - this.ChatMessageContent = chatMessageContent; } /// @@ -52,71 +95,263 @@ public AutoFunctionInvocationContext( /// /// Gets the arguments associated with the operation. /// - public KernelArguments? Arguments { get; init; } + public KernelArguments? Arguments + { + get => this._invocationContext.CallContent.Arguments is KernelArguments kernelArguments ? kernelArguments : null; + init => this._invocationContext.CallContent.Arguments = value; + } /// /// Request sequence index of automatic function invocation process. Starts from 0. /// - public int RequestSequenceIndex { get; init; } + public int RequestSequenceIndex + { + get => this._invocationContext.Iteration; + init => this._invocationContext.Iteration = value; + } /// /// Function sequence index. Starts from 0. /// - public int FunctionSequenceIndex { get; init; } + public int FunctionSequenceIndex + { + get => this._invocationContext.FunctionCallIndex; + init => this._invocationContext.FunctionCallIndex = value; + } - /// - /// Number of functions that will be invoked during auto function invocation request. - /// - public int FunctionCount { get; init; } + /// Gets or sets the total number of function call requests within the iteration. + /// + /// The response from the underlying client might include multiple function call requests. + /// This count indicates how many there were. + /// + public int FunctionCount + { + get => this._invocationContext.FunctionCount; + init => this._invocationContext.FunctionCount = value; + } /// /// The ID of the tool call. /// - public string? ToolCallId { get; init; } + public string? ToolCallId + { + get => this._invocationContext.CallContent.CallId; + init + { + this._invocationContext.CallContent = new Microsoft.Extensions.AI.FunctionCallContent( + callId: value ?? string.Empty, + name: this._invocationContext.CallContent.Name, + arguments: this._invocationContext.CallContent.Arguments); + } + } /// /// The chat message content associated with automatic function invocation. /// - public ChatMessageContent ChatMessageContent { get; } + public ChatMessageContent ChatMessageContent + => (this._invocationContext.Options?.AdditionalProperties?[ChatOptionsExtensions.ChatMessageContentKey] as ChatMessageContent)!; /// /// The execution settings associated with the operation. /// - public PromptExecutionSettings? ExecutionSettings { get; init; } + public PromptExecutionSettings? ExecutionSettings + { + get => this._invocationContext.Options?.AdditionalProperties?[ChatOptionsExtensions.PromptExecutionSettingsKey] as PromptExecutionSettings; + init + { + this._invocationContext.Options ??= new(); + this._invocationContext.Options.AdditionalProperties ??= []; + this._invocationContext.Options.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = value; + } + } /// /// Gets the associated with automatic function invocation. /// - public ChatHistory ChatHistory { get; } + public ChatHistory ChatHistory => this._chatHistory ??= new ChatMessageHistory(this._invocationContext.Messages); /// /// Gets the with which this filter is associated. /// - public KernelFunction Function { get; } + public KernelFunction Function + { + get + { + if (this._kernelFunction is null + // If the schemas are different, + // AIFunction reference potentially was modified and the kernel function should be regenerated. + || !IsSameSchema(this._kernelFunction, this._invocationContext.Function)) + { + this._kernelFunction = this._invocationContext.Function.AsKernelFunction(); + } + + return this._kernelFunction; + } + } /// /// Gets the containing services, plugins, and other state for use throughout the operation. /// - public Kernel Kernel { get; } + public Kernel Kernel + { + get + { + Kernel? kernel = null; + this._invocationContext.Options?.AdditionalProperties?.TryGetValue(ChatOptionsExtensions.KernelKey, out kernel); + + // To avoid exception from properties, when attempting to retrieve a kernel from a non-ready context, it will give a null. + return kernel!; + } + } /// /// Gets or sets the result of the function's invocation. /// public FunctionResult Result { get; set; } + /// Gets or sets a value indicating whether to terminate the request. + /// + /// In response to a function call request, the function might be invoked, its result added to the chat contents, + /// and a new request issued to the wrapped client. If this property is set to , that subsequent request + /// will not be issued and instead the loop immediately terminated rather than continuing until there are no + /// more function call requests in responses. + /// + public bool Terminate + { + get => this._invocationContext.Terminate; + set => this._invocationContext.Terminate = value; + } + + /// Gets or sets the function call content information associated with this invocation. + internal Microsoft.Extensions.AI.FunctionCallContent CallContent + { + get => this._invocationContext.CallContent; + set => this._invocationContext.CallContent = value; + } + + internal AIFunction AIFunction + { + get => this._invocationContext.Function; + set => this._invocationContext.Function = value; + } + + private static bool IsSameSchema(KernelFunction kernelFunction, AIFunction aiFunction) + { + // Compares the schemas, should be similar. + return string.Equals( + kernelFunction.AsAIFunction().JsonSchema.ToString(), + aiFunction.JsonSchema.ToString(), + StringComparison.OrdinalIgnoreCase); + + // TODO: Later can be improved by comparing the underlying methods. + // return kernelFunction.UnderlyingMethod == aiFunction.UnderlyingMethod; + } + /// - /// Gets or sets a value indicating whether the operation associated with the filter should be terminated. - /// - /// By default, this value is , which means all functions will be invoked. - /// If set to , the behavior depends on how functions are invoked: - /// - /// - If functions are invoked sequentially (the default behavior), the remaining functions will not be invoked, - /// and the last request to the LLM will not be performed. - /// - /// - If functions are invoked concurrently (controlled by the option), - /// other functions will still be invoked, and the last request to the LLM will not be performed. - /// - /// In both cases, the automatic function invocation process will be terminated, and the result of the last executed function will be returned to the caller. + /// Mutable IEnumerable of chat message as chat history. /// - public bool Terminate { get; set; } + private class ChatMessageHistory : ChatHistory, IEnumerable + { + private readonly List _messages; + + internal ChatMessageHistory(IEnumerable messages) : base(messages.ToChatHistory()) + { + this._messages = new List(messages); + } + + public override void Add(ChatMessageContent item) + { + base.Add(item); + this._messages.Add(item.ToChatMessage()); + } + + public override void Clear() + { + base.Clear(); + this._messages.Clear(); + } + + public override bool Remove(ChatMessageContent item) + { + var index = base.IndexOf(item); + + if (index < 0) + { + return false; + } + + this._messages.RemoveAt(index); + base.RemoveAt(index); + + return true; + } + + public override void Insert(int index, ChatMessageContent item) + { + base.Insert(index, item); + this._messages.Insert(index, item.ToChatMessage()); + } + + public override void RemoveAt(int index) + { + this._messages.RemoveAt(index); + base.RemoveAt(index); + } + + public override ChatMessageContent this[int index] + { + get => this._messages[index].ToChatMessageContent(); + set + { + this._messages[index] = value.ToChatMessage(); + base[index] = value; + } + } + + public override void RemoveRange(int index, int count) + { + this._messages.RemoveRange(index, count); + base.RemoveRange(index, count); + } + + public override void CopyTo(ChatMessageContent[] array, int arrayIndex) + { + for (int i = 0; i < this._messages.Count; i++) + { + array[arrayIndex + i] = this._messages[i].ToChatMessageContent(); + } + } + + public override bool Contains(ChatMessageContent item) => base.Contains(item); + + public override int IndexOf(ChatMessageContent item) => base.IndexOf(item); + + public override void AddRange(IEnumerable items) + { + base.AddRange(items); + this._messages.AddRange(items.Select(i => i.ToChatMessage())); + } + + public override int Count => this._messages.Count; + + // Explicit implementation of IEnumerable.GetEnumerator() + IEnumerator IEnumerable.GetEnumerator() + { + foreach (var message in this._messages) + { + yield return message.ToChatMessageContent(); // Convert and yield each item + } + } + + // Explicit implementation of non-generic IEnumerable.GetEnumerator() + IEnumerator IEnumerable.GetEnumerator() + => ((IEnumerable)this).GetEnumerator(); + } + + /// Destructor to clear the chat history overrides. + ~AutoFunctionInvocationContext() + { + // The moment this class is destroyed, we need to clear the update message overrides + this._chatHistory?.ClearOverrides(); + } } diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs index c0a76d0e0c2c..0e09c1d525b7 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs @@ -542,7 +542,6 @@ public KernelAIFunction(KernelFunction kernelFunction, Kernel? kernel) this.JsonSchema = BuildFunctionSchema(kernelFunction); } - public override string Name { get; } public override JsonElement JsonSchema { get; } public override string Description => this._kernelFunction.Description; diff --git a/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj b/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj index 47043cbe1df8..4826426418a6 100644 --- a/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj +++ b/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj @@ -29,7 +29,7 @@ - + diff --git a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs index 61b3324ef420..8037f513ee65 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs @@ -283,7 +283,7 @@ protected override async IAsyncEnumerable InvokeStreamingCoreAsync GetChatClientResultAsync( }; } - var modelId = chatClient.GetService()?.ModelId; + var modelId = chatClient.GetService()?.DefaultModelId; // Usage details are global and duplicated for each chat message content, use first one to get usage information - this.CaptureUsageDetails(chatClient.GetService()?.ModelId, chatResponse.Usage, this._logger); + this.CaptureUsageDetails(chatClient.GetService()?.DefaultModelId, chatResponse.Usage, this._logger); return new FunctionResult(this, chatResponse) { diff --git a/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/AIFunctionKernelFunctionTests.cs b/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/AIFunctionKernelFunctionTests.cs index 065784118c3d..20a2460e1751 100644 --- a/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/AIFunctionKernelFunctionTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/AIFunctionKernelFunctionTests.cs @@ -14,8 +14,8 @@ public class AIFunctionKernelFunctionTests public void ShouldAssignIsRequiredParameterMetadataPropertyCorrectly() { // Arrange and Act - AIFunction aiFunction = AIFunctionFactory.Create((string p1, int? p2 = null) => p1, - new AIFunctionFactoryOptions { JsonSchemaCreateOptions = new AIJsonSchemaCreateOptions { RequireAllProperties = false } }); + AIFunction aiFunction = Microsoft.Extensions.AI.AIFunctionFactory.Create((string p1, int? p2 = null) => p1, + new Microsoft.Extensions.AI.AIFunctionFactoryOptions { JsonSchemaCreateOptions = new AIJsonSchemaCreateOptions { RequireAllProperties = false } }); AIFunctionKernelFunction sut = new(aiFunction); diff --git a/dotnet/src/SemanticKernel.UnitTests/Extensions/ClientResultExceptionExtensionsTests.cs b/dotnet/src/SemanticKernel.UnitTests/Extensions/ClientResultExceptionExtensionsTests.cs index c9c348d1ac44..c25fa0ccd249 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Extensions/ClientResultExceptionExtensionsTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Extensions/ClientResultExceptionExtensionsTests.cs @@ -64,7 +64,6 @@ public void ItProvideStatusForResponsesWithoutContent() // Assert Assert.NotNull(httpOperationException); Assert.NotNull(httpOperationException.StatusCode); - Assert.Null(httpOperationException.ResponseContent); Assert.Equal(exception, httpOperationException.InnerException); Assert.Equal(exception.Message, httpOperationException.Message); Assert.Equal(pipelineResponse.Status, (int)httpOperationException.StatusCode!); diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs index 322c5f3b935f..7400875cdc9b 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs @@ -70,7 +70,7 @@ private sealed class ChatClientTest : IChatClient public ChatClientTest() { - this._metadata = new ChatClientMetadata(modelId: "Value1"); + this._metadata = new ChatClientMetadata(defaultModelId: "Value1"); } public void Dispose() diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs index b0a57fa8cdf6..52619c22af6e 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs @@ -527,7 +527,7 @@ private sealed class ChatClient : IChatClient public ChatClient(string modelId) { - this.Metadata = new ChatClientMetadata(modelId: modelId); + this.Metadata = new ChatClientMetadata(defaultModelId: modelId); } public Task> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)