From fa1776f2297c63f9abb1b44510eb00f55229498a Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Tue, 1 Apr 2025 21:34:20 +0100 Subject: [PATCH 01/26] AddChatClient OpenAI WIP --- ...tClient_AutoFunctionInvocationFiltering.cs | 167 ++++++++++++++++++ .../Connectors.OpenAI.csproj | 1 + ...penAIKernelBuilderExtensions.ChatClient.cs | 106 +++++++++++ ...IServiceCollectionExtensions.ChatClient.cs | 157 ++++++++++++++++ .../AI/ChatClient/ChatClientExtensions.cs | 6 +- .../ClientResultExceptionExtensionsTests.cs | 1 - 6 files changed, 435 insertions(+), 3 deletions(-) create mode 100644 dotnet/samples/Concepts/Filtering/ChatClient_AutoFunctionInvocationFiltering.cs create mode 100644 dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs create mode 100644 dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs diff --git a/dotnet/samples/Concepts/Filtering/ChatClient_AutoFunctionInvocationFiltering.cs b/dotnet/samples/Concepts/Filtering/ChatClient_AutoFunctionInvocationFiltering.cs new file mode 100644 index 000000000000..1e053618a385 --- /dev/null +++ b/dotnet/samples/Concepts/Filtering/ChatClient_AutoFunctionInvocationFiltering.cs @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.OpenAI; + +namespace Filtering; + +public class ChatClient_AutoFunctionInvocationFiltering(ITestOutputHelper output) : BaseTest(output) +{ + /// + /// Shows how to use . + /// + [Fact] + public async Task UsingAutoFunctionInvocationFilter() + { + var builder = Kernel.CreateBuilder(); + + builder.AddOpenAIChatClient("gpt-4", TestConfiguration.OpenAI.ApiKey); + + // This filter outputs information about auto function invocation and returns overridden result. + builder.Services.AddSingleton(new AutoFunctionInvocationFilter(this.Output)); + + var kernel = builder.Build(); + + var function = KernelFunctionFactory.CreateFromMethod(() => "Result from function", "MyFunction"); + + kernel.ImportPluginFromFunctions("MyPlugin", [function]); + + var executionSettings = new OpenAIPromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Required([function], autoInvoke: true) + }; + + var result = await kernel.InvokePromptAsync("Invoke provided function and return result", new(executionSettings)); + + Console.WriteLine(result); + + // Output: + // Request sequence number: 0 + // Function sequence number: 0 + // Total number of functions: 1 + // Result from auto function invocation filter. + } + + /// + /// Shows how to get list of function calls by using . + /// + [Fact] + public async Task GetFunctionCallsWithFilterAsync() + { + var builder = Kernel.CreateBuilder(); + + builder.AddOpenAIChatCompletion("gpt-3.5-turbo-1106", TestConfiguration.OpenAI.ApiKey); + + builder.Services.AddSingleton(new FunctionCallsFilter(this.Output)); + + var kernel = builder.Build(); + + kernel.ImportPluginFromFunctions("HelperFunctions", + [ + kernel.CreateFunctionFromMethod(() => DateTime.UtcNow.ToString("R"), "GetCurrentUtcTime", "Retrieves the current time in UTC."), + kernel.CreateFunctionFromMethod((string cityName) => + cityName switch + { + "Boston" => "61 and rainy", + "London" => "55 and cloudy", + "Miami" => "80 and sunny", + "Paris" => "60 and rainy", + "Tokyo" => "50 and sunny", + "Sydney" => "75 and sunny", + "Tel Aviv" => "80 and sunny", + _ => "31 and snowing", + }, "GetWeatherForCity", "Gets the current weather for the specified city"), + ]); + + var executionSettings = new OpenAIPromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + }; + + await foreach (var chunk in kernel.InvokePromptStreamingAsync("Check current UTC time and return current weather in Boston city.", new(executionSettings))) + { + Console.WriteLine(chunk.ToString()); + } + + // Output: + // Request #0. Function call: HelperFunctions.GetCurrentUtcTime. + // Request #0. Function call: HelperFunctions.GetWeatherForCity. + // The current UTC time is {time of execution}, and the current weather in Boston is 61°F and rainy. + } + + /// Shows available syntax for auto function invocation filter. + private sealed class AutoFunctionInvocationFilter(ITestOutputHelper output) : IAutoFunctionInvocationFilter + { + public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) + { + // Example: get function information + var functionName = context.Function.Name; + + // Example: get chat history + var chatHistory = context.ChatHistory; + + // Example: get information about all functions which will be invoked + var functionCalls = FunctionCallContent.GetFunctionCalls(context.ChatHistory.Last()); + + // In function calling functionality there are two loops. + // Outer loop is "request" loop - it performs multiple requests to LLM until user ask will be satisfied. + // Inner loop is "function" loop - it handles LLM response with multiple function calls. + + // Workflow example: + // 1. Request to LLM #1 -> Response with 3 functions to call. + // 1.1. Function #1 called. + // 1.2. Function #2 called. + // 1.3. Function #3 called. + // 2. Request to LLM #2 -> Response with 2 functions to call. + // 2.1. Function #1 called. + // 2.2. Function #2 called. + + // context.RequestSequenceIndex - it's a sequence number of outer/request loop operation. + // context.FunctionSequenceIndex - it's a sequence number of inner/function loop operation. + // context.FunctionCount - number of functions which will be called per request (based on example above: 3 for first request, 2 for second request). + + // Example: get request sequence index + output.WriteLine($"Request sequence index: {context.RequestSequenceIndex}"); + + // Example: get function sequence index + output.WriteLine($"Function sequence index: {context.FunctionSequenceIndex}"); + + // Example: get total number of functions which will be called + output.WriteLine($"Total number of functions: {context.FunctionCount}"); + + // Calling next filter in pipeline or function itself. + // By skipping this call, next filters and function won't be invoked, and function call loop will proceed to the next function. + await next(context); + + // Example: get function result + var result = context.Result; + + // Example: override function result value + context.Result = new FunctionResult(context.Result, "Result from auto function invocation filter"); + + // Example: Terminate function invocation + context.Terminate = true; + } + } + + /// Shows how to get list of all function calls per request. + private sealed class FunctionCallsFilter(ITestOutputHelper output) : IAutoFunctionInvocationFilter + { + public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) + { + var chatHistory = context.ChatHistory; + var functionCalls = FunctionCallContent.GetFunctionCalls(chatHistory.Last()).ToArray(); + + if (functionCalls is { Length: > 0 }) + { + foreach (var functionCall in functionCalls) + { + output.WriteLine($"Request #{context.RequestSequenceIndex}. Function call: {functionCall.PluginName}.{functionCall.FunctionName}."); + } + } + + await next(context); + } + } +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj b/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj index 64a0e72bde6d..c17e878a7a42 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj +++ b/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj @@ -37,6 +37,7 @@ + diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs new file mode 100644 index 000000000000..e38aa5289104 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Net.Http; +using OpenAI; + +namespace Microsoft.SemanticKernel; + +/// +/// Sponsor extensions class for . +/// +public static class OpenAIChatClientKernelBuilderExtensions +{ + #region Chat Completion + + /// + /// Adds the OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// OpenAI model name, see https://platform.openai.com/docs/models + /// OpenAI API key, see https://platform.openai.com/account/api-keys + /// OpenAI organization id. This is usually optional unless your account belongs to multiple organizations. + /// A local identifier for the given AI service + /// The HttpClient to use with this service. + /// The same instance as . + public static IKernelBuilder AddOpenAIChatClient( + this IKernelBuilder builder, + string modelId, + string apiKey, + string? orgId = null, + string? serviceId = null, + HttpClient? httpClient = null) + { + Verify.NotNull(builder); + + builder.Services.AddOpenAIChatClient( + modelId, + apiKey, + orgId, + serviceId, + httpClient); + + return builder; + } + + /// + /// Adds the OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// OpenAI model id + /// to use for the service. If null, one must be available in the service provider when this service is resolved. + /// A local identifier for the given AI service + /// The same instance as . + public static IKernelBuilder AddOpenAIChatClient( + this IKernelBuilder builder, + string modelId, + OpenAIClient? openAIClient = null, + string? serviceId = null) + { + Verify.NotNull(builder); + + builder.Services.AddOpenAIChatClient( + modelId, + openAIClient, + serviceId); + + return builder; + } + + /// + /// Adds the Custom Endpoint OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// OpenAI model name, see https://platform.openai.com/docs/models + /// Custom OpenAI Compatible Message API endpoint + /// OpenAI API key, see https://platform.openai.com/account/api-keys + /// OpenAI organization id. This is usually optional unless your account belongs to multiple organizations. + /// A local identifier for the given AI service + /// The HttpClient to use with this service. + /// The same instance as . + [Experimental("SKEXP0010")] + public static IKernelBuilder AddOpenAIChatClient( + this IKernelBuilder builder, + string modelId, + Uri endpoint, + string? apiKey, + string? orgId = null, + string? serviceId = null, + HttpClient? httpClient = null) + { + Verify.NotNull(builder); + + builder.Services.AddOpenAIChatClient( + modelId, + endpoint, + apiKey, + orgId, + serviceId, + httpClient); + + return builder; + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs new file mode 100644 index 000000000000..ff214c441c2e --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs @@ -0,0 +1,157 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Net.Http; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Http; +using OpenAI; + +namespace Microsoft.SemanticKernel; + +#pragma warning disable IDE0039 // Use local function + +/// +/// Sponsor extensions class for . +/// +public static class OpenAIChatClientServiceCollectionExtensions +{ + #region Chat Completion + + /// + /// Adds the OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// OpenAI model name, see https://platform.openai.com/docs/models + /// OpenAI API key, see https://platform.openai.com/account/api-keys + /// OpenAI organization id. This is usually optional unless your account belongs to multiple organizations. + /// A local identifier for the given AI service + /// The HttpClient to use with this service. + /// The same instance as . + public static IServiceCollection AddOpenAIChatClient( + this IServiceCollection services, + string modelId, + string apiKey, + string? orgId = null, + string? serviceId = null, + HttpClient? httpClient = null) + { + Verify.NotNull(services); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + ILogger? logger = serviceProvider.GetService()?.CreateLogger(); + + return new Microsoft.Extensions.AI.OpenAIChatClient( + openAIClient: new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), options: GetClientOptions(orgId: orgId, httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider))), + modelId: modelId) + .AsKernelFunctionInvokingChatClient(logger); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + + /// + /// Adds the OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// OpenAI model id + /// to use for the service. If null, one must be available in the service provider when this service is resolved. + /// A local identifier for the given AI service + /// The same instance as . + public static IServiceCollection AddOpenAIChatClient(this IServiceCollection services, + string modelId, + OpenAIClient? openAIClient = null, + string? serviceId = null) + { + Verify.NotNull(services); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + ILogger? logger = serviceProvider.GetService()?.CreateLogger(); + + return new Microsoft.Extensions.AI.OpenAIChatClient( + openAIClient ?? serviceProvider.GetRequiredService(), + modelId) + .AsKernelFunctionInvokingChatClient(logger); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + + /// + /// Adds the Custom OpenAI chat completion service to the list. + /// + /// The instance to augment. + /// OpenAI model name, see https://platform.openai.com/docs/models + /// A Custom Message API compatible endpoint. + /// OpenAI API key, see https://platform.openai.com/account/api-keys + /// OpenAI organization id. This is usually optional unless your account belongs to multiple organizations. + /// A local identifier for the given AI service + /// The HttpClient to use with this service. + /// The same instance as . + public static IServiceCollection AddOpenAIChatClient( + this IServiceCollection services, + string modelId, + Uri endpoint, + string? apiKey = null, + string? orgId = null, + string? serviceId = null, + HttpClient? httpClient = null) + { + Verify.NotNull(services); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + ILogger? logger = serviceProvider.GetService()?.CreateLogger(); + + return new Microsoft.Extensions.AI.OpenAIChatClient( + openAIClient: new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), GetClientOptions(endpoint, orgId, HttpClientProvider.GetHttpClient(httpClient, serviceProvider))), + modelId: modelId) + .AsKernelFunctionInvokingChatClient(logger); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + + private static OpenAIClientOptions GetClientOptions( + Uri? endpoint = null, + string? orgId = null, + HttpClient? httpClient = null) + { + OpenAIClientOptions options = new(); + + if (endpoint is not null) + { + options.Endpoint = endpoint; + } + + if (orgId is not null) + { + options.OrganizationId = orgId; + } + + if (httpClient is not null) + { + options.Transport = new HttpClientPipelineTransport(httpClient); + } + + return options; + } + + /// + /// White space constant. + /// + private const string SingleSpace = " "; + #endregion +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs index c36ca453bd04..c480772750de 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs @@ -5,6 +5,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; namespace Microsoft.SemanticKernel.ChatCompletion; @@ -71,14 +72,15 @@ public static IChatCompletionService AsChatCompletionService(this IChatClient cl /// Creates a new that supports for function invocation with a . /// /// Target chat client service. + /// Optional logger to use for logging. /// Function invoking chat client. [Experimental("SKEXP0001")] - public static IChatClient AsKernelFunctionInvokingChatClient(this IChatClient client) + public static IChatClient AsKernelFunctionInvokingChatClient(this IChatClient client, ILogger? logger = null) { Verify.NotNull(client); return client is KernelFunctionInvokingChatClient kernelFunctionInvocationClient ? kernelFunctionInvocationClient - : new KernelFunctionInvokingChatClient(client); + : new KernelFunctionInvokingChatClient(client, logger); } } diff --git a/dotnet/src/SemanticKernel.UnitTests/Extensions/ClientResultExceptionExtensionsTests.cs b/dotnet/src/SemanticKernel.UnitTests/Extensions/ClientResultExceptionExtensionsTests.cs index c9c348d1ac44..c25fa0ccd249 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Extensions/ClientResultExceptionExtensionsTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Extensions/ClientResultExceptionExtensionsTests.cs @@ -64,7 +64,6 @@ public void ItProvideStatusForResponsesWithoutContent() // Assert Assert.NotNull(httpOperationException); Assert.NotNull(httpOperationException.StatusCode); - Assert.Null(httpOperationException.ResponseContent); Assert.Equal(exception, httpOperationException.InnerException); Assert.Equal(exception.Message, httpOperationException.Message); Assert.Equal(pipelineResponse.Status, (int)httpOperationException.StatusCode!); From 69ff2c2a57b6f76959b9212a3cff989958097daf Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Thu, 3 Apr 2025 13:23:46 +0100 Subject: [PATCH 02/26] Adding UT and Extension Methods --- ...FunctionInvocationFilterChatClientTests.cs | 776 ++++++++++++++++++ .../Core/AutoFunctionInvocationFilterTests.cs | 2 +- ...IKernelBuilderExtensionsChatClientTests.cs | 92 +++ ...viceCollectionExtensionsChatClientTests.cs | 114 +++ ...penAIKernelBuilderExtensions.ChatClient.cs | 2 +- ...IServiceCollectionExtensions.ChatClient.cs | 15 +- .../KernelFunctionInvocationContext.cs | 2 + .../KernelFunctionInvokingChatClient.cs | 16 +- 8 files changed, 1001 insertions(+), 18 deletions(-) create mode 100644 dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs create mode 100644 dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIKernelBuilderExtensionsChatClientTests.cs create mode 100644 dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIServiceCollectionExtensionsChatClientTests.cs diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs new file mode 100644 index 000000000000..9a8e91a25df5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs @@ -0,0 +1,776 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.OpenAI; +using Xunit; + +namespace SemanticKernel.Connectors.OpenAI.UnitTests.Core; + +public sealed class AutoFunctionInvocationFilterChatClientTests : IDisposable +{ + private readonly MultipleHttpMessageHandlerStub _messageHandlerStub; + private readonly HttpClient _httpClient; + + public AutoFunctionInvocationFilterChatClientTests() + { + this._messageHandlerStub = new MultipleHttpMessageHandlerStub(); + + this._httpClient = new HttpClient(this._messageHandlerStub, false); + } + + [Fact] + public async Task FiltersAreExecutedCorrectlyAsync() + { + // Arrange + int filterInvocations = 0; + int functionInvocations = 0; + int[] expectedRequestSequenceNumbers = [0, 0, 1, 1]; + int[] expectedFunctionSequenceNumbers = [0, 1, 0, 1]; + List requestSequenceNumbers = []; + List functionSequenceNumbers = []; + Kernel? contextKernel = null; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { functionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { functionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + contextKernel = context.Kernel; + + if (context.ChatHistory.Last() is OpenAIChatMessageContent content) + { + Assert.Equal(2, content.ToolCalls.Count); + } + + requestSequenceNumbers.Add(context.RequestSequenceIndex); + functionSequenceNumbers.Add(context.FunctionSequenceIndex); + + await next(context); + + filterInvocations++; + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + { + ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + })); + + // Assert + Assert.Equal(4, filterInvocations); + Assert.Equal(4, functionInvocations); + Assert.Equal(expectedRequestSequenceNumbers, requestSequenceNumbers); + Assert.Equal(expectedFunctionSequenceNumbers, functionSequenceNumbers); + Assert.Same(kernel, contextKernel); + Assert.Equal("Test chat response", result.ToString()); + } + + [Fact] + public async Task FunctionSequenceIndexIsCorrectForConcurrentCallsAsync() + { + // Arrange + List functionSequenceNumbers = []; + List expectedFunctionSequenceNumbers = [0, 1, 0, 1]; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + functionSequenceNumbers.Add(context.FunctionSequenceIndex); + + await next(context); + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(options: new() + { + AllowParallelCalls = true, + AllowConcurrentInvocation = true + }) + })); + + // Assert + Assert.Equal(expectedFunctionSequenceNumbers, functionSequenceNumbers); + } + + [Fact] + public async Task FiltersAreExecutedCorrectlyOnStreamingAsync() + { + // Arrange + int filterInvocations = 0; + int functionInvocations = 0; + List requestSequenceNumbers = []; + List functionSequenceNumbers = []; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { functionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { functionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + if (context.ChatHistory.Last() is OpenAIChatMessageContent content) + { + Assert.Equal(2, content.ToolCalls.Count); + } + + requestSequenceNumbers.Add(context.RequestSequenceIndex); + functionSequenceNumbers.Add(context.FunctionSequenceIndex); + + await next(context); + + filterInvocations++; + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }; + + // Act + await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings))) + { } + + // Assert + Assert.Equal(4, filterInvocations); + Assert.Equal(4, functionInvocations); + Assert.Equal([0, 0, 1, 1], requestSequenceNumbers); + Assert.Equal([0, 1, 0, 1], functionSequenceNumbers); + } + + [Fact] + public async Task DifferentWaysOfAddingFiltersWorkCorrectlyAsync() + { + // Arrange + var executionOrder = new List(); + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var filter1 = new AutoFunctionInvocationFilter(async (context, next) => + { + executionOrder.Add("Filter1-Invoking"); + await next(context); + }); + + var filter2 = new AutoFunctionInvocationFilter(async (context, next) => + { + executionOrder.Add("Filter2-Invoking"); + await next(context); + }); + + var builder = Kernel.CreateBuilder(); + + builder.Plugins.Add(plugin); + + builder.Services.AddSingleton((serviceProvider) => + { + return new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient); + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + + // Case #1 - Add filter to services + builder.Services.AddSingleton(filter1); + + var kernel = builder.Build(); + + // Case #2 - Add filter to kernel + kernel.AutoFunctionInvocationFilters.Add(filter2); + + var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + { + ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + })); + + // Assert + Assert.Equal("Filter1-Invoking", executionOrder[0]); + Assert.Equal("Filter2-Invoking", executionOrder[1]); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task MultipleFiltersAreExecutedInOrderAsync(bool isStreaming) + { + // Arrange + var executionOrder = new List(); + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var filter1 = new AutoFunctionInvocationFilter(async (context, next) => + { + executionOrder.Add("Filter1-Invoking"); + await next(context); + executionOrder.Add("Filter1-Invoked"); + }); + + var filter2 = new AutoFunctionInvocationFilter(async (context, next) => + { + executionOrder.Add("Filter2-Invoking"); + await next(context); + executionOrder.Add("Filter2-Invoked"); + }); + + var filter3 = new AutoFunctionInvocationFilter(async (context, next) => + { + executionOrder.Add("Filter3-Invoking"); + await next(context); + executionOrder.Add("Filter3-Invoked"); + }); + + var builder = Kernel.CreateBuilder(); + + builder.Plugins.Add(plugin); + + builder.Services.AddSingleton((serviceProvider) => + { + return new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient); + }); + + builder.Services.AddSingleton(filter1); + builder.Services.AddSingleton(filter2); + builder.Services.AddSingleton(filter3); + + var kernel = builder.Build(); + + var arguments = new KernelArguments(new OpenAIPromptExecutionSettings + { + ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + }); + + // Act + if (isStreaming) + { + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", arguments)) + { } + } + else + { + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + await kernel.InvokePromptAsync("Test prompt", arguments); + } + + // Assert + Assert.Equal("Filter1-Invoking", executionOrder[0]); + Assert.Equal("Filter2-Invoking", executionOrder[1]); + Assert.Equal("Filter3-Invoking", executionOrder[2]); + Assert.Equal("Filter3-Invoked", executionOrder[3]); + Assert.Equal("Filter2-Invoked", executionOrder[4]); + Assert.Equal("Filter1-Invoked", executionOrder[5]); + } + + [Fact] + public async Task FilterCanOverrideArgumentsAsync() + { + // Arrange + const string NewValue = "NewValue"; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + context.Arguments!["parameter"] = NewValue; + await next(context); + context.Terminate = true; + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + { + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() + })); + + // Assert + Assert.Equal("NewValue", result.ToString()); + } + + [Fact] + public async Task FilterCanHandleExceptionAsync() + { + // Arrange + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { throw new KernelException("Exception from Function1"); }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => "Result from Function2", "Function2"); + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + try + { + await next(context); + } + catch (KernelException exception) + { + Assert.Equal("Exception from Function1", exception.Message); + context.Result = new FunctionResult(context.Result, "Result from filter"); + } + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + var chatCompletion = new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient); + + var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }; + + var chatHistory = new ChatHistory(); + chatHistory.AddSystemMessage("System message"); + + // Act + var result = await chatCompletion.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel); + + var firstFunctionResult = chatHistory[^2].Content; + var secondFunctionResult = chatHistory[^1].Content; + + // Assert + Assert.Equal("Result from filter", firstFunctionResult); + Assert.Equal("Result from Function2", secondFunctionResult); + } + + [Fact] + public async Task FilterCanHandleExceptionOnStreamingAsync() + { + // Arrange + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { throw new KernelException("Exception from Function1"); }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => "Result from Function2", "Function2"); + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + try + { + await next(context); + } + catch (KernelException) + { + context.Result = new FunctionResult(context.Result, "Result from filter"); + } + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + var chatCompletion = new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient); + + var chatHistory = new ChatHistory(); + var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }; + + // Act + await foreach (var item in chatCompletion.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel)) + { } + + var firstFunctionResult = chatHistory[^2].Content; + var secondFunctionResult = chatHistory[^1].Content; + + // Assert + Assert.Equal("Result from filter", firstFunctionResult); + Assert.Equal("Result from Function2", secondFunctionResult); + } + + [Fact] + public async Task FiltersCanSkipFunctionExecutionAsync() + { + // Arrange + int filterInvocations = 0; + int firstFunctionInvocations = 0; + int secondFunctionInvocations = 0; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + // Filter delegate is invoked only for second function, the first one should be skipped. + if (context.Function.Name == "Function2") + { + await next(context); + } + + filterInvocations++; + }); + + using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(File.ReadAllText("TestData/filters_multiple_function_calls_test_response.json")) }; + using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(File.ReadAllText("TestData/chat_completion_test_response.json")) }; + + this._messageHandlerStub.ResponsesToReturn = [response1, response2]; + + // Act + var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + { + ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + })); + + // Assert + Assert.Equal(2, filterInvocations); + Assert.Equal(0, firstFunctionInvocations); + Assert.Equal(1, secondFunctionInvocations); + } + + [Fact] + public async Task PreFilterCanTerminateOperationAsync() + { + // Arrange + int firstFunctionInvocations = 0; + int secondFunctionInvocations = 0; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + // Terminating before first function, so all functions won't be invoked. + context.Terminate = true; + + await next(context); + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + { + ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + })); + + // Assert + Assert.Equal(0, firstFunctionInvocations); + Assert.Equal(0, secondFunctionInvocations); + } + + [Fact] + public async Task PreFilterCanTerminateOperationOnStreamingAsync() + { + // Arrange + int firstFunctionInvocations = 0; + int secondFunctionInvocations = 0; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + // Terminating before first function, so all functions won't be invoked. + context.Terminate = true; + + await next(context); + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }; + + // Act + await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings))) + { } + + // Assert + Assert.Equal(0, firstFunctionInvocations); + Assert.Equal(0, secondFunctionInvocations); + } + + [Fact] + public async Task PostFilterCanTerminateOperationAsync() + { + // Arrange + int firstFunctionInvocations = 0; + int secondFunctionInvocations = 0; + List requestSequenceNumbers = []; + List functionSequenceNumbers = []; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + requestSequenceNumbers.Add(context.RequestSequenceIndex); + functionSequenceNumbers.Add(context.FunctionSequenceIndex); + + await next(context); + + // Terminating after first function, so second function won't be invoked. + context.Terminate = true; + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + // Act + var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + { + ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + })); + + // Assert + Assert.Equal(1, firstFunctionInvocations); + Assert.Equal(0, secondFunctionInvocations); + Assert.Equal([0], requestSequenceNumbers); + Assert.Equal([0], functionSequenceNumbers); + + // Results of function invoked before termination should be returned + var lastMessageContent = result.GetValue(); + Assert.NotNull(lastMessageContent); + + Assert.Equal("function1-value", lastMessageContent.Content); + Assert.Equal(AuthorRole.Tool, lastMessageContent.Role); + } + + [Fact] + public async Task PostFilterCanTerminateOperationOnStreamingAsync() + { + // Arrange + int firstFunctionInvocations = 0; + int secondFunctionInvocations = 0; + List requestSequenceNumbers = []; + List functionSequenceNumbers = []; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var kernel = this.GetKernelWithFilter(plugin, async (context, next) => + { + requestSequenceNumbers.Add(context.RequestSequenceIndex); + functionSequenceNumbers.Add(context.FunctionSequenceIndex); + + await next(context); + + // Terminating after first function, so second function won't be invoked. + context.Terminate = true; + }); + + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }; + + List streamingContent = []; + + // Act + await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings))) + { + streamingContent.Add(item); + } + + // Assert + Assert.Equal(1, firstFunctionInvocations); + Assert.Equal(0, secondFunctionInvocations); + Assert.Equal([0], requestSequenceNumbers); + Assert.Equal([0], functionSequenceNumbers); + + // Results of function invoked before termination should be returned + Assert.Equal(3, streamingContent.Count); + + var lastMessageContent = streamingContent[^1] as StreamingChatMessageContent; + Assert.NotNull(lastMessageContent); + + Assert.Equal("function1-value", lastMessageContent.Content); + Assert.Equal(AuthorRole.Tool, lastMessageContent.Role); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming) + { + // Arrange + bool? actualStreamingFlag = null; + + var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1"); + var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function2"); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]); + + var filter = new AutoFunctionInvocationFilter(async (context, next) => + { + actualStreamingFlag = context.IsStreaming; + await next(context); + }); + + var builder = Kernel.CreateBuilder(); + + builder.Plugins.Add(plugin); + + builder.Services.AddSingleton((serviceProvider) => + { + return new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient); + }); + + builder.Services.AddSingleton(filter); + + var kernel = builder.Build(); + + var arguments = new KernelArguments(new OpenAIPromptExecutionSettings + { + ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + }); + + // Act + if (isStreaming) + { + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + await kernel.InvokePromptStreamingAsync("Test prompt", arguments).ToListAsync(); + } + else + { + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + await kernel.InvokePromptAsync("Test prompt", arguments); + } + + // Assert + Assert.Equal(isStreaming, actualStreamingFlag); + } + + [Fact] + public async Task PromptExecutionSettingsArePropagatedFromInvokePromptToFilterContextAsync() + { + // Arrange + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]); + + AutoFunctionInvocationContext? actualContext = null; + + var kernel = this.GetKernelWithFilter(plugin, (context, next) => + { + actualContext = context; + return Task.CompletedTask; + }); + + var expectedExecutionSettings = new OpenAIPromptExecutionSettings + { + ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + }; + + // Act + var result = await kernel.InvokePromptAsync("Test prompt", new(expectedExecutionSettings)); + + // Assert + Assert.NotNull(actualContext); + Assert.Same(expectedExecutionSettings, actualContext!.ExecutionSettings); + } + + [Fact] + public async Task PromptExecutionSettingsArePropagatedFromInvokePromptStreamingToFilterContextAsync() + { + // Arrange + this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); + + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]); + + AutoFunctionInvocationContext? actualContext = null; + + var kernel = this.GetKernelWithFilter(plugin, (context, next) => + { + actualContext = context; + return Task.CompletedTask; + }); + + var expectedExecutionSettings = new OpenAIPromptExecutionSettings + { + ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + }; + + // Act + await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(expectedExecutionSettings))) + { } + + // Assert + Assert.NotNull(actualContext); + Assert.Same(expectedExecutionSettings, actualContext!.ExecutionSettings); + } + + public void Dispose() + { + this._httpClient.Dispose(); + this._messageHandlerStub.Dispose(); + } + + #region private + +#pragma warning disable CA2000 // Dispose objects before losing scope + private static List GetFunctionCallingResponses() + { + return [ + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_multiple_function_calls_test_response.json")) }, + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_multiple_function_calls_test_response.json")) }, + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_test_response.json")) } + ]; + } + + private static List GetFunctionCallingStreamingResponses() + { + return [ + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_streaming_multiple_function_calls_test_response.txt")) }, + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_streaming_multiple_function_calls_test_response.txt")) }, + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_streaming_test_response.txt")) } + ]; + } +#pragma warning restore CA2000 + + private Kernel GetKernelWithFilter( + KernelPlugin plugin, + Func, Task>? onAutoFunctionInvocation) + { + var builder = Kernel.CreateBuilder(); + var filter = new AutoFunctionInvocationFilter(onAutoFunctionInvocation); + + builder.Plugins.Add(plugin); + builder.Services.AddSingleton(filter); + + builder.AddOpenAIChatClient("model-id", "test-api-key", "organization-id", httpClient: this._httpClient); + + return builder.Build(); + } + + private sealed class AutoFunctionInvocationFilter( + Func, Task>? onAutoFunctionInvocation) : IAutoFunctionInvocationFilter + { + private readonly Func, Task>? _onAutoFunctionInvocation = onAutoFunctionInvocation; + + public Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) => + this._onAutoFunctionInvocation?.Invoke(context, next) ?? Task.CompletedTask; + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs index 19992be01667..1f23d84c4dfe 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs @@ -312,7 +312,7 @@ public async Task FilterCanOverrideArgumentsAsync() // Act var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings { - ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() })); // Assert diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIKernelBuilderExtensionsChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIKernelBuilderExtensionsChatClientTests.cs new file mode 100644 index 000000000000..437d347aa194 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIKernelBuilderExtensionsChatClientTests.cs @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using OpenAI; +using Xunit; + +namespace SemanticKernel.Connectors.OpenAI.UnitTests.Extensions; + +public class OpenAIKernelBuilderExtensionsChatClientTests +{ + [Fact] + public void AddOpenAIChatClientNullArgsThrow() + { + // Arrange + IKernelBuilder builder = null!; + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + + // Act & Assert + var exception = Assert.Throws(() => builder.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId)); + Assert.Equal("builder", exception.ParamName); + + exception = Assert.Throws(() => builder.AddOpenAIChatClient(modelId, new OpenAIClient(apiKey), serviceId)); + Assert.Equal("builder", exception.ParamName); + + using var httpClient = new HttpClient(); + exception = Assert.Throws(() => builder.AddOpenAIChatClient(modelId, new Uri("http://localhost"), apiKey, orgId, serviceId, httpClient)); + Assert.Equal("builder", exception.ParamName); + } + + [Fact] + public void AddOpenAIChatClientDefaultValidParametersRegistersService() + { + // Arrange + var builder = Kernel.CreateBuilder(); + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + + // Act + builder.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId); + + // Assert + var kernel = builder.Build(); + Assert.NotNull(kernel.GetRequiredService()); + Assert.NotNull(kernel.GetRequiredService(serviceId)); + } + + [Fact] + public void AddOpenAIChatClientOpenAIClientValidParametersRegistersService() + { + // Arrange + var builder = Kernel.CreateBuilder(); + string modelId = "gpt-3.5-turbo"; + var openAIClient = new OpenAIClient("test_api_key"); + string serviceId = "test_service_id"; + + // Act + builder.AddOpenAIChatClient(modelId, openAIClient, serviceId); + + // Assert + var kernel = builder.Build(); + Assert.NotNull(kernel.GetRequiredService()); + Assert.NotNull(kernel.GetRequiredService(serviceId)); + } + + [Fact] + public void AddOpenAIChatClientCustomEndpointValidParametersRegistersService() + { + // Arrange + var builder = Kernel.CreateBuilder(); + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + using var httpClient = new HttpClient(); + + // Act + builder.AddOpenAIChatClient(modelId, new Uri("http://localhost"), apiKey, orgId, serviceId, httpClient); + + // Assert + var kernel = builder.Build(); + Assert.NotNull(kernel.GetRequiredService()); + Assert.NotNull(kernel.GetRequiredService(serviceId)); + } +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIServiceCollectionExtensionsChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIServiceCollectionExtensionsChatClientTests.cs new file mode 100644 index 000000000000..7a3888b95f30 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIServiceCollectionExtensionsChatClientTests.cs @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using OpenAI; +using Xunit; + +namespace SemanticKernel.Connectors.OpenAI.UnitTests.Extensions; + +public class OpenAIServiceCollectionExtensionsChatClientTests +{ + [Fact] + public void AddOpenAIChatClientNullArgsThrow() + { + // Arrange + ServiceCollection services = null!; + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + + // Act & Assert + var exception = Assert.Throws(() => services.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId)); + Assert.Equal("services", exception.ParamName); + + exception = Assert.Throws(() => services.AddOpenAIChatClient(modelId, new OpenAIClient(apiKey), serviceId)); + Assert.Equal("services", exception.ParamName); + + using var httpClient = new HttpClient(); + exception = Assert.Throws(() => services.AddOpenAIChatClient(modelId, new Uri("http://localhost"), apiKey, orgId, serviceId, httpClient)); + Assert.Equal("services", exception.ParamName); + } + + [Fact] + public void AddOpenAIChatClientDefaultValidParametersRegistersService() + { + // Arrange + var services = new ServiceCollection(); + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + + // Act + services.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId); + + // Assert + var serviceProvider = services.BuildServiceProvider(); + var chatClient = serviceProvider.GetKeyedService(serviceId); + Assert.NotNull(chatClient); + } + + [Fact] + public void AddOpenAIChatClientOpenAIClientValidParametersRegistersService() + { + // Arrange + var services = new ServiceCollection(); + string modelId = "gpt-3.5-turbo"; + var openAIClient = new OpenAIClient("test_api_key"); + string serviceId = "test_service_id"; + + // Act + services.AddOpenAIChatClient(modelId, openAIClient, serviceId); + + // Assert + var serviceProvider = services.BuildServiceProvider(); + var chatClient = serviceProvider.GetKeyedService(serviceId); + Assert.NotNull(chatClient); + } + + [Fact] + public void AddOpenAIChatClientCustomEndpointValidParametersRegistersService() + { + // Arrange + var services = new ServiceCollection(); + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + using var httpClient = new HttpClient(); + // Act + services.AddOpenAIChatClient(modelId, new Uri("http://localhost"), apiKey, orgId, serviceId, httpClient); + // Assert + var serviceProvider = services.BuildServiceProvider(); + var chatClient = serviceProvider.GetKeyedService(serviceId); + Assert.NotNull(chatClient); + } + + [Fact] + public void AddOpenAIChatClientWorksWithKernel() + { + var services = new ServiceCollection(); + string modelId = "gpt-3.5-turbo"; + string apiKey = "test_api_key"; + string orgId = "test_org_id"; + string serviceId = "test_service_id"; + + // Act + services.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId); + services.AddKernel(); + + var serviceProvider = services.BuildServiceProvider(); + var kernel = serviceProvider.GetRequiredService(); + + var serviceFromCollection = serviceProvider.GetKeyedService(serviceId); + var serviceFromKernel = kernel.GetRequiredService(serviceId); + + Assert.NotNull(serviceFromKernel); + Assert.Same(serviceFromCollection, serviceFromKernel); + } +} diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs index e38aa5289104..ac0005f5d5c5 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs @@ -10,6 +10,7 @@ namespace Microsoft.SemanticKernel; /// /// Sponsor extensions class for . /// +[Experimental("SKEXP0010")] public static class OpenAIChatClientKernelBuilderExtensions { #region Chat Completion @@ -79,7 +80,6 @@ public static IKernelBuilder AddOpenAIChatClient( /// A local identifier for the given AI service /// The HttpClient to use with this service. /// The same instance as . - [Experimental("SKEXP0010")] public static IKernelBuilder AddOpenAIChatClient( this IKernelBuilder builder, string modelId, diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs index ff214c441c2e..2ae915869698 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs @@ -3,6 +3,7 @@ using System; using System.ClientModel; using System.ClientModel.Primitives; +using System.Diagnostics.CodeAnalysis; using System.Net.Http; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; @@ -13,14 +14,16 @@ namespace Microsoft.SemanticKernel; -#pragma warning disable IDE0039 // Use local function - /// /// Sponsor extensions class for . /// +[Experimental("SKEXP0010")] public static class OpenAIChatClientServiceCollectionExtensions { - #region Chat Completion + /// + /// White space constant. + /// + private const string SingleSpace = " "; /// /// Adds the OpenAI chat completion service to the list. @@ -148,10 +151,4 @@ private static OpenAIClientOptions GetClientOptions( return options; } - - /// - /// White space constant. - /// - private const string SingleSpace = " "; - #endregion } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs index da5af46620fd..d782a7262e68 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs @@ -18,6 +18,8 @@ namespace Microsoft.SemanticKernel.ChatCompletion; [ExcludeFromCodeCoverage] internal sealed class KernelFunctionInvocationContext { + internal const string KernelKey = "Kernel"; + /// /// A nop function used to allow to be non-nullable. Default instances of /// start with this as the target function. diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index 1a59b8f5ccbd..117e1ebb4f81 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -186,6 +186,7 @@ public override async Task GetResponseAsync( IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(messages); + options?.AdditionalProperties?.TryGetValue(KernelFunctionInvocationContext.KernelKey, out var kernel); // A single request into this GetResponseAsync may result in multiple requests to the inner client. // Create an activity to group them together for better observability. @@ -254,7 +255,7 @@ public override async Task GetResponseAsync( // Add the responses from the function calls into the augmented history and also into the tracked // list of response messages. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, cancellationToken).ConfigureAwait(false); + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, kernel, cancellationToken).ConfigureAwait(false); responseMessages.AddRange(modeAndMessages.MessagesAdded); if (UpdateOptionsForMode(modeAndMessages.Mode, ref options!, response.ChatThreadId)) @@ -504,15 +505,16 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti /// The options used for the response being processed. /// The function call contents representing the functions to be invoked. /// The iteration number of how many roundtrips have been made to the inner client. + /// The containing the auto function invocations. /// The to monitor for cancellation requests. /// A value indicating how the caller should proceed. private async Task<(ContinueMode Mode, IList MessagesAdded)> ProcessFunctionCallsAsync( - List messages, ChatOptions options, List functionCallContents, int iteration, CancellationToken cancellationToken) + List messages, ChatOptions options, List functionCallContents, int iteration, Kernel kernel, CancellationToken cancellationToken) { // We must add a response for every tool call, regardless of whether we successfully executed it or not. // If we successfully execute it, we'll add the result. If we don't, we'll add an error. - Debug.Assert(functionCallContents.Count > 0, "Expecteded at least one function call."); + Debug.Assert(functionCallContents.Count > 0, "Expected at least one function call."); // Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel. if (functionCallContents.Count == 1) @@ -694,13 +696,12 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi } /// Invokes the function asynchronously. - /// - /// The function invocation context detailing the function to be invoked and its arguments along with additional request information. - /// + /// The function invocation context detailing the function to be invoked and its arguments along with additional request information. + /// The containing the auto function invocations. /// The to monitor for cancellation requests. The default is . /// The result of the function invocation, or if the function invocation returned . /// is . - internal async Task InvokeFunctionAsync(KernelFunctionInvocationContext context, CancellationToken cancellationToken) + internal async Task InvokeFunctionAsync(KernelFunctionInvocationContext context, Kernel kernel, CancellationToken cancellationToken) { Verify.NotNull(context); @@ -724,6 +725,7 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi try { CurrentContext = context; + context = this.OnAutoFunctionInvocationAsync(kernel, context, ) result = await context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken).ConfigureAwait(false); } catch (Exception e) From dbd0aab5bda99251b44da15cee02f351aa3d3100 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Thu, 3 Apr 2025 14:25:01 +0100 Subject: [PATCH 03/26] Function Call impl --- .../AI/ChatClient/ChatMessageExtensions.cs | 13 ++- .../KernelFunctionInvocationContext.cs | 3 + ...rnelFunctionInvocationContextExtensions.cs | 38 +++++++ .../KernelFunctionInvokingChatClient.cs | 107 +++++++++++++++--- 4 files changed, 142 insertions(+), 19 deletions(-) create mode 100644 dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs index 3d190a48c5e4..1501cb71d988 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatMessageExtensions.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Collections.Generic; using Microsoft.Extensions.AI; namespace Microsoft.SemanticKernel.ChatCompletion; @@ -7,7 +8,6 @@ namespace Microsoft.SemanticKernel.ChatCompletion; internal static class ChatMessageExtensions { /// Converts a to a . - /// This conversion should not be necessary once SK eventually adopts the shared content types. internal static ChatMessageContent ToChatMessageContent(this ChatMessage message, Microsoft.Extensions.AI.ChatResponse? response = null) { ChatMessageContent result = new() @@ -46,4 +46,15 @@ internal static ChatMessageContent ToChatMessageContent(this ChatMessage message return result; } + + /// Converts a list of to a . + internal static ChatHistory ToChatHistory(this IEnumerable chatMessages) + { + ChatHistory chatHistory = []; + foreach (var message in chatMessages) + { + chatHistory.Add(message.ToChatMessageContent()); + } + return chatHistory; + } } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs index d782a7262e68..344e021da7ad 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs @@ -40,6 +40,9 @@ internal KernelFunctionInvocationContext() { } + /// Chat response information + public ChatResponse? Response { get; set; } + /// Gets or sets the function call content information associated with this invocation. public Microsoft.Extensions.AI.FunctionCallContent CallContent { diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs new file mode 100644 index 000000000000..61449fcfa55f --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.Extensions.AI; + +namespace Microsoft.SemanticKernel.ChatCompletion; + +// Slight modified source from +// https://raw.githubusercontent.com/dotnet/extensions/refs/heads/main/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs + +/// Provides context for an in-flight function invocation. +internal static class KernelFunctionInvocationContextExtensions +{ + internal static AutoFunctionInvocationContext ToAutoFunctionInvocationContext(this KernelFunctionInvocationContext context, Microsoft.Extensions.AI.FunctionCallContent functionCall, Kernel kernel, ChatMessage message, bool isStreaming) + { + if (context is null) + { + throw new ArgumentNullException(nameof(context)); + } + + return new AutoFunctionInvocationContext( + kernel: kernel, + function: context.Function.AsKernelFunction(), + result: null, + chatHistory: context.Messages.ToChatHistory(), + chatMessageContent: message.ToChatMessageContent()) + { + Arguments = new(functionCall.Arguments), + FunctionName = functionCall.Name, + FunctionCount = context.FunctionCount, + FunctionSequenceIndex = context.FunctionCallIndex, + RequestSequenceIndex = context.Iteration, + IsStreaming = isStreaming, + ToolCallId = functionCall.CallId, + ExecutionSettings = context.Options.ToPromptExecutionSettings() + }; + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index 117e1ebb4f81..9096f3f310e8 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -12,6 +12,7 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using static System.Runtime.InteropServices.JavaScript.JSType; #pragma warning disable CA2213 // Disposable fields should be disposed #pragma warning disable IDE0009 // Use explicit 'this.' qualifier @@ -255,7 +256,7 @@ public override async Task GetResponseAsync( // Add the responses from the function calls into the augmented history and also into the tracked // list of response messages. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, kernel, cancellationToken).ConfigureAwait(false); + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, response, functionCallContents!, iteration, kernel, cancellationToken).ConfigureAwait(false); responseMessages.AddRange(modeAndMessages.MessagesAdded); if (UpdateOptionsForMode(modeAndMessages.Mode, ref options!, response.ChatThreadId)) @@ -503,13 +504,14 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti /// /// The current chat contents, inclusive of the function call contents being processed. /// The options used for the response being processed. + /// The response from the inner client. /// The function call contents representing the functions to be invoked. /// The iteration number of how many roundtrips have been made to the inner client. /// The containing the auto function invocations. /// The to monitor for cancellation requests. /// A value indicating how the caller should proceed. private async Task<(ContinueMode Mode, IList MessagesAdded)> ProcessFunctionCallsAsync( - List messages, ChatOptions options, List functionCallContents, int iteration, Kernel kernel, CancellationToken cancellationToken) + List messages, ChatOptions options, ChatResponse response, List functionCallContents, int iteration, Kernel kernel, CancellationToken cancellationToken) { // We must add a response for every tool call, regardless of whether we successfully executed it or not. // If we successfully execute it, we'll add the result. If we don't, we'll add an error. @@ -520,7 +522,7 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti if (functionCallContents.Count == 1) { FunctionInvocationResult result = await ProcessFunctionCallAsync( - messages, options, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false); + messages, options, response, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false); IList added = CreateResponseMessages([result]); ThrowIfNoFunctionResultsAdded(added); @@ -585,13 +587,14 @@ private void ThrowIfNoFunctionResultsAdded(IList? messages) /// Processes the function call described in []. /// The current chat contents, inclusive of the function call contents being processed. /// The options used for the response being processed. + /// The response from the inner client. /// The function call contents representing all the functions being invoked. /// The iteration number of how many roundtrips have been made to the inner client. /// The 0-based index of the function being called out of . /// The to monitor for cancellation requests. /// A value indicating how the caller should proceed. private async Task ProcessFunctionCallAsync( - List messages, ChatOptions options, List callContents, + List messages, ChatOptions options, ChatResponse response, List callContents, int iteration, int functionCallIndex, CancellationToken cancellationToken) { var callContent = callContents[functionCallIndex]; @@ -695,17 +698,60 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi } } + /// + /// Invokes the auto function invocation filters. + /// + /// The . + /// The auto function invocation context. + /// The function to call after the filters. + /// The auto function invocation context. + private async Task OnAutoFunctionInvocationAsync( + Kernel kernel, + AutoFunctionInvocationContext context, + Func functionCallCallback) + { + await this.InvokeFilterOrFunctionAsync(kernel.AutoFunctionInvocationFilters, functionCallCallback, context).ConfigureAwait(false); + + return context; + } + + /// + /// This method will execute auto function invocation filters and function recursively. + /// If there are no registered filters, just function will be executed. + /// If there are registered filters, filter on position will be executed. + /// Second parameter of filter is callback. It can be either filter on + 1 position or function if there are no remaining filters to execute. + /// Function will be always executed as last step after all filters. + /// + private async Task InvokeFilterOrFunctionAsync( + IList? autoFunctionInvocationFilters, + Func functionCallCallback, + KernelFunctionInvocationContext context, + int index = 0) + { + if (autoFunctionInvocationFilters is { Count: > 0 } && index < autoFunctionInvocationFilters.Count) + { + await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( + context.ToAutoFunctionInvocationContext(), + (context) => this.InvokeFilterOrFunctionAsync(autoFunctionInvocationFilters, functionCallCallback, context, index + 1) + ).ConfigureAwait(false); + } + else + { + await functionCallCallback(context).ConfigureAwait(false); + } + } + /// Invokes the function asynchronously. - /// The function invocation context detailing the function to be invoked and its arguments along with additional request information. + /// The function invocation context detailing the function to be invoked and its arguments along with additional request information. /// The containing the auto function invocations. /// The to monitor for cancellation requests. The default is . /// The result of the function invocation, or if the function invocation returned . - /// is . - internal async Task InvokeFunctionAsync(KernelFunctionInvocationContext context, Kernel kernel, CancellationToken cancellationToken) + /// is . + internal async Task InvokeFunctionAsync(KernelFunctionInvocationContext invocationContext, Kernel kernel, CancellationToken cancellationToken) { - Verify.NotNull(context); + Verify.NotNull(invocationContext); - using Activity? activity = _activitySource?.StartActivity(context.Function.Name); + using Activity? activity = _activitySource?.StartActivity(invocationContext.Function.Name); long startingTimestamp = 0; if (_logger.IsEnabled(LogLevel.Debug)) @@ -713,20 +759,45 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi startingTimestamp = Stopwatch.GetTimestamp(); if (_logger.IsEnabled(LogLevel.Trace)) { - LogInvokingSensitive(context.Function.Name, LoggingAsJson(context.CallContent.Arguments, context.Function.JsonSerializerOptions)); + LogInvokingSensitive(invocationContext.Function.Name, LoggingAsJson(invocationContext.CallContent.Arguments, invocationContext.Function.JsonSerializerOptions)); } else { - LogInvoking(context.Function.Name); + LogInvoking(invocationContext.Function.Name); } } object? result = null; try { - CurrentContext = context; - context = this.OnAutoFunctionInvocationAsync(kernel, context, ) - result = await context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken).ConfigureAwait(false); + CurrentContext = invocationContext; + var autoFunctionInvocationContext = invocationContext.ToAutoFunctionInvocationContext(kernel, ) + invocationContext = this.OnAutoFunctionInvocationAsync( + kernel, + invocationContext, + async (context) => + { + // Check if filter requested termination + if (context.Terminate) + { + return; + } + + // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any + // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here, + // as the called function could in turn telling the model about itself as a possible candidate for invocation. + KernelArguments? arguments = null; + if (context.CallContent.Arguments is not null) + { + arguments = new(context.CallContent.Arguments); + } + + var result = await invocationContext.Function.InvokeAsync(invocationContext.CallContent.Arguments, cancellationToken).ConfigureAwait(false); + context.Result = new FunctionResult(context.Function.AsKernelFunction(), result); + }).ConfigureAwait(false); + + ) + result = } catch (Exception e) { @@ -738,11 +809,11 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi if (e is OperationCanceledException) { - LogInvocationCanceled(context.Function.Name); + LogInvocationCanceled(invocationContext.Function.Name); } else { - LogInvocationFailed(context.Function.Name, e); + LogInvocationFailed(invocationContext.Function.Name, e); } throw; @@ -755,11 +826,11 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi if (result is not null && _logger.IsEnabled(LogLevel.Trace)) { - LogInvocationCompletedSensitive(context.Function.Name, elapsed, LoggingAsJson(result, context.Function.JsonSerializerOptions)); + LogInvocationCompletedSensitive(invocationContext.Function.Name, elapsed, LoggingAsJson(result, invocationContext.Function.JsonSerializerOptions)); } else { - LogInvocationCompleted(context.Function.Name, elapsed); + LogInvocationCompleted(invocationContext.Function.Name, elapsed); } } } From 9c892e1efb7af30c093ed3deb846e0ae2e409bbf Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Fri, 4 Apr 2025 18:42:56 +0100 Subject: [PATCH 04/26] Auto vs KernelFIC --- .../KernelFunctionInvocationContext.cs | 5 --- ...rnelFunctionInvocationContextExtensions.cs | 20 +++++----- .../KernelFunctionInvokingChatClient.cs | 38 ++++++++++--------- .../AutoFunctionInvocationContext.cs | 9 ++++- 4 files changed, 38 insertions(+), 34 deletions(-) diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs index 344e021da7ad..da5af46620fd 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs @@ -18,8 +18,6 @@ namespace Microsoft.SemanticKernel.ChatCompletion; [ExcludeFromCodeCoverage] internal sealed class KernelFunctionInvocationContext { - internal const string KernelKey = "Kernel"; - /// /// A nop function used to allow to be non-nullable. Default instances of /// start with this as the target function. @@ -40,9 +38,6 @@ internal KernelFunctionInvocationContext() { } - /// Chat response information - public ChatResponse? Response { get; set; } - /// Gets or sets the function call content information associated with this invocation. public Microsoft.Extensions.AI.FunctionCallContent CallContent { diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs index 61449fcfa55f..475698c89fd5 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using Microsoft.Extensions.AI; +using System.Linq; namespace Microsoft.SemanticKernel.ChatCompletion; @@ -11,27 +11,27 @@ namespace Microsoft.SemanticKernel.ChatCompletion; /// Provides context for an in-flight function invocation. internal static class KernelFunctionInvocationContextExtensions { - internal static AutoFunctionInvocationContext ToAutoFunctionInvocationContext(this KernelFunctionInvocationContext context, Microsoft.Extensions.AI.FunctionCallContent functionCall, Kernel kernel, ChatMessage message, bool isStreaming) + internal static AutoFunctionInvocationContext ToAutoFunctionInvocationContext(this KernelFunctionInvocationContext context) { if (context is null) { throw new ArgumentNullException(nameof(context)); } + var kernelFunction = context.Function.AsKernelFunction(); return new AutoFunctionInvocationContext( - kernel: kernel, - function: context.Function.AsKernelFunction(), - result: null, + kernel: context.Kernel, + function: kernelFunction, + result: context.Result, chatHistory: context.Messages.ToChatHistory(), - chatMessageContent: message.ToChatMessageContent()) + chatMessageContent: context.Response!.Messages.Last().ToChatMessageContent()) { - Arguments = new(functionCall.Arguments), - FunctionName = functionCall.Name, + Arguments = context.CallContent.Arguments is null ? null : new(context.CallContent.Arguments), FunctionCount = context.FunctionCount, FunctionSequenceIndex = context.FunctionCallIndex, RequestSequenceIndex = context.Iteration, - IsStreaming = isStreaming, - ToolCallId = functionCall.CallId, + IsStreaming = context.IsStreaming, + ToolCallId = context.CallContent.CallId, ExecutionSettings = context.Options.ToPromptExecutionSettings() }; } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index 9096f3f310e8..99c689d61369 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -585,6 +585,7 @@ private void ThrowIfNoFunctionResultsAdded(IList? messages) } /// Processes the function call described in []. + /// The current chat contents, inclusive of the function call contents being processed. /// The options used for the response being processed. /// The response from the inner client. @@ -594,6 +595,7 @@ private void ThrowIfNoFunctionResultsAdded(IList? messages) /// The to monitor for cancellation requests. /// A value indicating how the caller should proceed. private async Task ProcessFunctionCallAsync( + Kernel kernel, List messages, ChatOptions options, ChatResponse response, List callContents, int iteration, int functionCallIndex, CancellationToken cancellationToken) { @@ -608,7 +610,9 @@ private async Task ProcessFunctionCallAsync( KernelFunctionInvocationContext context = new() { + Kernel = kernel, Messages = messages, + Response = response, Options = options, CallContent = callContent, Function = function, @@ -701,18 +705,17 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi /// /// Invokes the auto function invocation filters. /// - /// The . - /// The auto function invocation context. + /// The auto function invocation context. /// The function to call after the filters. /// The auto function invocation context. - private async Task OnAutoFunctionInvocationAsync( - Kernel kernel, - AutoFunctionInvocationContext context, + private async Task OnAutoFunctionInvocationAsync( + KernelFunctionInvocationContext kernelContext, Func functionCallCallback) { - await this.InvokeFilterOrFunctionAsync(kernel.AutoFunctionInvocationFilters, functionCallCallback, context).ConfigureAwait(false); + var autoContext = invocationContext.ToAutoFunctionInvocationContext(); + await this.InvokeFilterOrFunctionAsync(functionCallCallback, kernelContext).ConfigureAwait(false); - return context; + return kernelContext; } /// @@ -723,16 +726,17 @@ private async Task OnAutoFunctionInvocationAsyn /// Function will be always executed as last step after all filters. /// private async Task InvokeFilterOrFunctionAsync( - IList? autoFunctionInvocationFilters, - Func functionCallCallback, - KernelFunctionInvocationContext context, + Func functionCallCallback, + AutoFunctionInvocationContext context, int index = 0) { + IList? autoFunctionInvocationFilters = context.Kernel.AutoFunctionInvocationFilters; + if (autoFunctionInvocationFilters is { Count: > 0 } && index < autoFunctionInvocationFilters.Count) { await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( - context.ToAutoFunctionInvocationContext(), - (context) => this.InvokeFilterOrFunctionAsync(autoFunctionInvocationFilters, functionCallCallback, context, index + 1) + context, + (context) => this.InvokeFilterOrFunctionAsync(functionCallCallback, context, index + 1) ).ConfigureAwait(false); } else @@ -747,7 +751,7 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( /// The to monitor for cancellation requests. The default is . /// The result of the function invocation, or if the function invocation returned . /// is . - internal async Task InvokeFunctionAsync(KernelFunctionInvocationContext invocationContext, Kernel kernel, CancellationToken cancellationToken) + internal async Task InvokeFunctionAsync(KernelFunctionInvocationContext invocationContext, CancellationToken cancellationToken) { Verify.NotNull(invocationContext); @@ -771,9 +775,7 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( try { CurrentContext = invocationContext; - var autoFunctionInvocationContext = invocationContext.ToAutoFunctionInvocationContext(kernel, ) invocationContext = this.OnAutoFunctionInvocationAsync( - kernel, invocationContext, async (context) => { @@ -787,13 +789,13 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here, // as the called function could in turn telling the model about itself as a possible candidate for invocation. KernelArguments? arguments = null; - if (context.CallContent.Arguments is not null) + if (context.Arguments is not null) { - arguments = new(context.CallContent.Arguments); + arguments = new(context.Arguments); } var result = await invocationContext.Function.InvokeAsync(invocationContext.CallContent.Arguments, cancellationToken).ConfigureAwait(false); - context.Result = new FunctionResult(context.Function.AsKernelFunction(), result); + context.Result = new FunctionResult(context.Function, result); }).ConfigureAwait(false); ) diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs index d13d5519b652..d61923eda663 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs @@ -9,8 +9,15 @@ namespace Microsoft.SemanticKernel; /// /// Class with data related to automatic function invocation. /// -public class AutoFunctionInvocationContext +public class AutoFunctionInvocationContext : KernelFunctionInvocationContext { + private KernelFunctionInvocationContext _innerContext; + + public AutoFunctionInvocationContext(KernelFunctionInvocationContext innerContext) + { + this._innerContext = innerContext; + } + /// /// Initializes a new instance of the class. /// From be5583e0f4a814734483d5aa1cd77924f8f74412 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Fri, 4 Apr 2025 20:33:50 +0100 Subject: [PATCH 05/26] AutoFunctionInvocationContext as KFIC --- .../KernelFunctionInvocationContext.cs | 6 +- ...rnelFunctionInvocationContextExtensions.cs | 38 ------ .../KernelFunctionInvokingChatClient.cs | 89 ++++++------ .../AI/ChatCompletion/ChatHistory.cs | 27 ++-- .../AutoFunctionInvocationContext.cs | 129 +++++++++++++++--- 5 files changed, 162 insertions(+), 127 deletions(-) delete mode 100644 dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs index da5af46620fd..a36f03630fd5 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs @@ -16,10 +16,10 @@ namespace Microsoft.SemanticKernel.ChatCompletion; /// Provides context for an in-flight function invocation. [ExcludeFromCodeCoverage] -internal sealed class KernelFunctionInvocationContext +public class KernelFunctionInvocationContext { /// - /// A nop function used to allow to be non-nullable. Default instances of + /// A nop function used to allow to be non-nullable. Default instances of /// start with this as the target function. /// private static readonly AIFunction s_nopFunction = AIFunctionFactory.Create(() => { }, nameof(KernelFunctionInvocationContext)); @@ -64,7 +64,7 @@ public IList Messages public ChatOptions? Options { get; set; } /// Gets or sets the AI function to be invoked. - public AIFunction Function + public AIFunction AIFunction { get => _function; set diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs deleted file mode 100644 index 475698c89fd5..000000000000 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Linq; - -namespace Microsoft.SemanticKernel.ChatCompletion; - -// Slight modified source from -// https://raw.githubusercontent.com/dotnet/extensions/refs/heads/main/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs - -/// Provides context for an in-flight function invocation. -internal static class KernelFunctionInvocationContextExtensions -{ - internal static AutoFunctionInvocationContext ToAutoFunctionInvocationContext(this KernelFunctionInvocationContext context) - { - if (context is null) - { - throw new ArgumentNullException(nameof(context)); - } - - var kernelFunction = context.Function.AsKernelFunction(); - return new AutoFunctionInvocationContext( - kernel: context.Kernel, - function: kernelFunction, - result: context.Result, - chatHistory: context.Messages.ToChatHistory(), - chatMessageContent: context.Response!.Messages.Last().ToChatMessageContent()) - { - Arguments = context.CallContent.Arguments is null ? null : new(context.CallContent.Arguments), - FunctionCount = context.FunctionCount, - FunctionSequenceIndex = context.FunctionCallIndex, - RequestSequenceIndex = context.Iteration, - IsStreaming = context.IsStreaming, - ToolCallId = context.CallContent.CallId, - ExecutionSettings = context.Options.ToPromptExecutionSettings() - }; - } -} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index 99c689d61369..a17e9d02730a 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -12,7 +12,6 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using static System.Runtime.InteropServices.JavaScript.JSType; #pragma warning disable CA2213 // Disposable fields should be disposed #pragma warning disable IDE0009 // Use explicit 'this.' qualifier @@ -187,7 +186,6 @@ public override async Task GetResponseAsync( IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { Verify.NotNull(messages); - options?.AdditionalProperties?.TryGetValue(KernelFunctionInvocationContext.KernelKey, out var kernel); // A single request into this GetResponseAsync may result in multiple requests to the inner client. // Create an activity to group them together for better observability. @@ -256,7 +254,7 @@ public override async Task GetResponseAsync( // Add the responses from the function calls into the augmented history and also into the tracked // list of response messages. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, response, functionCallContents!, iteration, kernel, cancellationToken).ConfigureAwait(false); + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, cancellationToken).ConfigureAwait(false); responseMessages.AddRange(modeAndMessages.MessagesAdded); if (UpdateOptionsForMode(modeAndMessages.Mode, ref options!, response.ChatThreadId)) @@ -504,14 +502,12 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti /// /// The current chat contents, inclusive of the function call contents being processed. /// The options used for the response being processed. - /// The response from the inner client. /// The function call contents representing the functions to be invoked. /// The iteration number of how many roundtrips have been made to the inner client. - /// The containing the auto function invocations. /// The to monitor for cancellation requests. /// A value indicating how the caller should proceed. private async Task<(ContinueMode Mode, IList MessagesAdded)> ProcessFunctionCallsAsync( - List messages, ChatOptions options, ChatResponse response, List functionCallContents, int iteration, Kernel kernel, CancellationToken cancellationToken) + List messages, ChatOptions options, List functionCallContents, int iteration, CancellationToken cancellationToken) { // We must add a response for every tool call, regardless of whether we successfully executed it or not. // If we successfully execute it, we'll add the result. If we don't, we'll add an error. @@ -522,7 +518,7 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti if (functionCallContents.Count == 1) { FunctionInvocationResult result = await ProcessFunctionCallAsync( - messages, options, response, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false); + messages, options, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false); IList added = CreateResponseMessages([result]); ThrowIfNoFunctionResultsAdded(added); @@ -585,18 +581,15 @@ private void ThrowIfNoFunctionResultsAdded(IList? messages) } /// Processes the function call described in []. - /// The current chat contents, inclusive of the function call contents being processed. /// The options used for the response being processed. - /// The response from the inner client. /// The function call contents representing all the functions being invoked. /// The iteration number of how many roundtrips have been made to the inner client. /// The 0-based index of the function being called out of . /// The to monitor for cancellation requests. /// A value indicating how the caller should proceed. private async Task ProcessFunctionCallAsync( - Kernel kernel, - List messages, ChatOptions options, ChatResponse response, List callContents, + List messages, ChatOptions options, List callContents, int iteration, int functionCallIndex, CancellationToken cancellationToken) { var callContent = callContents[functionCallIndex]; @@ -610,12 +603,10 @@ private async Task ProcessFunctionCallAsync( KernelFunctionInvocationContext context = new() { - Kernel = kernel, Messages = messages, - Response = response, Options = options, CallContent = callContent, - Function = function, + AIFunction = function, Iteration = iteration, FunctionCallIndex = functionCallIndex, FunctionCount = callContents.Count, @@ -705,17 +696,16 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi /// /// Invokes the auto function invocation filters. /// - /// The auto function invocation context. + /// The auto function invocation context. /// The function to call after the filters. /// The auto function invocation context. private async Task OnAutoFunctionInvocationAsync( - KernelFunctionInvocationContext kernelContext, + AutoFunctionInvocationContext context, Func functionCallCallback) { - var autoContext = invocationContext.ToAutoFunctionInvocationContext(); - await this.InvokeFilterOrFunctionAsync(functionCallCallback, kernelContext).ConfigureAwait(false); + await this.InvokeFilterOrFunctionAsync(functionCallCallback, context).ConfigureAwait(false); - return kernelContext; + return context; } /// @@ -747,7 +737,6 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( /// Invokes the function asynchronously. /// The function invocation context detailing the function to be invoked and its arguments along with additional request information. - /// The containing the auto function invocations. /// The to monitor for cancellation requests. The default is . /// The result of the function invocation, or if the function invocation returned . /// is . @@ -755,7 +744,7 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( { Verify.NotNull(invocationContext); - using Activity? activity = _activitySource?.StartActivity(invocationContext.Function.Name); + using Activity? activity = _activitySource?.StartActivity(invocationContext.AIFunction.Name); long startingTimestamp = 0; if (_logger.IsEnabled(LogLevel.Debug)) @@ -763,11 +752,11 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( startingTimestamp = Stopwatch.GetTimestamp(); if (_logger.IsEnabled(LogLevel.Trace)) { - LogInvokingSensitive(invocationContext.Function.Name, LoggingAsJson(invocationContext.CallContent.Arguments, invocationContext.Function.JsonSerializerOptions)); + LogInvokingSensitive(invocationContext.AIFunction.Name, LoggingAsJson(invocationContext.CallContent.Arguments, invocationContext.AIFunction.JsonSerializerOptions)); } else { - LogInvoking(invocationContext.Function.Name); + LogInvoking(invocationContext.AIFunction.Name); } } @@ -775,31 +764,31 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( try { CurrentContext = invocationContext; - invocationContext = this.OnAutoFunctionInvocationAsync( - invocationContext, - async (context) => + if (invocationContext is AutoFunctionInvocationContext autoFunctionInvocationContext) { - // Check if filter requested termination - if (context.Terminate) - { - return; - } - - // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any - // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here, - // as the called function could in turn telling the model about itself as a possible candidate for invocation. - KernelArguments? arguments = null; - if (context.Arguments is not null) + autoFunctionInvocationContext = await this.OnAutoFunctionInvocationAsync( + autoFunctionInvocationContext, + async (context) => { - arguments = new(context.Arguments); - } - - var result = await invocationContext.Function.InvokeAsync(invocationContext.CallContent.Arguments, cancellationToken).ConfigureAwait(false); - context.Result = new FunctionResult(context.Function, result); - }).ConfigureAwait(false); - - ) - result = + // Check if filter requested termination + if (context.Terminate) + { + return; + } + + // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any + // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here, + // as the called function could in turn telling the model about itself as a possible candidate for invocation. + KernelArguments? arguments = null; + if (context.Arguments is not null) + { + arguments = new(context.Arguments); + } + + var result = await invocationContext.AIFunction.InvokeAsync(invocationContext.CallContent.Arguments, cancellationToken).ConfigureAwait(false); + context.Result = new FunctionResult(context.Function, result); + }).ConfigureAwait(false); + } } catch (Exception e) { @@ -811,11 +800,11 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( if (e is OperationCanceledException) { - LogInvocationCanceled(invocationContext.Function.Name); + LogInvocationCanceled(invocationContext.AIFunction.Name); } else { - LogInvocationFailed(invocationContext.Function.Name, e); + LogInvocationFailed(invocationContext.AIFunction.Name, e); } throw; @@ -828,11 +817,11 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( if (result is not null && _logger.IsEnabled(LogLevel.Trace)) { - LogInvocationCompletedSensitive(invocationContext.Function.Name, elapsed, LoggingAsJson(result, invocationContext.Function.JsonSerializerOptions)); + LogInvocationCompletedSensitive(invocationContext.AIFunction.Name, elapsed, LoggingAsJson(result, invocationContext.AIFunction.JsonSerializerOptions)); } else { - LogInvocationCompleted(invocationContext.Function.Name, elapsed); + LogInvocationCompleted(invocationContext.AIFunction.Name, elapsed); } } } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs index 22968c47ea38..cfdad840349c 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs @@ -37,8 +37,7 @@ public ChatHistory(string message, AuthorRole role) { Verify.NotNullOrWhiteSpace(message); - this._messages = []; - this.Add(new ChatMessageContent(role, message)); + this._messages = [new ChatMessageContent(role, message)]; } /// @@ -60,7 +59,7 @@ public ChatHistory(IEnumerable messages) } /// Gets the number of messages in the history. - public int Count => this._messages.Count; + public virtual int Count => this._messages.Count; /// /// Role of the message author @@ -118,7 +117,7 @@ public void AddDeveloperMessage(string content) => /// Adds a message to the history. /// The message to add. /// is null. - public void Add(ChatMessageContent item) + public virtual void Add(ChatMessageContent item) { Verify.NotNull(item); this._messages.Add(item); @@ -127,7 +126,7 @@ public void Add(ChatMessageContent item) /// Adds the messages to the history. /// The collection whose messages should be added to the history. /// is null. - public void AddRange(IEnumerable items) + public virtual void AddRange(IEnumerable items) { Verify.NotNull(items); this._messages.AddRange(items); @@ -137,7 +136,7 @@ public void AddRange(IEnumerable items) /// The index at which the item should be inserted. /// The message to insert. /// is null. - public void Insert(int index, ChatMessageContent item) + public virtual void Insert(int index, ChatMessageContent item) { Verify.NotNull(item); this._messages.Insert(index, item); @@ -151,17 +150,17 @@ public void Insert(int index, ChatMessageContent item) /// is null. /// The number of messages in the history is greater than the available space from to the end of . /// is less than 0. - public void CopyTo(ChatMessageContent[] array, int arrayIndex) => this._messages.CopyTo(array, arrayIndex); + public virtual void CopyTo(ChatMessageContent[] array, int arrayIndex) => this._messages.CopyTo(array, arrayIndex); /// Removes all messages from the history. - public void Clear() => this._messages.Clear(); + public virtual void Clear() => this._messages.Clear(); /// Gets or sets the message at the specified index in the history. /// The index of the message to get or set. /// The message at the specified index. /// is null. /// The was not valid for this history. - public ChatMessageContent this[int index] + public virtual ChatMessageContent this[int index] { get => this._messages[index]; set @@ -175,7 +174,7 @@ public ChatMessageContent this[int index] /// The message to locate. /// true if the message is found in the history; otherwise, false. /// is null. - public bool Contains(ChatMessageContent item) + public virtual bool Contains(ChatMessageContent item) { Verify.NotNull(item); return this._messages.Contains(item); @@ -185,7 +184,7 @@ public bool Contains(ChatMessageContent item) /// The message to locate. /// The index of the first found occurrence of the specified message; -1 if the message could not be found. /// is null. - public int IndexOf(ChatMessageContent item) + public virtual int IndexOf(ChatMessageContent item) { Verify.NotNull(item); return this._messages.IndexOf(item); @@ -194,13 +193,13 @@ public int IndexOf(ChatMessageContent item) /// Removes the message at the specified index from the history. /// The index of the message to remove. /// The was not valid for this history. - public void RemoveAt(int index) => this._messages.RemoveAt(index); + public virtual void RemoveAt(int index) => this._messages.RemoveAt(index); /// Removes the first occurrence of the specified message from the history. /// The message to remove from the history. /// true if the item was successfully removed; false if it wasn't located in the history. /// is null. - public bool Remove(ChatMessageContent item) + public virtual bool Remove(ChatMessageContent item) { Verify.NotNull(item); return this._messages.Remove(item); @@ -214,7 +213,7 @@ public bool Remove(ChatMessageContent item) /// is less than 0. /// is less than 0. /// and do not denote a valid range of messages. - public void RemoveRange(int index, int count) + public virtual void RemoveRange(int index, int count) { this._messages.RemoveRange(index, count); } diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs index d61923eda663..68e06140cc9b 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs @@ -1,7 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Collections; +using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Threading; +using Microsoft.Extensions.AI; using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.SemanticKernel; @@ -11,11 +14,28 @@ namespace Microsoft.SemanticKernel; /// public class AutoFunctionInvocationContext : KernelFunctionInvocationContext { - private KernelFunctionInvocationContext _innerContext; - + private readonly KernelFunction? _kernelFunction; + private readonly KernelFunctionInvocationContext? _innerContext; + /// + /// Initializes a new instance of the class from an existing . + /// public AutoFunctionInvocationContext(KernelFunctionInvocationContext innerContext) { + Verify.NotNull(innerContext); + Verify.NotNull(innerContext.Options); + Verify.NotNull(innerContext.Options.AdditionalProperties); + + innerContext.Options.AdditionalProperties.TryGetValue("Kernel", out var kernel); + Verify.NotNull(kernel); + + innerContext.Options.AdditionalProperties.TryGetValue("ChatMessageContent", out var chatMessageContent); + Verify.NotNull(chatMessageContent); + this._innerContext = innerContext; + this.ChatHistory = new ChatMessageHistory(innerContext.Messages); + this.ChatMessageContent = chatMessageContent; + this.Kernel = kernel; + this.Result = new FunctionResult(this._kernelFunction!) { Culture = kernel.Culture }; } /// @@ -40,7 +60,7 @@ public AutoFunctionInvocationContext( Verify.NotNull(chatMessageContent); this.Kernel = kernel; - this.Function = function; + this._kernelFunction = function; this.Result = result; this.ChatHistory = chatHistory; this.ChatMessageContent = chatMessageContent; @@ -72,11 +92,6 @@ public AutoFunctionInvocationContext( /// public int FunctionSequenceIndex { get; init; } - /// - /// Number of functions that will be invoked during auto function invocation request. - /// - public int FunctionCount { get; init; } - /// /// The ID of the tool call. /// @@ -101,7 +116,10 @@ public AutoFunctionInvocationContext( /// /// Gets the with which this filter is associated. /// - public KernelFunction Function { get; } + public KernelFunction Function + { + get => this._innerContext?.AIFunction.AsKernelFunction() ?? this._kernelFunction!; + } /// /// Gets the containing services, plugins, and other state for use throughout the operation. @@ -114,18 +132,85 @@ public AutoFunctionInvocationContext( public FunctionResult Result { get; set; } /// - /// Gets or sets a value indicating whether the operation associated with the filter should be terminated. - /// - /// By default, this value is , which means all functions will be invoked. - /// If set to , the behavior depends on how functions are invoked: - /// - /// - If functions are invoked sequentially (the default behavior), the remaining functions will not be invoked, - /// and the last request to the LLM will not be performed. - /// - /// - If functions are invoked concurrently (controlled by the option), - /// other functions will still be invoked, and the last request to the LLM will not be performed. - /// - /// In both cases, the automatic function invocation process will be terminated, and the result of the last executed function will be returned to the caller. + /// Mutable chat message as chat history. /// - public bool Terminate { get; set; } + internal class ChatMessageHistory : ChatHistory, IEnumerable + { + private readonly List _messages; + + public ChatMessageHistory(IEnumerable messages) : base(messages.ToChatHistory()) + { + this._messages = new List(messages); + } + + public override void Add(ChatMessageContent item) + { + item.GetHashCode(); + base.Add(item); + this._messages.Add(item.ToChatMessage()); + } + + public override void Clear() + { + base.Clear(); + this._messages.Clear(); + } + + public override bool Contains(ChatMessageContent item) + { + return base.Contains(item); + } + + public override bool Remove(ChatMessageContent item) + { + var index = base.IndexOf(item); + + if (index < 0) + { + return false; + } + + this._messages.RemoveAt(index); + base.RemoveAt(index); + + return true; + } + + public override void Insert(int index, ChatMessageContent item) + { + base.Insert(index, item); + this._messages.Insert(index, item.ToChatMessage()); + } + + public override void RemoveAt(int index) + { + this._messages.RemoveAt(index); + base.RemoveAt(index); + } + + public override ChatMessageContent this[int index] + { + get => this._messages[index].ToChatMessageContent(); + set + { + this._messages[index] = value.ToChatMessage(); + base[index] = value; + } + } + + public override int Count => this._messages.Count; + + // Explicit implementation of IEnumerable.GetEnumerator() + IEnumerator IEnumerable.GetEnumerator() + { + foreach (var message in this._messages) + { + yield return message.ToChatMessageContent(); // Convert and yield each item + } + } + + // Explicit implementation of non-generic IEnumerable.GetEnumerator() + IEnumerator IEnumerable.GetEnumerator() + => ((IEnumerable)this).GetEnumerator(); + } } From 8663ab08d46fe5dbd517c79e76e0733d55e19fd6 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Mon, 7 Apr 2025 16:34:03 +0100 Subject: [PATCH 06/26] AutoFunctionInvocation WIP --- .../Connectors.OpenAI.UnitTests.csproj | 6 + ...FunctionInvocationFilterChatClientTests.cs | 10 +- ..._multiple_function_calls_test_response.txt | 9 ++ ...multiple_function_calls_test_response.json | 40 +++++ .../AI/ChatClient/ChatOptionsExtensions.cs | 21 +++ .../KernelFunctionInvokingChatClient.cs | 143 ++++++++++++------ .../AI/PromptExecutionSettingsExtensions.cs | 4 +- .../AutoFunctionInvocationContext.cs | 9 +- .../Functions/KernelFunction.cs | 2 +- 9 files changed, 185 insertions(+), 59 deletions(-) create mode 100644 dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/chat_completion_streaming_chatclient_multiple_function_calls_test_response.txt create mode 100644 dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_multiple_function_calls_test_response.json diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj index 57f983dd5c9b..33557430e615 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj @@ -96,6 +96,12 @@ Always + + Always + + + Always + diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs index 9a8e91a25df5..d28c67346c3f 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs @@ -596,7 +596,7 @@ public async Task PostFilterCanTerminateOperationOnStreamingAsync() Assert.Equal([0], requestSequenceNumbers); Assert.Equal([0], functionSequenceNumbers); - // Results of function invoked before termination should be returned + // Results of function invoked before termination should be returned Assert.Equal(3, streamingContent.Count); var lastMessageContent = streamingContent[^1] as StreamingChatMessageContent; @@ -732,8 +732,8 @@ public void Dispose() private static List GetFunctionCallingResponses() { return [ - new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_multiple_function_calls_test_response.json")) }, - new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_multiple_function_calls_test_response.json")) }, + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_chatclient_multiple_function_calls_test_response.json")) }, + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_chatclient_multiple_function_calls_test_response.json")) }, new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_test_response.json")) } ]; } @@ -741,8 +741,8 @@ private static List GetFunctionCallingResponses() private static List GetFunctionCallingStreamingResponses() { return [ - new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_streaming_multiple_function_calls_test_response.txt")) }, - new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_streaming_multiple_function_calls_test_response.txt")) }, + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt")) }, + new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt")) }, new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_streaming_test_response.txt")) } ]; } diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/chat_completion_streaming_chatclient_multiple_function_calls_test_response.txt b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/chat_completion_streaming_chatclient_multiple_function_calls_test_response.txt new file mode 100644 index 000000000000..17ce94647fd5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/chat_completion_streaming_chatclient_multiple_function_calls_test_response.txt @@ -0,0 +1,9 @@ +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":0,"id":"1","type":"function","function":{"name":"MyPlugin_GetCurrentWeather","arguments":"{\n\"location\": \"Boston, MA\"\n}"}}]},"finish_reason":"tool_calls"}]} + +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":1,"id":"2","type":"function","function":{"name":"MyPlugin_FunctionWithException","arguments":"{\n\"argument\": \"value\"\n}"}}]},"finish_reason":"tool_calls"}]} + +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":2,"id":"3","type":"function","function":{"name":"MyPlugin_NonExistentFunction","arguments":"{\n\"argument\": \"value\"\n}"}}]},"finish_reason":"tool_calls"}]} + +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":3,"id":"4","type":"function","function":{"name":"MyPlugin_InvalidArguments","arguments":"invalid_arguments_format"}}]},"finish_reason":"tool_calls"}]} + +data: [DONE] diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_multiple_function_calls_test_response.json b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_multiple_function_calls_test_response.json new file mode 100644 index 000000000000..2c499b14089f --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_multiple_function_calls_test_response.json @@ -0,0 +1,40 @@ +{ + "id": "response-id", + "object": "chat.completion", + "created": 1699896916, + "model": "gpt-3.5-turbo-0613", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "1", + "type": "function", + "function": { + "name": "MyPlugin_Function1", + "arguments": "{\n\"parameter\": \"function1-value\"\n}" + } + }, + { + "id": "2", + "type": "function", + "function": { + "name": "MyPlugin_Function2", + "arguments": "{\n\"parameter\": \"function2-value\"\n}" + } + } + ] + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 82, + "completion_tokens": 17, + "total_tokens": 99 + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs index d8fab37e57bd..d6b5109f9b17 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Text.Json; using Microsoft.Extensions.AI; @@ -13,6 +14,10 @@ namespace Microsoft.SemanticKernel.ChatCompletion; /// internal static class ChatOptionsExtensions { + internal const string KernelKey = "AutoInvokingKernel"; + internal const string IsStreamingKey = "AutoInvokingIsStreaming"; + internal const string ChatMessageContentKey = "AutoInvokingChatCompletionContent"; + /// Converts a to a . internal static PromptExecutionSettings? ToPromptExecutionSettings(this ChatOptions? options) { @@ -118,4 +123,20 @@ internal static class ChatOptionsExtensions return settings; } + + /// + /// To enable usage of AutoFunctionInvocationFilters with ChatClient's the kernel needs to be provided in the ChatOptions + /// + /// Chat options. + /// Kernel to be used for auto function invocation. + internal static ChatOptions? AddKernel(this ChatOptions options, Kernel kernel) + { + Verify.NotNull(options); + Verify.NotNull(kernel); + + options.AdditionalProperties ??= []; + options.AdditionalProperties?.TryAdd(KernelKey, kernel); + + return options; + } } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index a17e9d02730a..360d3d53f184 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -48,7 +48,7 @@ namespace Microsoft.SemanticKernel.ChatCompletion; internal sealed partial class KernelFunctionInvokingChatClient : DelegatingChatClient { /// The for the current function invocation. - private static readonly AsyncLocal _currentContext = new(); + private static readonly AsyncLocal s_currentContext = new(); /// The logger to use for logging information about function invocation. private readonly ILogger _logger; @@ -68,8 +68,8 @@ internal sealed partial class KernelFunctionInvokingChatClient : DelegatingChatC public KernelFunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null) : base(innerClient) { - _logger = logger ?? NullLogger.Instance; - _activitySource = innerClient.GetService(); + this._logger = logger ?? NullLogger.Instance; + this._activitySource = innerClient.GetService(); } /// @@ -80,8 +80,8 @@ public KernelFunctionInvokingChatClient(IChatClient innerClient, ILogger? logger /// internal static KernelFunctionInvocationContext? CurrentContext { - get => _currentContext.Value; - set => _currentContext.Value = value; + get => s_currentContext.Value; + set => s_currentContext.Value = value; } /// @@ -97,7 +97,7 @@ internal static KernelFunctionInvocationContext? CurrentContext /// /// /// Changing the value of this property while the client is in use might result in inconsistencies - /// as to whether errors are retried during an in-flight request. + /// whether errors are retried during an in-flight request. /// public bool RetryOnError { get; set; } @@ -169,7 +169,7 @@ internal static KernelFunctionInvocationContext? CurrentContext /// public int? MaximumIterationsPerRequest { - get => _maximumIterationsPerRequest; + get => this._maximumIterationsPerRequest; set { if (value < 1) @@ -177,7 +177,7 @@ public int? MaximumIterationsPerRequest throw new ArgumentOutOfRangeException(nameof(value)); } - _maximumIterationsPerRequest = value; + this._maximumIterationsPerRequest = value; } } @@ -189,7 +189,7 @@ public override async Task GetResponseAsync( // A single request into this GetResponseAsync may result in multiple requests to the inner client. // Create an activity to group them together for better observability. - using Activity? activity = _activitySource?.StartActivity(nameof(KernelFunctionInvokingChatClient)); + using Activity? activity = this._activitySource?.StartActivity(nameof(KernelFunctionInvokingChatClient)); // Copy the original messages in order to avoid enumerating the original messages multiple times. // The IEnumerable can represent an arbitrary amount of work. @@ -217,7 +217,7 @@ public override async Task GetResponseAsync( // Any function call work to do? If yes, ensure we're tracking that work in functionCallContents. bool requiresFunctionInvocation = options?.Tools is { Count: > 0 } && - (!MaximumIterationsPerRequest.HasValue || iteration < MaximumIterationsPerRequest.GetValueOrDefault()) && + (!this.MaximumIterationsPerRequest.HasValue || iteration < this.MaximumIterationsPerRequest.GetValueOrDefault()) && CopyFunctionCalls(response.Messages, ref functionCallContents); // In a common case where we make a request and there's no function calling work required, @@ -252,11 +252,17 @@ public override async Task GetResponseAsync( // Prepare the history for the next iteration. FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId); + // Prepare the options for the next auto function invocation iteration. + UpdateOptionsForAutoFunctionInvocation(ref options!, response.Messages.Last().ToChatMessageContent(), isStreaming: false); + // Add the responses from the function calls into the augmented history and also into the tracked // list of response messages. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, cancellationToken).ConfigureAwait(false); + var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, cancellationToken).ConfigureAwait(false); responseMessages.AddRange(modeAndMessages.MessagesAdded); + // Clear the auto function invocation options. + ClearOptionsForAutoFunctionInvocation(ref options!); + if (UpdateOptionsForMode(modeAndMessages.Mode, ref options!, response.ChatThreadId)) { // Terminate @@ -279,7 +285,7 @@ public override async IAsyncEnumerable GetStreamingResponseA // A single request into this GetStreamingResponseAsync may result in multiple requests to the inner client. // Create an activity to group them together for better observability. - using Activity? activity = _activitySource?.StartActivity(nameof(KernelFunctionInvokingChatClient)); + using Activity? activity = this._activitySource?.StartActivity(nameof(KernelFunctionInvokingChatClient)); // Copy the original messages in order to avoid enumerating the original messages multiple times. // The IEnumerable can represent an arbitrary amount of work. @@ -315,7 +321,7 @@ public override async IAsyncEnumerable GetStreamingResponseA // If there are no tools to call, or for any other reason we should stop, return the response. if (functionCallContents is not { Count: > 0 } || options?.Tools is not { Count: > 0 } || - (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations)) + (this.MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations)) { break; } @@ -327,12 +333,18 @@ public override async IAsyncEnumerable GetStreamingResponseA // Prepare the history for the next iteration. FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId); - // Process all of the functions, adding their results into the history. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); + // Prepare the options for the next auto function invocation iteration. + UpdateOptionsForAutoFunctionInvocation(ref options!, response.Messages.Last().ToChatMessageContent(), isStreaming: true); + + // Process all the functions, adding their results into the history. + var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); responseMessages.AddRange(modeAndMessages.MessagesAdded); + // Clear the auto function invocation options. + ClearOptionsForAutoFunctionInvocation(ref options!); + // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages - // includes all activities, including generated function results. + // include all activities, including generated function results. string toolResponseId = Guid.NewGuid().ToString("N"); foreach (var message in modeAndMessages.MessagesAdded) { @@ -454,6 +466,37 @@ private static bool CopyFunctionCalls( return any; } + private static void UpdateOptionsForAutoFunctionInvocation(ref ChatOptions options, ChatMessageContent content, bool isStreaming) + { + if (options.AdditionalProperties?.ContainsKey(ChatOptionsExtensions.IsStreamingKey) ?? false) + { + throw new KernelException($"The reserved key name '{ChatOptionsExtensions.IsStreamingKey}' is already specified in the options. Avoid using this key name."); + } + + if (options.AdditionalProperties?.ContainsKey(ChatOptionsExtensions.ChatMessageContentKey) ?? false) + { + throw new KernelException($"The reserved key name '{ChatOptionsExtensions.ChatMessageContentKey}' is already specified in the options. Avoid using this key name."); + } + + options.AdditionalProperties ??= []; + + options.AdditionalProperties[ChatOptionsExtensions.IsStreamingKey] = isStreaming; + options.AdditionalProperties[ChatOptionsExtensions.ChatMessageContentKey] = content; + } + + private static void ClearOptionsForAutoFunctionInvocation(ref ChatOptions options) + { + if (options.AdditionalProperties?.ContainsKey(ChatOptionsExtensions.IsStreamingKey) ?? false) + { + options.AdditionalProperties.Remove(ChatOptionsExtensions.IsStreamingKey); + } + + if (options.AdditionalProperties?.ContainsKey(ChatOptionsExtensions.ChatMessageContentKey) ?? false) + { + options.AdditionalProperties.Remove(ChatOptionsExtensions.ChatMessageContentKey); + } + } + /// Updates for the response. /// true if the function calling loop should terminate; otherwise, false. private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions options, string? chatThreadId) @@ -517,11 +560,11 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti // Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel. if (functionCallContents.Count == 1) { - FunctionInvocationResult result = await ProcessFunctionCallAsync( + FunctionInvocationResult result = await this.ProcessFunctionCallAsync( messages, options, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false); - IList added = CreateResponseMessages([result]); - ThrowIfNoFunctionResultsAdded(added); + IList added = this.CreateResponseMessages([result]); + this.ThrowIfNoFunctionResultsAdded(added); messages.AddRange(added); return (result.ContinueMode, added); @@ -530,12 +573,12 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti { FunctionInvocationResult[] results; - if (AllowConcurrentInvocation) + if (this.AllowConcurrentInvocation) { // Schedule the invocation of every function. results = await Task.WhenAll( from i in Enumerable.Range(0, functionCallContents.Count) - select Task.Run(() => ProcessFunctionCallAsync( + select Task.Run(() => this.ProcessFunctionCallAsync( messages, options, functionCallContents, iteration, i, cancellationToken))).ConfigureAwait(false); } @@ -545,7 +588,7 @@ select Task.Run(() => ProcessFunctionCallAsync( results = new FunctionInvocationResult[functionCallContents.Count]; for (int i = 0; i < results.Length; i++) { - results[i] = await ProcessFunctionCallAsync( + results[i] = await this.ProcessFunctionCallAsync( messages, options, functionCallContents, iteration, i, cancellationToken).ConfigureAwait(false); } @@ -553,8 +596,8 @@ select Task.Run(() => ProcessFunctionCallAsync( ContinueMode continueMode = ContinueMode.Continue; - IList added = CreateResponseMessages(results); - ThrowIfNoFunctionResultsAdded(added); + IList added = this.CreateResponseMessages(results); + this.ThrowIfNoFunctionResultsAdded(added); messages.AddRange(added); foreach (FunctionInvocationResult fir in results) @@ -576,7 +619,7 @@ private void ThrowIfNoFunctionResultsAdded(IList? messages) { if (messages is null || messages.Count == 0) { - throw new InvalidOperationException($"{GetType().Name}.{nameof(CreateResponseMessages)} returned null or an empty collection of messages."); + throw new InvalidOperationException($"{this.GetType().Name}.{nameof(this.CreateResponseMessages)} returned null or an empty collection of messages."); } } @@ -595,13 +638,14 @@ private async Task ProcessFunctionCallAsync( var callContent = callContents[functionCallIndex]; // Look up the AIFunction for the function call. If the requested function isn't available, send back an error. - AIFunction? function = options.Tools!.OfType().FirstOrDefault(t => t.Name == callContent.Name); + AIFunction? function = options.Tools!.OfType().FirstOrDefault( + t => t.Name == callContent.Name); if (function is null) { return new(ContinueMode.Continue, FunctionInvocationStatus.NotFound, callContent, result: null, exception: null); } - KernelFunctionInvocationContext context = new() + var context = new AutoFunctionInvocationContext(new() { Messages = messages, Options = options, @@ -609,18 +653,17 @@ private async Task ProcessFunctionCallAsync( AIFunction = function, Iteration = iteration, FunctionCallIndex = functionCallIndex, - FunctionCount = callContents.Count, - }; + FunctionCount = callContents.Count + }); object? result; try { - result = await InvokeFunctionAsync(context, cancellationToken).ConfigureAwait(false); + result = await this.InvokeFunctionAsync(context, cancellationToken).ConfigureAwait(false); } catch (Exception e) when (!cancellationToken.IsCancellationRequested) { - return new( - RetryOnError ? ContinueMode.Continue : ContinueMode.AllowOneMoreRoundtrip, // We won't allow further function calls, hence the LLM will just get one more chance to give a final answer. + return new(this.RetryOnError ? ContinueMode.Continue : ContinueMode.AllowOneMoreRoundtrip, // We won't allow further function calls, hence the LLM will just get one more chance to give a final answer. FunctionInvocationStatus.Exception, callContent, result: null, @@ -681,7 +724,7 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi _ => "Error: Unknown error.", }; - if (IncludeDetailedErrors && result.Exception is not null) + if (this.IncludeDetailedErrors && result.Exception is not null) { message = $"{message} Exception: {result.Exception.Message}"; } @@ -744,19 +787,19 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( { Verify.NotNull(invocationContext); - using Activity? activity = _activitySource?.StartActivity(invocationContext.AIFunction.Name); + using Activity? activity = this._activitySource?.StartActivity(invocationContext.AIFunction.Name); long startingTimestamp = 0; - if (_logger.IsEnabled(LogLevel.Debug)) + if (this._logger.IsEnabled(LogLevel.Debug)) { startingTimestamp = Stopwatch.GetTimestamp(); - if (_logger.IsEnabled(LogLevel.Trace)) + if (this._logger.IsEnabled(LogLevel.Trace)) { - LogInvokingSensitive(invocationContext.AIFunction.Name, LoggingAsJson(invocationContext.CallContent.Arguments, invocationContext.AIFunction.JsonSerializerOptions)); + this.LogInvokingSensitive(invocationContext.AIFunction.Name, LoggingAsJson(invocationContext.CallContent.Arguments, invocationContext.AIFunction.JsonSerializerOptions)); } else { - LogInvoking(invocationContext.AIFunction.Name); + this.LogInvoking(invocationContext.AIFunction.Name); } } @@ -788,6 +831,8 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( var result = await invocationContext.AIFunction.InvokeAsync(invocationContext.CallContent.Arguments, cancellationToken).ConfigureAwait(false); context.Result = new FunctionResult(context.Function, result); }).ConfigureAwait(false); + + invocationContext = autoFunctionInvocationContext; } } catch (Exception e) @@ -800,28 +845,28 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( if (e is OperationCanceledException) { - LogInvocationCanceled(invocationContext.AIFunction.Name); + this.LogInvocationCanceled(invocationContext.AIFunction.Name); } else { - LogInvocationFailed(invocationContext.AIFunction.Name, e); + this.LogInvocationFailed(invocationContext.AIFunction.Name, e); } throw; } finally { - if (_logger.IsEnabled(LogLevel.Debug)) + if (this._logger.IsEnabled(LogLevel.Debug)) { TimeSpan elapsed = GetElapsedTime(startingTimestamp); - if (result is not null && _logger.IsEnabled(LogLevel.Trace)) + if (result is not null && this._logger.IsEnabled(LogLevel.Trace)) { - LogInvocationCompletedSensitive(invocationContext.AIFunction.Name, elapsed, LoggingAsJson(result, invocationContext.AIFunction.JsonSerializerOptions)); + this.LogInvocationCompletedSensitive(invocationContext.AIFunction.Name, elapsed, LoggingAsJson(result, invocationContext.AIFunction.JsonSerializerOptions)); } else { - LogInvocationCompleted(invocationContext.AIFunction.Name, elapsed); + this.LogInvocationCompleted(invocationContext.AIFunction.Name, elapsed); } } } @@ -881,11 +926,11 @@ public sealed class FunctionInvocationResult { internal FunctionInvocationResult(ContinueMode continueMode, FunctionInvocationStatus status, Microsoft.Extensions.AI.FunctionCallContent callContent, object? result, Exception? exception) { - ContinueMode = continueMode; - Status = status; - CallContent = callContent; - Result = result; - Exception = exception; + this.ContinueMode = continueMode; + this.Status = status; + this.CallContent = callContent; + this.Result = result; + this.Exception = exception; } /// Gets status about how the function invocation completed. diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs index 98bb09be6f85..f9d4a5db151b 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs @@ -8,6 +8,7 @@ using System.Text.Json; using System.Text.Json.Serialization.Metadata; using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.SemanticKernel; @@ -149,7 +150,8 @@ internal static class PromptExecutionSettingsExtensions options.Tools = functions.Select(f => f.AsAIFunction(kernel)).Cast().ToList(); } - return options; + // Enables usage of AutoFunctionInvocationFilters + return options.AddKernel(kernel!); // Be a little lenient on the types of the values used in the extension data, // e.g. allow doubles even when requesting floats. diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs index 68e06140cc9b..9dd8b094e50c 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs @@ -25,17 +25,20 @@ public AutoFunctionInvocationContext(KernelFunctionInvocationContext innerContex Verify.NotNull(innerContext.Options); Verify.NotNull(innerContext.Options.AdditionalProperties); - innerContext.Options.AdditionalProperties.TryGetValue("Kernel", out var kernel); + innerContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.KernelKey, out var kernel); Verify.NotNull(kernel); - innerContext.Options.AdditionalProperties.TryGetValue("ChatMessageContent", out var chatMessageContent); + innerContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.ChatMessageContentKey, out var chatMessageContent); Verify.NotNull(chatMessageContent); + innerContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.IsStreamingKey, out var isStreaming); + this.IsStreaming = isStreaming; + this._innerContext = innerContext; this.ChatHistory = new ChatMessageHistory(innerContext.Messages); this.ChatMessageContent = chatMessageContent; this.Kernel = kernel; - this.Result = new FunctionResult(this._kernelFunction!) { Culture = kernel.Culture }; + this.Result = new FunctionResult(this.Function) { Culture = kernel.Culture }; } /// diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs index a0a425aca1ec..a3af72a44d08 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs @@ -532,7 +532,7 @@ public KernelAIFunction(KernelFunction kernelFunction, Kernel? kernel) this.JsonSchema = BuildFunctionSchema(kernelFunction); } - + internal string KernelFunctionName => this._kernelFunction.Name; public override string Name { get; } public override JsonElement JsonSchema { get; } public override string Description => this._kernelFunction.Description; From 704c80b07babd16dae14886ed28590adb60ee7d2 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Tue, 8 Apr 2025 17:35:35 +0100 Subject: [PATCH 07/26] AutoContext losing result --- .../KernelFunctionInvokingChatClient.cs | 44 ++++---- .../AutoFunctionInvocationContext.cs | 105 +++++++++++------- 2 files changed, 88 insertions(+), 61 deletions(-) diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index 360d3d53f184..66b6e46e84ec 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -645,16 +645,20 @@ private async Task ProcessFunctionCallAsync( return new(ContinueMode.Continue, FunctionInvocationStatus.NotFound, callContent, result: null, exception: null); } - var context = new AutoFunctionInvocationContext(new() + if (callContent.Arguments is not null) + { + callContent.Arguments = new KernelArguments(callContent.Arguments); + } + + var context = new AutoFunctionInvocationContext(options) { Messages = messages, - Options = options, CallContent = callContent, AIFunction = function, Iteration = iteration, FunctionCallIndex = functionCallIndex, FunctionCount = callContents.Count - }); + }; object? result; try @@ -810,29 +814,23 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( if (invocationContext is AutoFunctionInvocationContext autoFunctionInvocationContext) { autoFunctionInvocationContext = await this.OnAutoFunctionInvocationAsync( - autoFunctionInvocationContext, - async (context) => - { - // Check if filter requested termination - if (context.Terminate) - { - return; - } - - // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any - // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here, - // as the called function could in turn telling the model about itself as a possible candidate for invocation. - KernelArguments? arguments = null; - if (context.Arguments is not null) + autoFunctionInvocationContext, + async (context) => { - arguments = new(context.Arguments); - } - - var result = await invocationContext.AIFunction.InvokeAsync(invocationContext.CallContent.Arguments, cancellationToken).ConfigureAwait(false); - context.Result = new FunctionResult(context.Function, result); + // Check if filter requested termination + if (context.Terminate) + { + return; + } + + // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any + // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here, + // as the called function could in turn telling the model about itself as a possible candidate for invocation. + result = await invocationContext.AIFunction.InvokeAsync(autoFunctionInvocationContext.Arguments, cancellationToken).ConfigureAwait(false); + context.Result = new FunctionResult(context.Function, result); }).ConfigureAwait(false); - invocationContext = autoFunctionInvocationContext; + result = autoFunctionInvocationContext.Result.GetValue(); } } catch (Exception e) diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs index 9dd8b094e50c..a431ed868381 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; @@ -14,30 +15,29 @@ namespace Microsoft.SemanticKernel; /// public class AutoFunctionInvocationContext : KernelFunctionInvocationContext { - private readonly KernelFunction? _kernelFunction; - private readonly KernelFunctionInvocationContext? _innerContext; + private ChatHistory? _chatHistory; + /// /// Initializes a new instance of the class from an existing . /// - public AutoFunctionInvocationContext(KernelFunctionInvocationContext innerContext) + public AutoFunctionInvocationContext(ChatOptions options) { - Verify.NotNull(innerContext); - Verify.NotNull(innerContext.Options); - Verify.NotNull(innerContext.Options.AdditionalProperties); + this.Options = options; + + // To create a AutoFunctionInvocationContext from a KernelFunctionInvocationContext, + // the ChatOptions must be provided with AdditionalProperties. + Verify.NotNull(options.AdditionalProperties); - innerContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.KernelKey, out var kernel); + options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.KernelKey, out var kernel); Verify.NotNull(kernel); - innerContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.ChatMessageContentKey, out var chatMessageContent); + options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.ChatMessageContentKey, out var chatMessageContent); Verify.NotNull(chatMessageContent); - innerContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.IsStreamingKey, out var isStreaming); - this.IsStreaming = isStreaming; + options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.IsStreamingKey, out var isStreaming); + Verify.NotNull(isStreaming); + this.IsStreaming = isStreaming.Value; - this._innerContext = innerContext; - this.ChatHistory = new ChatMessageHistory(innerContext.Messages); - this.ChatMessageContent = chatMessageContent; - this.Kernel = kernel; this.Result = new FunctionResult(this.Function) { Culture = kernel.Culture }; } @@ -62,11 +62,18 @@ public AutoFunctionInvocationContext( Verify.NotNull(chatHistory); Verify.NotNull(chatMessageContent); - this.Kernel = kernel; - this._kernelFunction = function; + this.Options = new() + { + AdditionalProperties = new() + { + [ChatOptionsExtensions.ChatMessageContentKey] = chatMessageContent, + [ChatOptionsExtensions.KernelKey] = kernel + } + }; + + this.Messages = chatHistory.ToChatMessageList(); + this.AIFunction = function.AsAIFunction(); this.Result = result; - this.ChatHistory = chatHistory; - this.ChatMessageContent = chatMessageContent; } /// @@ -83,27 +90,48 @@ public AutoFunctionInvocationContext( /// /// Gets the arguments associated with the operation. /// - public KernelArguments? Arguments { get; init; } + public KernelArguments? Arguments + { + get => this.CallContent.Arguments is KernelArguments kernelArguments ? kernelArguments : null; + init => this.CallContent.Arguments = value; + } /// /// Request sequence index of automatic function invocation process. Starts from 0. /// - public int RequestSequenceIndex { get; init; } + public int RequestSequenceIndex + { + get => this.Iteration; + init => this.Iteration = value; + } /// /// Function sequence index. Starts from 0. /// - public int FunctionSequenceIndex { get; init; } + public int FunctionSequenceIndex + { + get => this.FunctionCallIndex; + init => this.FunctionCallIndex = value; + } /// /// The ID of the tool call. /// - public string? ToolCallId { get; init; } + public string? ToolCallId + { + get => this.CallContent.CallId; + init + { + Verify.NotNull(value); + // ToolCallId + this.CallContent = new Microsoft.Extensions.AI.FunctionCallContent(value, this.CallContent.Name, this.CallContent.Arguments); + } + } /// /// The chat message content associated with automatic function invocation. /// - public ChatMessageContent ChatMessageContent { get; } + public ChatMessageContent ChatMessageContent => (this.Options?.AdditionalProperties?[ChatOptionsExtensions.ChatMessageContentKey] as ChatMessageContent)!; /// /// The execution settings associated with the operation. @@ -114,20 +142,27 @@ public AutoFunctionInvocationContext( /// /// Gets the associated with automatic function invocation. /// - public ChatHistory ChatHistory { get; } + public ChatHistory ChatHistory => this._chatHistory ??= new ChatMessageHistory(this.Messages); /// /// Gets the with which this filter is associated. /// - public KernelFunction Function - { - get => this._innerContext?.AIFunction.AsKernelFunction() ?? this._kernelFunction!; - } + public KernelFunction Function => this.AIFunction.AsKernelFunction(); /// /// Gets the containing services, plugins, and other state for use throughout the operation. /// - public Kernel Kernel { get; } + public Kernel Kernel + { + get + { + Kernel? kernel = null; + this.Options?.AdditionalProperties?.TryGetValue(ChatOptionsExtensions.KernelKey, out kernel); + + // To avoid exception from properties, when attempting to retrieve a kernel from a non-ready context, it will give a null. + return kernel!; + } + } /// /// Gets or sets the result of the function's invocation. @@ -135,20 +170,19 @@ public KernelFunction Function public FunctionResult Result { get; set; } /// - /// Mutable chat message as chat history. + /// Mutable IEnumerable of chat message as chat history. /// - internal class ChatMessageHistory : ChatHistory, IEnumerable + private class ChatMessageHistory : ChatHistory, IEnumerable { private readonly List _messages; - public ChatMessageHistory(IEnumerable messages) : base(messages.ToChatHistory()) + internal ChatMessageHistory(IEnumerable messages) : base(messages.ToChatHistory()) { this._messages = new List(messages); } public override void Add(ChatMessageContent item) { - item.GetHashCode(); base.Add(item); this._messages.Add(item.ToChatMessage()); } @@ -159,11 +193,6 @@ public override void Clear() this._messages.Clear(); } - public override bool Contains(ChatMessageContent item) - { - return base.Contains(item); - } - public override bool Remove(ChatMessageContent item) { var index = base.IndexOf(item); From e4ae4948bf1b435500c5e54f57f471230ad7c2db Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Wed, 9 Apr 2025 13:59:46 +0100 Subject: [PATCH 08/26] FilterCanOverrideArguments --- ...FunctionInvocationFilterChatClientTests.cs | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs index d28c67346c3f..908acf9356a0 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs @@ -7,6 +7,7 @@ using System.Net; using System.Net.Http; using System.Threading.Tasks; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; @@ -316,7 +317,12 @@ public async Task FilterCanOverrideArgumentsAsync() })); // Assert - Assert.Equal("NewValue", result.ToString()); + var chatResponse = Assert.IsType(result.GetValue()); + Assert.NotNull(chatResponse); + + var lastFunctionResult = GetLastFunctionResultFromChatResponse(chatResponse); + Assert.NotNull(lastFunctionResult); + Assert.Equal("NewValue", lastFunctionResult.ToString()); } [Fact] @@ -728,6 +734,18 @@ public void Dispose() #region private + private static object? GetLastFunctionResultFromChatResponse(ChatResponse chatResponse) + { + Assert.NotEmpty(chatResponse.Messages); + var chatMessage = chatResponse.Messages[^1]; + + Assert.NotEmpty(chatMessage.Contents); + Assert.Contains(chatMessage.Contents, c => c is Microsoft.Extensions.AI.FunctionResultContent); + + var resultContent = (Microsoft.Extensions.AI.FunctionResultContent)chatMessage.Contents.Last(c => c is Microsoft.Extensions.AI.FunctionResultContent); + return resultContent.Result; + } + #pragma warning disable CA2000 // Dispose objects before losing scope private static List GetFunctionCallingResponses() { From 7a5fc08158a2e9cbae81faf2fb34d04a2504bcf9 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Wed, 9 Apr 2025 14:51:23 +0100 Subject: [PATCH 09/26] Fix failing UT + Aded PromptExecutionSettings to be propagated into AutoInvocationContext --- .../Connectors.OpenAI.UnitTests.csproj | 3 ++ ...FunctionInvocationFilterChatClientTests.cs | 28 +++++++++---------- ..._multiple_function_calls_test_response.txt | 5 ++++ .../AI/ChatClient/ChatClientExtensions.cs | 26 +++++++++++++++++ .../AI/ChatClient/ChatOptionsExtensions.cs | 1 + .../KernelFunctionInvokingChatClient.cs | 2 +- .../AutoFunctionInvocationContext.cs | 13 ++++++++- .../Functions/KernelFunctionFromPrompt.cs | 2 +- 8 files changed, 63 insertions(+), 17 deletions(-) create mode 100644 dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj index 33557430e615..04d35b9e6561 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj @@ -102,6 +102,9 @@ Always + + Always + diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs index 908acf9356a0..3031ead057b4 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs @@ -67,7 +67,7 @@ public async Task FiltersAreExecutedCorrectlyAsync() // Act var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings { - ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() })); // Assert @@ -145,7 +145,7 @@ public async Task FiltersAreExecutedCorrectlyOnStreamingAsync() this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); - var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }; + var executionSettings = new OpenAIPromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; // Act await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings))) @@ -350,7 +350,7 @@ public async Task FilterCanHandleExceptionAsync() var chatCompletion = new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient); - var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }; + var executionSettings = new OpenAIPromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; var chatHistory = new ChatHistory(); chatHistory.AddSystemMessage("System message"); @@ -391,7 +391,7 @@ public async Task FilterCanHandleExceptionOnStreamingAsync() var chatCompletion = new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient); var chatHistory = new ChatHistory(); - var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }; + var executionSettings = new OpenAIPromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; // Act await foreach (var item in chatCompletion.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel)) @@ -437,7 +437,7 @@ public async Task FiltersCanSkipFunctionExecutionAsync() // Act var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings { - ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() })); // Assert @@ -469,9 +469,9 @@ public async Task PreFilterCanTerminateOperationAsync() this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); // Act - await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + await kernel.InvokePromptAsync("Test prompt", new(new PromptExecutionSettings { - ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() })); // Assert @@ -501,7 +501,7 @@ public async Task PreFilterCanTerminateOperationOnStreamingAsync() this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); - var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }; + var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; // Act await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings))) @@ -542,7 +542,7 @@ public async Task PostFilterCanTerminateOperationAsync() // Act var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings { - ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() })); // Assert @@ -586,7 +586,7 @@ public async Task PostFilterCanTerminateOperationOnStreamingAsync() this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); - var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }; + var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; List streamingContent = []; @@ -683,9 +683,9 @@ public async Task PromptExecutionSettingsArePropagatedFromInvokePromptToFilterCo return Task.CompletedTask; }); - var expectedExecutionSettings = new OpenAIPromptExecutionSettings + var expectedExecutionSettings = new PromptExecutionSettings { - ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; // Act @@ -712,9 +712,9 @@ public async Task PromptExecutionSettingsArePropagatedFromInvokePromptStreamingT return Task.CompletedTask; }); - var expectedExecutionSettings = new OpenAIPromptExecutionSettings + var expectedExecutionSettings = new PromptExecutionSettings { - ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; // Act diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt new file mode 100644 index 000000000000..c113e3fa97ca --- /dev/null +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt @@ -0,0 +1,5 @@ +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":0,"id":"1","type":"function","function":{"name":"MyPlugin_Function1","arguments":"{\n\"parameter\": \"function1-value\"\n}"}}]},"finish_reason":"tool_calls"}]} + +data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":1,"id":"2","type":"function","function":{"name":"MyPlugin_Function2","arguments":"{\n\"parameter\": \"function2-value\"\n}"}}]},"finish_reason":"tool_calls"}]} + +data: [DONE] diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs index c480772750de..f3aa9c4ad069 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Threading.Tasks; @@ -30,6 +31,9 @@ internal static Task GetResponseAsync( { var chatOptions = executionSettings?.ToChatOptions(kernel); + // Passing by reference to be used by AutoFunctionInvocationFilters + chatOptions.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = executionSettings; + // Try to parse the text as a chat history if (ChatPromptParser.TryParse(prompt, out var chatHistoryFromPrompt)) { @@ -40,6 +44,28 @@ internal static Task GetResponseAsync( return chatClient.GetResponseAsync(prompt, chatOptions, cancellationToken); } + /// Get ChatClient streaming response for the prompt, settings and kernel. + /// Target chat client service. + /// The standardized prompt input. + /// The AI execution settings (optional). + /// The containing services, plugins, and other state for use throughout the operation. + /// The to monitor for cancellation requests. The default is . + /// Streaming list of different completion streaming string updates generated by the remote model + internal static IAsyncEnumerable GetStreamingResponseAsync( + this IChatClient chatClient, + string prompt, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + { + var chatOptions = executionSettings?.ToChatOptions(kernel); + + // Passing by reference to be used by AutoFunctionInvocationFilters + chatOptions.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = executionSettings; + + return chatClient.GetStreamingResponseAsync(prompt, chatOptions, cancellationToken); + } + /// Creates an for the specified . /// The chat client to be represented as a chat completion service. /// An optional that can be used to resolve services to use in the instance. diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs index d6b5109f9b17..de2c6831b93d 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs @@ -17,6 +17,7 @@ internal static class ChatOptionsExtensions internal const string KernelKey = "AutoInvokingKernel"; internal const string IsStreamingKey = "AutoInvokingIsStreaming"; internal const string ChatMessageContentKey = "AutoInvokingChatCompletionContent"; + internal const string PromptExecutionSettingsKey = "AutoInvokingPromptExecutionSettings"; /// Converts a to a . internal static PromptExecutionSettings? ToPromptExecutionSettings(this ChatOptions? options) diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index 66b6e46e84ec..3f54113ab997 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -657,7 +657,7 @@ private async Task ProcessFunctionCallAsync( AIFunction = function, Iteration = iteration, FunctionCallIndex = functionCallIndex, - FunctionCount = callContents.Count + FunctionCount = callContents.Count, }; object? result; diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs index a431ed868381..e642c1a55d8f 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs @@ -34,6 +34,9 @@ public AutoFunctionInvocationContext(ChatOptions options) options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.ChatMessageContentKey, out var chatMessageContent); Verify.NotNull(chatMessageContent); + options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.PromptExecutionSettingsKey, out var executionSettings); + this.ExecutionSettings = executionSettings; + options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.IsStreamingKey, out var isStreaming); Verify.NotNull(isStreaming); this.IsStreaming = isStreaming.Value; @@ -137,7 +140,15 @@ public string? ToolCallId /// The execution settings associated with the operation. /// [Experimental("SKEXP0001")] - public PromptExecutionSettings? ExecutionSettings { get; init; } + public PromptExecutionSettings? ExecutionSettings + { + get => this.Options?.AdditionalProperties?[ChatOptionsExtensions.PromptExecutionSettingsKey] as PromptExecutionSettings; + init + { + this.Options.AdditionalProperties ??= []; + this.Options.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = value; + } + } /// /// Gets the associated with automatic function invocation. diff --git a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs index 0fc3984f8d80..9db8f8e1ce46 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs @@ -286,7 +286,7 @@ protected override async IAsyncEnumerable InvokeStreamingCoreAsync Date: Wed, 9 Apr 2025 18:31:20 +0100 Subject: [PATCH 10/26] Resolving UT for FilterCanHandleException --- ...FunctionInvocationFilterChatClientTests.cs | 30 +++++++++++-------- .../ChatCompletion/ChatHistoryExtensions.cs | 5 +++- .../AI/PromptExecutionSettingsExtensions.cs | 5 ++-- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs index 3031ead057b4..0e0a7c1d4246 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Net; using System.Net.Http; +using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; @@ -185,10 +186,7 @@ public async Task DifferentWaysOfAddingFiltersWorkCorrectlyAsync() builder.Plugins.Add(plugin); - builder.Services.AddSingleton((serviceProvider) => - { - return new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient); - }); + builder.Services.AddOpenAIChatClient("model-id", "test-api-key", "organization-id", httpClient: this._httpClient); this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); @@ -202,9 +200,9 @@ public async Task DifferentWaysOfAddingFiltersWorkCorrectlyAsync() // Case #2 - Add filter to kernel kernel.AutoFunctionInvocationFilters.Add(filter2); - var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + var result = await kernel.InvokePromptAsync("Test prompt", new(new PromptExecutionSettings { - ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions + FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() })); // Assert @@ -348,22 +346,28 @@ public async Task FilterCanHandleExceptionAsync() this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); - var chatCompletion = new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient); + var chatClient = kernel.GetRequiredService(); - var executionSettings = new OpenAIPromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; var chatHistory = new ChatHistory(); chatHistory.AddSystemMessage("System message"); // Act - var result = await chatCompletion.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel); + var messageList = chatHistory.ToChatMessageList(); + var options = executionSettings.ToChatOptions(kernel); + var resultMessages = await chatClient.GetResponseAsync(messageList, options, CancellationToken.None); - var firstFunctionResult = chatHistory[^2].Content; - var secondFunctionResult = chatHistory[^1].Content; + var firstToolMessage = resultMessages.Messages.First(m => m.Role == ChatRole.Tool); + Assert.NotNull(firstToolMessage); + var firstFunctionResult = firstToolMessage.Contents[^2] as Microsoft.Extensions.AI.FunctionResultContent; + var secondFunctionResult = firstToolMessage.Contents[^1] as Microsoft.Extensions.AI.FunctionResultContent; // Assert - Assert.Equal("Result from filter", firstFunctionResult); - Assert.Equal("Result from Function2", secondFunctionResult); + Assert.NotNull(firstFunctionResult); + Assert.NotNull(secondFunctionResult); + Assert.Equal("Result from filter", firstFunctionResult.Result!.ToString()); + Assert.Equal("Result from Function2", secondFunctionResult.Result!.ToString()); } [Fact] diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs index a238e77417da..f027a9b7a31c 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs @@ -82,6 +82,9 @@ public static async Task ReduceAsync(this ChatHistory chatHistory, return chatHistory; } - internal static List ToChatMessageList(this ChatHistory chatHistory) + /// Converts a to a list. + /// The chat history to convert. + /// A list of objects. + public static List ToChatMessageList(this ChatHistory chatHistory) => chatHistory.Select(m => m.ToChatMessage()).ToList(); } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs index f9d4a5db151b..74fe27c2e841 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs @@ -12,10 +12,11 @@ namespace Microsoft.SemanticKernel; -internal static class PromptExecutionSettingsExtensions +/// Extensions methods for . +public static class PromptExecutionSettingsExtensions { /// Converts a pair of and to a . - internal static ChatOptions? ToChatOptions(this PromptExecutionSettings? settings, Kernel? kernel) + public static ChatOptions? ToChatOptions(this PromptExecutionSettings? settings, Kernel? kernel) { if (settings is null) { From 85c8cc44ed3688e1643bd58f77c94f44fd4f6c2e Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Wed, 9 Apr 2025 18:58:34 +0100 Subject: [PATCH 11/26] Adjust for HandleExceptionONStreaming --- ...FunctionInvocationFilterChatClientTests.cs | 38 +++++++++++-------- .../ChatCompletion/ChatHistoryExtensions.cs | 2 +- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs index 0e0a7c1d4246..3c42a004eb32 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs @@ -349,21 +349,18 @@ public async Task FilterCanHandleExceptionAsync() var chatClient = kernel.GetRequiredService(); var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; - - var chatHistory = new ChatHistory(); - chatHistory.AddSystemMessage("System message"); + var options = executionSettings.ToChatOptions(kernel); + List messageList = [new(ChatRole.System, "System message")]; // Act - var messageList = chatHistory.ToChatMessageList(); - var options = executionSettings.ToChatOptions(kernel); var resultMessages = await chatClient.GetResponseAsync(messageList, options, CancellationToken.None); + // Assert var firstToolMessage = resultMessages.Messages.First(m => m.Role == ChatRole.Tool); Assert.NotNull(firstToolMessage); var firstFunctionResult = firstToolMessage.Contents[^2] as Microsoft.Extensions.AI.FunctionResultContent; var secondFunctionResult = firstToolMessage.Contents[^1] as Microsoft.Extensions.AI.FunctionResultContent; - // Assert Assert.NotNull(firstFunctionResult); Assert.NotNull(secondFunctionResult); Assert.Equal("Result from filter", firstFunctionResult.Result!.ToString()); @@ -392,21 +389,30 @@ public async Task FilterCanHandleExceptionOnStreamingAsync() this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); - var chatCompletion = new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient); + var chatClient = kernel.GetRequiredService(); - var chatHistory = new ChatHistory(); - var executionSettings = new OpenAIPromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; + var options = executionSettings.ToChatOptions(kernel); + List messageList = []; // Act - await foreach (var item in chatCompletion.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel)) - { } - - var firstFunctionResult = chatHistory[^2].Content; - var secondFunctionResult = chatHistory[^1].Content; + List streamingContent = []; + await foreach (var update in chatClient.GetStreamingResponseAsync(messageList, options, CancellationToken.None)) + { + streamingContent.Add(update); + } + var chatResponse = streamingContent.ToChatResponse(); // Assert - Assert.Equal("Result from filter", firstFunctionResult); - Assert.Equal("Result from Function2", secondFunctionResult); + var firstToolMessage = chatResponse.Messages.First(m => m.Role == ChatRole.Tool); + Assert.NotNull(firstToolMessage); + var firstFunctionResult = firstToolMessage.Contents[^2] as Microsoft.Extensions.AI.FunctionResultContent; + var secondFunctionResult = firstToolMessage.Contents[^1] as Microsoft.Extensions.AI.FunctionResultContent; + + Assert.NotNull(firstFunctionResult); + Assert.NotNull(secondFunctionResult); + Assert.Equal("Result from filter", firstFunctionResult.Result!.ToString()); + Assert.Equal("Result from Function2", secondFunctionResult.Result!.ToString()); } [Fact] diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs index f027a9b7a31c..9702bc5da22e 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs @@ -85,6 +85,6 @@ public static async Task ReduceAsync(this ChatHistory chatHistory, /// Converts a to a list. /// The chat history to convert. /// A list of objects. - public static List ToChatMessageList(this ChatHistory chatHistory) + internal static List ToChatMessageList(this ChatHistory chatHistory) => chatHistory.Select(m => m.ToChatMessage()).ToList(); } From ead8061c86aa1686193e2f00f5c486106d68a9c2 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Wed, 9 Apr 2025 20:48:03 +0100 Subject: [PATCH 12/26] Added all behavior except Skipping --- ...FunctionInvocationFilterChatClientTests.cs | 60 ++++++++---------- .../FunctionCalling/FunctionCallsProcessor.cs | 6 +- .../KernelFunctionInvokingChatClient.cs | 61 ++++++++++++------- .../AutoFunctionInvocationContext.cs | 4 -- 4 files changed, 67 insertions(+), 64 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs index 3c42a004eb32..1abb85a74b52 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs @@ -248,10 +248,7 @@ public async Task MultipleFiltersAreExecutedInOrderAsync(bool isStreaming) builder.Plugins.Add(plugin); - builder.Services.AddSingleton((serviceProvider) => - { - return new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient); - }); + builder.Services.AddOpenAIChatClient("model-id", "test-api-key", "organization-id", httpClient: this._httpClient); builder.Services.AddSingleton(filter1); builder.Services.AddSingleton(filter2); @@ -259,24 +256,21 @@ public async Task MultipleFiltersAreExecutedInOrderAsync(bool isStreaming) var kernel = builder.Build(); - var arguments = new KernelArguments(new OpenAIPromptExecutionSettings - { - ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions - }); + var settings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; // Act if (isStreaming) { this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); - await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", arguments)) + await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(settings))) { } } else { this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); - await kernel.InvokePromptAsync("Test prompt", arguments); + await kernel.InvokePromptAsync("Test prompt", new(settings)); } // Assert @@ -439,13 +433,13 @@ public async Task FiltersCanSkipFunctionExecutionAsync() filterInvocations++; }); - using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(File.ReadAllText("TestData/filters_multiple_function_calls_test_response.json")) }; + using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(File.ReadAllText("TestData/filters_chatclient_multiple_function_calls_test_response.json")) }; using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(File.ReadAllText("TestData/chat_completion_test_response.json")) }; this._messageHandlerStub.ResponsesToReturn = [response1, response2]; // Act - var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + var result = await kernel.InvokePromptAsync("Test prompt", new(new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() })); @@ -550,7 +544,7 @@ public async Task PostFilterCanTerminateOperationAsync() this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); // Act - var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings + var functionResult = await kernel.InvokePromptAsync("Test prompt", new(new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() })); @@ -562,11 +556,12 @@ public async Task PostFilterCanTerminateOperationAsync() Assert.Equal([0], functionSequenceNumbers); // Results of function invoked before termination should be returned - var lastMessageContent = result.GetValue(); - Assert.NotNull(lastMessageContent); + var chatResponse = functionResult.GetValue(); + Assert.NotNull(chatResponse); - Assert.Equal("function1-value", lastMessageContent.Content); - Assert.Equal(AuthorRole.Tool, lastMessageContent.Role); + var result = GetLastFunctionResultFromChatResponse(chatResponse); + Assert.NotNull(result); + Assert.Equal("function1-value", result.ToString()); } [Fact] @@ -598,12 +593,12 @@ public async Task PostFilterCanTerminateOperationOnStreamingAsync() var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; - List streamingContent = []; + List streamingContent = []; // Act - await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings))) + await foreach (var update in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings))) { - streamingContent.Add(item); + streamingContent.Add(update); } // Assert @@ -613,13 +608,14 @@ public async Task PostFilterCanTerminateOperationOnStreamingAsync() Assert.Equal([0], functionSequenceNumbers); // Results of function invoked before termination should be returned - Assert.Equal(3, streamingContent.Count); + Assert.Equal(4, streamingContent.Count); - var lastMessageContent = streamingContent[^1] as StreamingChatMessageContent; - Assert.NotNull(lastMessageContent); + var chatResponse = streamingContent.ToChatResponse(); + Assert.NotNull(chatResponse); - Assert.Equal("function1-value", lastMessageContent.Content); - Assert.Equal(AuthorRole.Tool, lastMessageContent.Role); + var result = GetLastFunctionResultFromChatResponse(chatResponse); + Assert.NotNull(result); + Assert.Equal("function1-value", result.ToString()); } [Theory] @@ -645,32 +641,26 @@ public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming) builder.Plugins.Add(plugin); - builder.Services.AddSingleton((serviceProvider) => - { - return new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient); - }); + builder.Services.AddOpenAIChatClient("model-id", "test-api-key", "organization-id", httpClient: this._httpClient); builder.Services.AddSingleton(filter); var kernel = builder.Build(); - var arguments = new KernelArguments(new OpenAIPromptExecutionSettings - { - ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions - }); + var settings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }; // Act if (isStreaming) { this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses(); - await kernel.InvokePromptStreamingAsync("Test prompt", arguments).ToListAsync(); + await kernel.InvokePromptStreamingAsync("Test prompt", new(settings)).ToListAsync(); } else { this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses(); - await kernel.InvokePromptAsync("Test prompt", arguments); + await kernel.InvokePromptAsync("Test prompt", new(settings)); } // Assert diff --git a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs index 97cb426c307d..88f3da9d6a53 100644 --- a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs +++ b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs @@ -211,7 +211,7 @@ public FunctionCallsProcessor(ILogger? logger = null) { bool terminationRequested = false; - // Wait for all of the function invocations to complete, then add the results to the chat, but stop when we hit a + // Wait for all the function invocations to complete, then add the results to the chat, but stop when we hit a // function for which termination was requested. FunctionResultContext[] resultContexts = await Task.WhenAll(functionTasks).ConfigureAwait(false); foreach (FunctionResultContext resultContext in resultContexts) @@ -487,8 +487,8 @@ public static string ProcessFunctionResult(object functionResult) return stringResult; } - // This is an optimization to use ChatMessageContent content directly - // without unnecessary serialization of the whole message content class. + // This is an optimization to use ChatMessageContent content directly + // without unnecessary serialization of the whole message content class. if (functionResult is ChatMessageContent chatMessageContent) { return chatMessageContent.ToString(); diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index 3f54113ab997..75d052842604 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -257,7 +257,7 @@ public override async Task GetResponseAsync( // Add the responses from the function calls into the augmented history and also into the tracked // list of response messages. - var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, cancellationToken).ConfigureAwait(false); + var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, isStreaming: false, cancellationToken).ConfigureAwait(false); responseMessages.AddRange(modeAndMessages.MessagesAdded); // Clear the auto function invocation options. @@ -337,7 +337,7 @@ public override async IAsyncEnumerable GetStreamingResponseA UpdateOptionsForAutoFunctionInvocation(ref options!, response.Messages.Last().ToChatMessageContent(), isStreaming: true); // Process all the functions, adding their results into the history. - var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); + var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, isStreaming: true, cancellationToken).ConfigureAwait(false); responseMessages.AddRange(modeAndMessages.MessagesAdded); // Clear the auto function invocation options. @@ -547,21 +547,23 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti /// The options used for the response being processed. /// The function call contents representing the functions to be invoked. /// The iteration number of how many roundtrips have been made to the inner client. + /// Whether the function calls are being processed in a streaming context. /// The to monitor for cancellation requests. /// A value indicating how the caller should proceed. private async Task<(ContinueMode Mode, IList MessagesAdded)> ProcessFunctionCallsAsync( - List messages, ChatOptions options, List functionCallContents, int iteration, CancellationToken cancellationToken) + List messages, ChatOptions options, List functionCallContents, int iteration, bool isStreaming, CancellationToken cancellationToken) { // We must add a response for every tool call, regardless of whether we successfully executed it or not. // If we successfully execute it, we'll add the result. If we don't, we'll add an error. Debug.Assert(functionCallContents.Count > 0, "Expected at least one function call."); + ContinueMode continueMode = ContinueMode.Continue; // Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel. if (functionCallContents.Count == 1) { FunctionInvocationResult result = await this.ProcessFunctionCallAsync( - messages, options, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false); + messages, options, functionCallContents, iteration, 0, isStreaming, cancellationToken).ConfigureAwait(false); IList added = this.CreateResponseMessages([result]); this.ThrowIfNoFunctionResultsAdded(added); @@ -571,40 +573,54 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti } else { - FunctionInvocationResult[] results; + List results = []; + var terminationRequested = false; if (this.AllowConcurrentInvocation) { // Schedule the invocation of every function. - results = await Task.WhenAll( + results.AddRange(await Task.WhenAll( from i in Enumerable.Range(0, functionCallContents.Count) select Task.Run(() => this.ProcessFunctionCallAsync( messages, options, functionCallContents, - iteration, i, cancellationToken))).ConfigureAwait(false); + iteration, i, isStreaming, cancellationToken))).ConfigureAwait(false)); + + terminationRequested = results.Any(r => r.ContinueMode == ContinueMode.Terminate); } else { // Invoke each function serially. - results = new FunctionInvocationResult[functionCallContents.Count]; - for (int i = 0; i < results.Length; i++) + for (int i = 0; i < functionCallContents.Count; i++) { - results[i] = await this.ProcessFunctionCallAsync( + var result = await this.ProcessFunctionCallAsync( messages, options, functionCallContents, - iteration, i, cancellationToken).ConfigureAwait(false); + iteration, i, isStreaming, cancellationToken).ConfigureAwait(false); + + results.Add(result); + + if (result.ContinueMode == ContinueMode.Terminate) + { + continueMode = ContinueMode.Terminate; + terminationRequested = true; + break; + } } } - ContinueMode continueMode = ContinueMode.Continue; - IList added = this.CreateResponseMessages(results); this.ThrowIfNoFunctionResultsAdded(added); - messages.AddRange(added); - foreach (FunctionInvocationResult fir in results) + + if (!terminationRequested) { - if (fir.ContinueMode > continueMode) + // If any function requested termination, we'll terminate. + continueMode = ContinueMode.Continue; + foreach (FunctionInvocationResult fir in results) { - continueMode = fir.ContinueMode; + if (fir.ContinueMode > continueMode) + { + continueMode = fir.ContinueMode; + } } } @@ -629,11 +645,12 @@ private void ThrowIfNoFunctionResultsAdded(IList? messages) /// The function call contents representing all the functions being invoked. /// The iteration number of how many roundtrips have been made to the inner client. /// The 0-based index of the function being called out of . + /// Whether the function calls are being processed in a streaming context. /// The to monitor for cancellation requests. /// A value indicating how the caller should proceed. private async Task ProcessFunctionCallAsync( List messages, ChatOptions options, List callContents, - int iteration, int functionCallIndex, CancellationToken cancellationToken) + int iteration, int functionCallIndex, bool isStreaming, CancellationToken cancellationToken) { var callContent = callContents[functionCallIndex]; @@ -658,6 +675,7 @@ private async Task ProcessFunctionCallAsync( Iteration = iteration, FunctionCallIndex = functionCallIndex, FunctionCount = callContents.Count, + IsStreaming = isStreaming }; object? result; @@ -700,10 +718,10 @@ internal enum ContinueMode /// Information about the function call invocations and results. /// A list of all chat messages created from . internal IList CreateResponseMessages( - ReadOnlySpan results) + IReadOnlyList results) { - var contents = new List(results.Length); - for (int i = 0; i < results.Length; i++) + var contents = new List(results.Count); + for (int i = 0; i < results.Count; i++) { contents.Add(CreateFunctionResultContent(results[i])); } @@ -829,7 +847,6 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( result = await invocationContext.AIFunction.InvokeAsync(autoFunctionInvocationContext.Arguments, cancellationToken).ConfigureAwait(false); context.Result = new FunctionResult(context.Function, result); }).ConfigureAwait(false); - result = autoFunctionInvocationContext.Result.GetValue(); } } diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs index e642c1a55d8f..842352fe95dd 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs @@ -37,10 +37,6 @@ public AutoFunctionInvocationContext(ChatOptions options) options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.PromptExecutionSettingsKey, out var executionSettings); this.ExecutionSettings = executionSettings; - options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.IsStreamingKey, out var isStreaming); - Verify.NotNull(isStreaming); - this.IsStreaming = isStreaming.Value; - this.Result = new FunctionResult(this.Function) { Culture = kernel.Culture }; } From df4b108a6f44ee5afecdfba281aa7bde102a119f Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Thu, 10 Apr 2025 19:00:13 +0100 Subject: [PATCH 13/26] Fix warnings --- .../AI/ChatClient/ChatClientExtensions.cs | 19 +++++++++++-------- .../AI/ChatClient/ChatOptionsExtensions.cs | 11 +++++++---- .../AutoFunctionInvocationContext.cs | 2 +- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs index f3aa9c4ad069..6d6b80fe28bb 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs @@ -29,10 +29,7 @@ internal static Task GetResponseAsync( Kernel? kernel = null, CancellationToken cancellationToken = default) { - var chatOptions = executionSettings?.ToChatOptions(kernel); - - // Passing by reference to be used by AutoFunctionInvocationFilters - chatOptions.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = executionSettings; + var chatOptions = GetChatOptionsFromSettings(executionSettings, kernel); // Try to parse the text as a chat history if (ChatPromptParser.TryParse(prompt, out var chatHistoryFromPrompt)) @@ -58,10 +55,7 @@ internal static IAsyncEnumerable GetStreamingResponseAsync( Kernel? kernel = null, CancellationToken cancellationToken = default) { - var chatOptions = executionSettings?.ToChatOptions(kernel); - - // Passing by reference to be used by AutoFunctionInvocationFilters - chatOptions.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = executionSettings; + var chatOptions = GetChatOptionsFromSettings(executionSettings, kernel); return chatClient.GetStreamingResponseAsync(prompt, chatOptions, cancellationToken); } @@ -109,4 +103,13 @@ public static IChatClient AsKernelFunctionInvokingChatClient(this IChatClient cl ? kernelFunctionInvocationClient : new KernelFunctionInvokingChatClient(client, logger); } + + private static ChatOptions GetChatOptionsFromSettings(PromptExecutionSettings? executionSettings, Kernel? kernel) + { + ChatOptions chatOptions = executionSettings?.ToChatOptions(kernel) ?? new ChatOptions().AddKernel(kernel); + + // Passing by reference to be used by AutoFunctionInvocationFilters + chatOptions.AdditionalProperties![ChatOptionsExtensions.PromptExecutionSettingsKey] = executionSettings; + return chatOptions; + } } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs index de2c6831b93d..93d076090dcc 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs @@ -130,13 +130,16 @@ internal static class ChatOptionsExtensions /// /// Chat options. /// Kernel to be used for auto function invocation. - internal static ChatOptions? AddKernel(this ChatOptions options, Kernel kernel) + internal static ChatOptions AddKernel(this ChatOptions options, Kernel? kernel) { Verify.NotNull(options); - Verify.NotNull(kernel); - options.AdditionalProperties ??= []; - options.AdditionalProperties?.TryAdd(KernelKey, kernel); + // Only add the kernel if it is provided + if (kernel is not null) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties?.TryAdd(KernelKey, kernel); + } return options; } diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs index 842352fe95dd..d4fbe564fcd4 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; using System.Collections; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; @@ -141,6 +140,7 @@ public PromptExecutionSettings? ExecutionSettings get => this.Options?.AdditionalProperties?[ChatOptionsExtensions.PromptExecutionSettingsKey] as PromptExecutionSettings; init { + this.Options ??= new(); this.Options.AdditionalProperties ??= []; this.Options.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = value; } From ffc01d925bd1c903afdc0f2fd47ea962ffdfc61d Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Fri, 11 Apr 2025 11:19:51 +0100 Subject: [PATCH 14/26] AutoINvocation Skipping and ChatReducer passing --- ...FunctionInvocationFilterChatClientTests.cs | 2 +- .../Core/AutoFunctionInvocationFilterTests.cs | 2 +- .../AI/ChatCompletion/ChatHistory.cs | 61 +++++++++++++++-- .../ChatCompletion/ChatHistoryExtensions.cs | 58 +++++++++++++++++ .../AutoFunctionInvocationContext.cs | 65 ++++++++++++++++++- .../FunctionCallsProcessorTests.cs | 12 ++-- 6 files changed, 187 insertions(+), 13 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs index 1abb85a74b52..cb47c48af8d7 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs @@ -425,7 +425,7 @@ public async Task FiltersCanSkipFunctionExecutionAsync() var kernel = this.GetKernelWithFilter(plugin, async (context, next) => { // Filter delegate is invoked only for second function, the first one should be skipped. - if (context.Function.Name == "Function2") + if (context.Function.Name == "MyPlugin_Function2") { await next(context); } diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs index 1f23d84c4dfe..b308206b12d5 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs @@ -596,7 +596,7 @@ public async Task PostFilterCanTerminateOperationOnStreamingAsync() Assert.Equal([0], requestSequenceNumbers); Assert.Equal([0], functionSequenceNumbers); - // Results of function invoked before termination should be returned + // Results of function invoked before termination should be returned Assert.Equal(3, streamingContent.Count); var lastMessageContent = streamingContent[^1] as StreamingChatMessageContent; diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs index cfdad840349c..0054dc98a400 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs @@ -18,6 +18,14 @@ public class ChatHistory : IList, IReadOnlyListThe messages. private readonly List _messages; + private Action? _overrideAdd; + private Func? _overrideRemove; + private Action? _overrideClear; + private Action? _overrideInsert; + private Action? _overrideRemoveAt; + private Action? _overrideRemoveRange; + private Action>? _overrideAddRange; + /// Initializes an empty history. /// /// Creates a new instance of the class @@ -27,6 +35,36 @@ public ChatHistory() this._messages = []; } + // Due to changes using the AutoFunctionInvocation as a dependency of KernelInvocation, that needs to reflect + internal void SetOverrides( + Action overrideAdd, + Func overrideRemove, + Action onClear, + Action overrideInsert, + Action overrideRemoveAt, + Action overrideRemoveRange, + Action> overrideAddRange) + { + this._overrideAdd = overrideAdd; + this._overrideRemove = overrideRemove; + this._overrideClear = onClear; + this._overrideInsert = overrideInsert; + this._overrideRemoveAt = overrideRemoveAt; + this._overrideRemoveRange = overrideRemoveRange; + this._overrideAddRange = overrideAddRange; + } + + internal void ClearOverrides() + { + this._overrideAdd = null; + this._overrideRemove = null; + this._overrideClear = null; + this._overrideInsert = null; + this._overrideRemoveAt = null; + this._overrideRemoveRange = null; + this._overrideAddRange = null; + } + /// /// Creates a new instance of the with a first message in the provided . /// If not role is provided then the first message will default to role. @@ -121,6 +159,7 @@ public virtual void Add(ChatMessageContent item) { Verify.NotNull(item); this._messages.Add(item); + this._overrideAdd?.Invoke(item); } /// Adds the messages to the history. @@ -130,6 +169,7 @@ public virtual void AddRange(IEnumerable items) { Verify.NotNull(items); this._messages.AddRange(items); + this._overrideAddRange?.Invoke(items); } /// Inserts a message into the history at the specified index. @@ -140,6 +180,7 @@ public virtual void Insert(int index, ChatMessageContent item) { Verify.NotNull(item); this._messages.Insert(index, item); + this._overrideInsert?.Invoke(index, item); } /// @@ -150,10 +191,15 @@ public virtual void Insert(int index, ChatMessageContent item) /// is null. /// The number of messages in the history is greater than the available space from to the end of . /// is less than 0. - public virtual void CopyTo(ChatMessageContent[] array, int arrayIndex) => this._messages.CopyTo(array, arrayIndex); + public virtual void CopyTo(ChatMessageContent[] array, int arrayIndex) + => this._messages.CopyTo(array, arrayIndex); /// Removes all messages from the history. - public virtual void Clear() => this._messages.Clear(); + public virtual void Clear() + { + this._messages.Clear(); + this._overrideClear?.Invoke(); + } /// Gets or sets the message at the specified index in the history. /// The index of the message to get or set. @@ -193,7 +239,11 @@ public virtual int IndexOf(ChatMessageContent item) /// Removes the message at the specified index from the history. /// The index of the message to remove. /// The was not valid for this history. - public virtual void RemoveAt(int index) => this._messages.RemoveAt(index); + public virtual void RemoveAt(int index) + { + this._messages.RemoveAt(index); + this._overrideRemoveAt?.Invoke(index); + } /// Removes the first occurrence of the specified message from the history. /// The message to remove from the history. @@ -202,7 +252,9 @@ public virtual int IndexOf(ChatMessageContent item) public virtual bool Remove(ChatMessageContent item) { Verify.NotNull(item); - return this._messages.Remove(item); + var result = this._messages.Remove(item); + this._overrideRemove?.Invoke(item); + return result; } /// @@ -216,6 +268,7 @@ public virtual bool Remove(ChatMessageContent item) public virtual void RemoveRange(int index, int count) { this._messages.RemoveRange(index, count); + this._overrideRemoveRange?.Invoke(index, count); } /// diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs index 9702bc5da22e..8a65e502464f 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs @@ -87,4 +87,62 @@ public static async Task ReduceAsync(this ChatHistory chatHistory, /// A list of objects. internal static List ToChatMessageList(this ChatHistory chatHistory) => chatHistory.Select(m => m.ToChatMessage()).ToList(); + + internal static void SetChatMessageHandlers(this ChatHistory chatHistory, IList messages) + { + chatHistory.SetOverrides(Add, Remove, Clear, Insert, RemoveAt, RemoveRange, AddRange); + + void Add(ChatMessageContent item) + { + messages.Add(item.ToChatMessage()); + } + + void Clear() + { + messages.Clear(); + } + + bool Remove(ChatMessageContent item) + { + var index = chatHistory.IndexOf(item); + + if (index < 0) + { + return false; + } + + messages.RemoveAt(index); + + return true; + } + + void Insert(int index, ChatMessageContent item) + { + messages.Insert(index, item.ToChatMessage()); + } + + void RemoveAt(int index) + { + messages.RemoveAt(index); + } + + void RemoveRange(int index, int count) + { + if (messages is List messageList) + { + messageList.RemoveRange(index, count); + return; + } + + foreach (var chatMessage in messages.Skip(index).Take(count)) + { + messages.Remove(chatMessage); + } + } + + void AddRange(IEnumerable items) + { + messages.AddRange(items.Select(i => i.ToChatMessage())); + } + } } diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs index d4fbe564fcd4..b15ac9b00fe7 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs @@ -1,8 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Collections; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using System.Linq; using System.Threading; using Microsoft.Extensions.AI; using Microsoft.SemanticKernel.ChatCompletion; @@ -15,6 +17,7 @@ namespace Microsoft.SemanticKernel; public class AutoFunctionInvocationContext : KernelFunctionInvocationContext { private ChatHistory? _chatHistory; + private KernelFunction? _kernelFunction; /// /// Initializes a new instance of the class from an existing . @@ -69,7 +72,10 @@ public AutoFunctionInvocationContext( } }; + this._kernelFunction = function; + this._chatHistory = chatHistory; this.Messages = chatHistory.ToChatMessageList(); + chatHistory.SetChatMessageHandlers(this.Messages); this.AIFunction = function.AsAIFunction(); this.Result = result; } @@ -154,7 +160,21 @@ public PromptExecutionSettings? ExecutionSettings /// /// Gets the with which this filter is associated. /// - public KernelFunction Function => this.AIFunction.AsKernelFunction(); + public KernelFunction Function + { + get + { + if (this._kernelFunction is null + // If the schemas are different, + // AIFunction reference potentially was modified and the kernel function should be regenerated. + || !IsSameSchema(this._kernelFunction, this.AIFunction)) + { + this._kernelFunction = this.AIFunction.AsKernelFunction(); + } + + return this._kernelFunction; + } + } /// /// Gets the containing services, plugins, and other state for use throughout the operation. @@ -176,6 +196,18 @@ public Kernel Kernel /// public FunctionResult Result { get; set; } + private static bool IsSameSchema(KernelFunction kernelFunction, AIFunction aiFunction) + { + // Compares the schemas, should be similar. + return string.Equals( + kernelFunction.AsAIFunction().JsonSchema.ToString(), + aiFunction.JsonSchema.ToString(), + StringComparison.OrdinalIgnoreCase); + + // TODO: Later can be improved by comparing the underlying methods. + // return kernelFunction.UnderlyingMethod == aiFunction.UnderlyingMethod; + } + /// /// Mutable IEnumerable of chat message as chat history. /// @@ -237,6 +269,30 @@ public override ChatMessageContent this[int index] } } + public override void RemoveRange(int index, int count) + { + this._messages.RemoveRange(index, count); + base.RemoveRange(index, count); + } + + public override void CopyTo(ChatMessageContent[] array, int arrayIndex) + { + for (int i = 0; i < this._messages.Count; i++) + { + array[arrayIndex + i] = this._messages[i].ToChatMessageContent(); + } + } + + public override bool Contains(ChatMessageContent item) => base.Contains(item); + + public override int IndexOf(ChatMessageContent item) => base.IndexOf(item); + + public override void AddRange(IEnumerable items) + { + base.AddRange(items); + this._messages.AddRange(items.Select(i => i.ToChatMessage())); + } + public override int Count => this._messages.Count; // Explicit implementation of IEnumerable.GetEnumerator() @@ -252,4 +308,11 @@ IEnumerator IEnumerable.GetEnumerator() IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)this).GetEnumerator(); } + + ~AutoFunctionInvocationContext() + { + // The moment this class is destroyed, we need to clear the overrides and + // overrides to update to message + this._chatHistory?.ClearOverrides(); + } } diff --git a/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs b/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs index 9015b2908b2f..ca3f269abc64 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs @@ -184,7 +184,7 @@ public async Task ItShouldAddFunctionInvocationExceptionToChatHistoryAsync() var chatHistory = new ChatHistory(); var chatMessageContent = new ChatMessageContent(); - chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin")); + chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test")); // Act await this._sut.ProcessFunctionCallsAsync( @@ -213,7 +213,7 @@ public async Task ItShouldAddErrorToChatHistoryIfFunctionCallNotAdvertisedAsync( var chatHistory = new ChatHistory(); var chatMessageContent = new ChatMessageContent(); - chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin")); + chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test")); // Act await this._sut.ProcessFunctionCallsAsync( @@ -242,7 +242,7 @@ public async Task ItShouldAddErrorToChatHistoryIfFunctionIsNotRegisteredOnKernel var chatHistory = new ChatHistory(); var chatMessageContent = new ChatMessageContent(); - chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin")); // The call for function that is not registered on the kernel + chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test")); // The call for function that is not registered on the kernel // Act await this._sut.ProcessFunctionCallsAsync( @@ -743,7 +743,7 @@ public async Task ItShouldHandleChatMessageContentAsFunctionResultAsync() var chatHistory = new ChatHistory(); var chatMessageContent = new ChatMessageContent(); - chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin")); + chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test")); // Act await this._sut.ProcessFunctionCallsAsync( @@ -779,7 +779,7 @@ public async Task ItShouldSerializeFunctionResultOfUnknownTypeAsync() var chatHistory = new ChatHistory(); var chatMessageContent = new ChatMessageContent(); - chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin")); + chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test")); // Act await this._sut.ProcessFunctionCallsAsync( @@ -870,7 +870,7 @@ public async Task ItShouldPassPromptExecutionSettingsToAutoFunctionInvocationFil }); var chatMessageContent = new ChatMessageContent(); - chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", arguments: new KernelArguments() { ["parameter"] = "function1-result" })); + chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test", arguments: new KernelArguments() { ["parameter"] = "function1-result" })); // Act await this._sut.ProcessFunctionCallsAsync( From f7ee7a2aa22922cc25498572a29a45ac4169e11a Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Fri, 11 Apr 2025 11:30:06 +0100 Subject: [PATCH 15/26] Fix CallId optionality --- .../AI/ChatClient/ChatOptionsExtensions.cs | 1 - .../ChatClient/KernelFunctionInvokingChatClient.cs | 2 +- .../AutoFunctionInvocationContext.cs | 5 ++--- .../AIConnectors/FunctionCallsProcessorTests.cs | 12 ++++++------ 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs index 93d076090dcc..5db8240b1707 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Text.Json; using Microsoft.Extensions.AI; diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index 75d052842604..1d48111b47af 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -846,7 +846,7 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( // as the called function could in turn telling the model about itself as a possible candidate for invocation. result = await invocationContext.AIFunction.InvokeAsync(autoFunctionInvocationContext.Arguments, cancellationToken).ConfigureAwait(false); context.Result = new FunctionResult(context.Function, result); - }).ConfigureAwait(false); + }).ConfigureAwait(false); result = autoFunctionInvocationContext.Result.GetValue(); } } diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs index b15ac9b00fe7..e3318555b129 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs @@ -126,9 +126,7 @@ public string? ToolCallId get => this.CallContent.CallId; init { - Verify.NotNull(value); - // ToolCallId - this.CallContent = new Microsoft.Extensions.AI.FunctionCallContent(value, this.CallContent.Name, this.CallContent.Arguments); + this.CallContent = new Microsoft.Extensions.AI.FunctionCallContent(value ?? string.Empty, this.CallContent.Name, this.CallContent.Arguments); } } @@ -309,6 +307,7 @@ IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)this).GetEnumerator(); } + /// Destructor to clear the overrides and update the message. ~AutoFunctionInvocationContext() { // The moment this class is destroyed, we need to clear the overrides and diff --git a/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs b/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs index ca3f269abc64..9015b2908b2f 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs @@ -184,7 +184,7 @@ public async Task ItShouldAddFunctionInvocationExceptionToChatHistoryAsync() var chatHistory = new ChatHistory(); var chatMessageContent = new ChatMessageContent(); - chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test")); + chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin")); // Act await this._sut.ProcessFunctionCallsAsync( @@ -213,7 +213,7 @@ public async Task ItShouldAddErrorToChatHistoryIfFunctionCallNotAdvertisedAsync( var chatHistory = new ChatHistory(); var chatMessageContent = new ChatMessageContent(); - chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test")); + chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin")); // Act await this._sut.ProcessFunctionCallsAsync( @@ -242,7 +242,7 @@ public async Task ItShouldAddErrorToChatHistoryIfFunctionIsNotRegisteredOnKernel var chatHistory = new ChatHistory(); var chatMessageContent = new ChatMessageContent(); - chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test")); // The call for function that is not registered on the kernel + chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin")); // The call for function that is not registered on the kernel // Act await this._sut.ProcessFunctionCallsAsync( @@ -743,7 +743,7 @@ public async Task ItShouldHandleChatMessageContentAsFunctionResultAsync() var chatHistory = new ChatHistory(); var chatMessageContent = new ChatMessageContent(); - chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test")); + chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin")); // Act await this._sut.ProcessFunctionCallsAsync( @@ -779,7 +779,7 @@ public async Task ItShouldSerializeFunctionResultOfUnknownTypeAsync() var chatHistory = new ChatHistory(); var chatMessageContent = new ChatMessageContent(); - chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test")); + chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin")); // Act await this._sut.ProcessFunctionCallsAsync( @@ -870,7 +870,7 @@ public async Task ItShouldPassPromptExecutionSettingsToAutoFunctionInvocationFil }); var chatMessageContent = new ChatMessageContent(); - chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test", arguments: new KernelArguments() { ["parameter"] = "function1-result" })); + chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", arguments: new KernelArguments() { ["parameter"] = "function1-result" })); // Act await this._sut.ProcessFunctionCallsAsync( From 7344afd3489b9673826fdb6ea481a16c674733dd Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Fri, 11 Apr 2025 12:26:11 +0100 Subject: [PATCH 16/26] Fix warnings --- .../Core/AutoFunctionInvocationFilterChatClientTests.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs index cb47c48af8d7..dd8d94c99824 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs @@ -11,7 +11,6 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel; -using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.OpenAI; using Xunit; From 196fa98c426120e77593e3be38f077db316e28e8 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Fri, 11 Apr 2025 13:01:50 +0100 Subject: [PATCH 17/26] Fix 9.4.0 conflicts and errors --- .../Agents/ChatCompletion_ServiceSelection.cs | 10 ++-- .../Kernel/CustomAIServiceSelector.cs | 8 +-- .../Step06_DependencyInjection.cs | 22 ++++---- ...IServiceCollectionExtensions.ChatClient.cs | 24 ++++----- .../OpenAI/OpenAIAudioToTextTests.cs | 2 +- .../OpenAI/OpenAIChatCompletionTests.cs | 53 +++++++++---------- .../OpenAI/OpenAITextToAudioTests.cs | 2 +- .../samples/InternalUtilities/BaseTest.cs | 22 ++++---- .../AI/ChatClient/AIFunctionFactory.cs | 12 ++--- .../AI/ChatClient/ChatClientAIService.cs | 2 +- .../AI/ChatClient/ChatClientExtensions.cs | 2 +- .../KernelFunctionInvokingChatClient.cs | 2 +- .../Functions/KernelFunctionFromPrompt.cs | 4 +- .../AIFunctionKernelFunctionTests.cs | 4 +- .../CustomAIChatClientSelectorTests.cs | 2 +- .../OrderedAIServiceSelectorTests.cs | 2 +- 16 files changed, 87 insertions(+), 86 deletions(-) diff --git a/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs b/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs index a0aa892a7802..be69fe412d5e 100644 --- a/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs +++ b/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs @@ -109,11 +109,11 @@ private Kernel CreateKernelWithTwoServices(bool useChatClient) { builder.Services.AddKeyedChatClient( ServiceKeyBad, - new OpenAI.OpenAIClient("bad-key").AsChatClient(TestConfiguration.OpenAI.ChatModelId)); + new OpenAI.OpenAIClient("bad-key").GetChatClient(TestConfiguration.OpenAI.ChatModelId).AsIChatClient()); builder.Services.AddKeyedChatClient( ServiceKeyGood, - new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey).AsChatClient(TestConfiguration.OpenAI.ChatModelId)); + new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey).GetChatClient(TestConfiguration.OpenAI.ChatModelId).AsIChatClient()); } else { @@ -122,14 +122,16 @@ private Kernel CreateKernelWithTwoServices(bool useChatClient) new Azure.AI.OpenAI.AzureOpenAIClient( new Uri(TestConfiguration.AzureOpenAI.Endpoint), new Azure.AzureKeyCredential("bad-key")) - .AsChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName)); + .GetChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName) + .AsIChatClient()); builder.Services.AddKeyedChatClient( ServiceKeyGood, new Azure.AI.OpenAI.AzureOpenAIClient( new Uri(TestConfiguration.AzureOpenAI.Endpoint), new Azure.AzureKeyCredential(TestConfiguration.AzureOpenAI.ApiKey)) - .AsChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName)); + .GetChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName) + .AsIChatClient()); } } else diff --git a/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs b/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs index d4631323c24d..02ddbdb3ec35 100644 --- a/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs +++ b/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs @@ -10,7 +10,7 @@ namespace KernelExamples; /// -/// This sample shows how to use a custom AI service selector to select a specific model by matching it's id. +/// This sample shows how to use a custom AI service selector to select a specific model by matching the model id. /// public class CustomAIServiceSelector(ITestOutputHelper output) : BaseTest(output) { @@ -39,7 +39,8 @@ public async Task UsingCustomSelectToSelectServiceByMatchingModelId() builder.Services .AddSingleton(customSelector) .AddKeyedChatClient("OpenAIChatClient", new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey) - .AsChatClient("gpt-4o")); // Add a IChatClient to the kernel + .GetChatClient("gpt-4o") + .AsIChatClient()); // Add a IChatClient to the kernel Kernel kernel = builder.Build(); @@ -60,7 +61,6 @@ private sealed class GptAIServiceSelector(string modelNameStartsWith, ITestOutpu private readonly ITestOutputHelper _output = output; private readonly string _modelNameStartsWith = modelNameStartsWith; - /// private bool TrySelect( Kernel kernel, KernelFunction function, KernelArguments arguments, [NotNullWhen(true)] out T? service, out PromptExecutionSettings? serviceSettings) where T : class @@ -78,7 +78,7 @@ private bool TrySelect( else if (serviceToCheck is IChatClient chatClient) { var metadata = chatClient.GetService(); - serviceModelId = metadata?.ModelId; + serviceModelId = metadata?.DefaultModelId; endpoint = metadata?.ProviderUri?.ToString(); } diff --git a/dotnet/samples/GettingStartedWithAgents/Step06_DependencyInjection.cs b/dotnet/samples/GettingStartedWithAgents/Step06_DependencyInjection.cs index 8935e4d66d48..39106b957841 100644 --- a/dotnet/samples/GettingStartedWithAgents/Step06_DependencyInjection.cs +++ b/dotnet/samples/GettingStartedWithAgents/Step06_DependencyInjection.cs @@ -43,25 +43,25 @@ public async Task UseDependencyInjectionToCreateAgentAsync(bool useChatClient) IChatClient chatClient; if (this.UseOpenAIConfig) { - chatClient = new Microsoft.Extensions.AI.OpenAIChatClient( - new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey), - TestConfiguration.OpenAI.ChatModelId); + chatClient = new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey) + .GetChatClient(TestConfiguration.OpenAI.ChatModelId) + .AsIChatClient(); } else if (!string.IsNullOrEmpty(this.ApiKey)) { - chatClient = new Microsoft.Extensions.AI.OpenAIChatClient( - openAIClient: new AzureOpenAIClient( + chatClient = new AzureOpenAIClient( endpoint: new Uri(TestConfiguration.AzureOpenAI.Endpoint), - credential: new ApiKeyCredential(TestConfiguration.AzureOpenAI.ApiKey)), - modelId: TestConfiguration.AzureOpenAI.ChatModelId); + credential: new ApiKeyCredential(TestConfiguration.AzureOpenAI.ApiKey)) + .GetChatClient(TestConfiguration.OpenAI.ChatModelId) + .AsIChatClient(); } else { - chatClient = new Microsoft.Extensions.AI.OpenAIChatClient( - openAIClient: new AzureOpenAIClient( + chatClient = new AzureOpenAIClient( endpoint: new Uri(TestConfiguration.AzureOpenAI.Endpoint), - credential: new AzureCliCredential()), - modelId: TestConfiguration.AzureOpenAI.ChatModelId); + credential: new AzureCliCredential()) + .GetChatClient(TestConfiguration.OpenAI.ChatModelId) + .AsIChatClient(); } var functionCallingChatClient = chatClient!.AsKernelFunctionInvokingChatClient(); diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs index 2ae915869698..7fa002d806ea 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs @@ -47,11 +47,11 @@ public static IServiceCollection AddOpenAIChatClient( IChatClient Factory(IServiceProvider serviceProvider, object? _) { - ILogger? logger = serviceProvider.GetService()?.CreateLogger(); + ILogger? logger = serviceProvider.GetService()?.CreateLogger(); - return new Microsoft.Extensions.AI.OpenAIChatClient( - openAIClient: new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), options: GetClientOptions(orgId: orgId, httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider))), - modelId: modelId) + return new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), options: GetClientOptions(orgId: orgId, httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider))) + .GetChatClient(modelId) + .AsIChatClient() .AsKernelFunctionInvokingChatClient(logger); } @@ -77,11 +77,11 @@ public static IServiceCollection AddOpenAIChatClient(this IServiceCollection ser IChatClient Factory(IServiceProvider serviceProvider, object? _) { - ILogger? logger = serviceProvider.GetService()?.CreateLogger(); + ILogger? logger = serviceProvider.GetService()?.CreateLogger(); - return new Microsoft.Extensions.AI.OpenAIChatClient( - openAIClient ?? serviceProvider.GetRequiredService(), - modelId) + return (openAIClient ?? serviceProvider.GetRequiredService()) + .GetChatClient(modelId) + .AsIChatClient() .AsKernelFunctionInvokingChatClient(logger); } @@ -114,11 +114,11 @@ public static IServiceCollection AddOpenAIChatClient( IChatClient Factory(IServiceProvider serviceProvider, object? _) { - ILogger? logger = serviceProvider.GetService()?.CreateLogger(); + ILogger? logger = serviceProvider.GetService()?.CreateLogger(); - return new Microsoft.Extensions.AI.OpenAIChatClient( - openAIClient: new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), GetClientOptions(endpoint, orgId, HttpClientProvider.GetHttpClient(httpClient, serviceProvider))), - modelId: modelId) + return new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), GetClientOptions(endpoint, orgId, HttpClientProvider.GetHttpClient(httpClient, serviceProvider))) + .GetChatClient(modelId) + .AsIChatClient() .AsKernelFunctionInvokingChatClient(logger); } diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIAudioToTextTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIAudioToTextTests.cs index 90375307c533..9e1127fa8b55 100644 --- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIAudioToTextTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIAudioToTextTests.cs @@ -22,7 +22,7 @@ public sealed class OpenAIAudioToTextTests() .AddUserSecrets() .Build(); - [RetryFact]//(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] + [RetryFact] //(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] public async Task OpenAIAudioToTextTestAsync() { // Arrange diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs index fe8ff155d9c5..ddfe6b997a25 100644 --- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs @@ -12,7 +12,6 @@ using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Http.Resilience; -using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.OpenAI; using OpenAI; @@ -48,16 +47,16 @@ public async Task ItCanUseOpenAiChatForTextGenerationAsync() [Fact] public async Task ItCanUseOpenAiChatClientAndContentsAsync() { - var OpenAIConfiguration = this._configuration.GetSection("OpenAI").Get(); - Assert.NotNull(OpenAIConfiguration); - Assert.NotNull(OpenAIConfiguration.ChatModelId); - Assert.NotNull(OpenAIConfiguration.ApiKey); - Assert.NotNull(OpenAIConfiguration.ServiceId); + var openAIConfiguration = this._configuration.GetSection("OpenAI").Get(); + Assert.NotNull(openAIConfiguration); + Assert.NotNull(openAIConfiguration.ChatModelId); + Assert.NotNull(openAIConfiguration.ApiKey); + Assert.NotNull(openAIConfiguration.ServiceId); // Arrange - var openAIClient = new OpenAIClient(OpenAIConfiguration.ApiKey); + var openAIClient = new OpenAIClient(openAIConfiguration.ApiKey); var builder = Kernel.CreateBuilder(); - builder.Services.AddChatClient(openAIClient.AsChatClient(OpenAIConfiguration.ChatModelId)); + builder.Services.AddChatClient(openAIClient.GetChatClient(openAIConfiguration.ChatModelId).AsIChatClient()); var kernel = builder.Build(); var func = kernel.CreateFunctionFromPrompt( @@ -104,16 +103,16 @@ public async Task OpenAIStreamingTestAsync() [Fact] public async Task ItCanUseOpenAiStreamingChatClientAndContentsAsync() { - var OpenAIConfiguration = this._configuration.GetSection("OpenAI").Get(); - Assert.NotNull(OpenAIConfiguration); - Assert.NotNull(OpenAIConfiguration.ChatModelId); - Assert.NotNull(OpenAIConfiguration.ApiKey); - Assert.NotNull(OpenAIConfiguration.ServiceId); + var openAIConfiguration = this._configuration.GetSection("OpenAI").Get(); + Assert.NotNull(openAIConfiguration); + Assert.NotNull(openAIConfiguration.ChatModelId); + Assert.NotNull(openAIConfiguration.ApiKey); + Assert.NotNull(openAIConfiguration.ServiceId); // Arrange - var openAIClient = new OpenAIClient(OpenAIConfiguration.ApiKey); + var openAIClient = new OpenAIClient(openAIConfiguration.ApiKey); var builder = Kernel.CreateBuilder(); - builder.Services.AddChatClient(openAIClient.AsChatClient(OpenAIConfiguration.ChatModelId)); + builder.Services.AddChatClient(openAIClient.GetChatClient(openAIConfiguration.ChatModelId).AsIChatClient()); var kernel = builder.Build(); var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin"); @@ -179,7 +178,7 @@ public async Task OpenAIHttpRetryPolicyTestAsync() // Assert Assert.All(statusCodes, s => Assert.Equal(HttpStatusCode.Unauthorized, s)); - Assert.Equal(HttpStatusCode.Unauthorized, ((HttpOperationException)exception).StatusCode); + Assert.Equal(HttpStatusCode.Unauthorized, exception.StatusCode); } [Fact] @@ -258,11 +257,11 @@ public async Task SemanticKernelVersionHeaderIsSentAsync() var kernel = this.CreateAndInitializeKernel(httpClient); // Act - var result = await kernel.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?"); + await kernel.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?"); // Assert Assert.NotNull(httpHeaderHandler.RequestHeaders); - Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var values)); + Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var _)); } //[Theory(Skip = "This test is for manual verification.")] @@ -301,18 +300,18 @@ public async Task LogProbsDataIsReturnedWhenRequestedAsync(bool? logprobs, int? private Kernel CreateAndInitializeKernel(HttpClient? httpClient = null) { - var OpenAIConfiguration = this._configuration.GetSection("OpenAI").Get(); - Assert.NotNull(OpenAIConfiguration); - Assert.NotNull(OpenAIConfiguration.ChatModelId); - Assert.NotNull(OpenAIConfiguration.ApiKey); - Assert.NotNull(OpenAIConfiguration.ServiceId); + var openAIConfiguration = this._configuration.GetSection("OpenAI").Get(); + Assert.NotNull(openAIConfiguration); + Assert.NotNull(openAIConfiguration.ChatModelId); + Assert.NotNull(openAIConfiguration.ApiKey); + Assert.NotNull(openAIConfiguration.ServiceId); - var kernelBuilder = base.CreateKernelBuilder(); + var kernelBuilder = this.CreateKernelBuilder(); kernelBuilder.AddOpenAIChatCompletion( - modelId: OpenAIConfiguration.ChatModelId, - apiKey: OpenAIConfiguration.ApiKey, - serviceId: OpenAIConfiguration.ServiceId, + modelId: openAIConfiguration.ChatModelId, + apiKey: openAIConfiguration.ApiKey, + serviceId: openAIConfiguration.ServiceId, httpClient: httpClient); return kernelBuilder.Build(); diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAITextToAudioTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAITextToAudioTests.cs index c2818abe2502..420295fe4349 100644 --- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAITextToAudioTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAITextToAudioTests.cs @@ -18,7 +18,7 @@ public sealed class OpenAITextToAudioTests .AddUserSecrets() .Build(); - [Fact]//(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] + [Fact] //(Skip = "OpenAI will often throttle requests. This test is for manual verification.")] public async Task OpenAITextToAudioTestAsync() { // Arrange diff --git a/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs b/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs index 2fefb6ee9d16..3a8d561f4eaf 100644 --- a/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs +++ b/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs @@ -99,25 +99,25 @@ protected IChatClient AddChatClientToKernel(IKernelBuilder builder) IChatClient chatClient; if (this.UseOpenAIConfig) { - chatClient = new Microsoft.Extensions.AI.OpenAIChatClient( - new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey), - TestConfiguration.OpenAI.ChatModelId); + chatClient = new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey) + .GetChatClient(TestConfiguration.OpenAI.ChatModelId) + .AsIChatClient(); } else if (!string.IsNullOrEmpty(this.ApiKey)) { - chatClient = new Microsoft.Extensions.AI.OpenAIChatClient( - openAIClient: new AzureOpenAIClient( + chatClient = new AzureOpenAIClient( endpoint: new Uri(TestConfiguration.AzureOpenAI.Endpoint), - credential: new ApiKeyCredential(TestConfiguration.AzureOpenAI.ApiKey)), - modelId: TestConfiguration.AzureOpenAI.ChatDeploymentName); + credential: new ApiKeyCredential(TestConfiguration.AzureOpenAI.ApiKey)) + .GetChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName) + .AsIChatClient(); } else { - chatClient = new Microsoft.Extensions.AI.OpenAIChatClient( - openAIClient: new AzureOpenAIClient( + chatClient = new AzureOpenAIClient( endpoint: new Uri(TestConfiguration.AzureOpenAI.Endpoint), - credential: new AzureCliCredential()), - modelId: TestConfiguration.AzureOpenAI.ChatDeploymentName); + credential: new AzureCliCredential()) + .GetChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName) + .AsIChatClient(); } var functionCallingChatClient = chatClient!.AsKernelFunctionInvokingChatClient(); diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs index a0d6b1865a8f..e7afdd77e153 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs @@ -194,8 +194,8 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, public override MethodInfo UnderlyingMethod => FunctionDescriptor.Method; public override JsonElement JsonSchema => FunctionDescriptor.JsonSchema; public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions; - protected override Task InvokeCoreAsync( - IEnumerable>? arguments, + protected override ValueTask InvokeCoreAsync( + AIFunctionArguments? arguments, CancellationToken cancellationToken) { var paramMarshallers = FunctionDescriptor.ParameterMarshallers; @@ -283,7 +283,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions public JsonSerializerOptions JsonSerializerOptions { get; } public JsonElement JsonSchema { get; } public Func, CancellationToken, object?>[] ParameterMarshallers { get; } - public Func> ReturnParameterMarshaller { get; } + public Func> ReturnParameterMarshaller { get; } public ReflectionAIFunction? CachedDefaultInstance { get; set; } private static string GetFunctionName(MethodInfo method) @@ -395,7 +395,7 @@ static bool IsAsyncMethod(MethodInfo method) /// /// Gets a delegate for handling the result value of a method, converting it into the to return from the invocation. /// - private static Func> GetReturnParameterMarshaller(MethodInfo method, JsonSerializerOptions serializerOptions) + private static Func> GetReturnParameterMarshaller(MethodInfo method, JsonSerializerOptions serializerOptions) { Type returnType = method.ReturnType; JsonTypeInfo returnTypeInfo; @@ -403,7 +403,7 @@ static bool IsAsyncMethod(MethodInfo method) // Void if (returnType == typeof(void)) { - return static (_, _) => Task.FromResult(null); + return static (_, _) => new ValueTask((object?)null); } // Task @@ -461,7 +461,7 @@ static bool IsAsyncMethod(MethodInfo method) returnTypeInfo = serializerOptions.GetTypeInfo(returnType); return (result, cancellationToken) => SerializeResultAsync(result, returnTypeInfo, cancellationToken); - static async Task SerializeResultAsync(object? result, JsonTypeInfo returnTypeInfo, CancellationToken cancellationToken) + static async ValueTask SerializeResultAsync(object? result, JsonTypeInfo returnTypeInfo, CancellationToken cancellationToken) { if (returnTypeInfo.Kind is JsonTypeInfoKind.None) { diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs index 8a5abc42a6e0..b840b33e690b 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs @@ -33,7 +33,7 @@ internal ChatClientAIService(IChatClient chatClient) var metadata = this._chatClient.GetService(); Verify.NotNull(metadata); - this._internalAttributes[nameof(metadata.ModelId)] = metadata.ModelId; + this._internalAttributes["ModelId"] = metadata.DefaultModelId; this._internalAttributes[nameof(metadata.ProviderName)] = metadata.ProviderName; this._internalAttributes[nameof(metadata.ProviderUri)] = metadata.ProviderUri; } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs index 6d6b80fe28bb..0ffdd5fec99d 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs @@ -85,7 +85,7 @@ public static IChatCompletionService AsChatCompletionService(this IChatClient cl { Verify.NotNull(client); - return client.GetService()?.ModelId; + return client.GetService()?.DefaultModelId; } /// diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index 1d48111b47af..10552971aba8 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -844,7 +844,7 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here, // as the called function could in turn telling the model about itself as a possible candidate for invocation. - result = await invocationContext.AIFunction.InvokeAsync(autoFunctionInvocationContext.Arguments, cancellationToken).ConfigureAwait(false); + result = await invocationContext.AIFunction.InvokeAsync(new(autoFunctionInvocationContext.Arguments), cancellationToken).ConfigureAwait(false); context.Result = new FunctionResult(context.Function, result); }).ConfigureAwait(false); result = autoFunctionInvocationContext.Result.GetValue(); diff --git a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs index 6446dd1b97d5..8037f513ee65 100644 --- a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs +++ b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs @@ -818,10 +818,10 @@ private async Task GetChatClientResultAsync( }; } - var modelId = chatClient.GetService()?.ModelId; + var modelId = chatClient.GetService()?.DefaultModelId; // Usage details are global and duplicated for each chat message content, use first one to get usage information - this.CaptureUsageDetails(chatClient.GetService()?.ModelId, chatResponse.Usage, this._logger); + this.CaptureUsageDetails(chatClient.GetService()?.DefaultModelId, chatResponse.Usage, this._logger); return new FunctionResult(this, chatResponse) { diff --git a/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/AIFunctionKernelFunctionTests.cs b/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/AIFunctionKernelFunctionTests.cs index 065784118c3d..20a2460e1751 100644 --- a/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/AIFunctionKernelFunctionTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/AIFunctionKernelFunctionTests.cs @@ -14,8 +14,8 @@ public class AIFunctionKernelFunctionTests public void ShouldAssignIsRequiredParameterMetadataPropertyCorrectly() { // Arrange and Act - AIFunction aiFunction = AIFunctionFactory.Create((string p1, int? p2 = null) => p1, - new AIFunctionFactoryOptions { JsonSchemaCreateOptions = new AIJsonSchemaCreateOptions { RequireAllProperties = false } }); + AIFunction aiFunction = Microsoft.Extensions.AI.AIFunctionFactory.Create((string p1, int? p2 = null) => p1, + new Microsoft.Extensions.AI.AIFunctionFactoryOptions { JsonSchemaCreateOptions = new AIJsonSchemaCreateOptions { RequireAllProperties = false } }); AIFunctionKernelFunction sut = new(aiFunction); diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs index 322c5f3b935f..7400875cdc9b 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs @@ -70,7 +70,7 @@ private sealed class ChatClientTest : IChatClient public ChatClientTest() { - this._metadata = new ChatClientMetadata(modelId: "Value1"); + this._metadata = new ChatClientMetadata(defaultModelId: "Value1"); } public void Dispose() diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs index b0a57fa8cdf6..52619c22af6e 100644 --- a/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs +++ b/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs @@ -527,7 +527,7 @@ private sealed class ChatClient : IChatClient public ChatClient(string modelId) { - this.Metadata = new ChatClientMetadata(modelId: modelId); + this.Metadata = new ChatClientMetadata(defaultModelId: modelId); } public Task> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) From 83c2e18389303c1715db6dbaba49ac409ecdcf47 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Fri, 11 Apr 2025 17:21:47 +0100 Subject: [PATCH 18/26] Internalizing components --- .../KernelFunctionInvocationContext.cs | 6 +- .../KernelFunctionInvokingChatClient.cs | 95 ++++++++-------- .../AutoFunctionInvocationContext.cs | 106 ++++++++++++------ 3 files changed, 123 insertions(+), 84 deletions(-) diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs index a36f03630fd5..c8f293b2bb0c 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs @@ -16,10 +16,10 @@ namespace Microsoft.SemanticKernel.ChatCompletion; /// Provides context for an in-flight function invocation. [ExcludeFromCodeCoverage] -public class KernelFunctionInvocationContext +internal class KernelFunctionInvocationContext { /// - /// A nop function used to allow to be non-nullable. Default instances of + /// A nop function used to allow to be non-nullable. Default instances of /// start with this as the target function. /// private static readonly AIFunction s_nopFunction = AIFunctionFactory.Create(() => { }, nameof(KernelFunctionInvocationContext)); @@ -64,7 +64,7 @@ public IList Messages public ChatOptions? Options { get; set; } /// Gets or sets the AI function to be invoked. - public AIFunction AIFunction + public AIFunction Function { get => _function; set diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index 10552971aba8..3a5eb88b5cf3 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -47,8 +47,8 @@ namespace Microsoft.SemanticKernel.ChatCompletion; [ExcludeFromCodeCoverage] internal sealed partial class KernelFunctionInvokingChatClient : DelegatingChatClient { - /// The for the current function invocation. - private static readonly AsyncLocal s_currentContext = new(); + /// The for the current function invocation. + private static readonly AsyncLocal s_currentContext = new(); /// The logger to use for logging information about function invocation. private readonly ILogger _logger; @@ -78,7 +78,7 @@ public KernelFunctionInvokingChatClient(IChatClient innerClient, ILogger? logger /// /// This value flows across async calls. /// - internal static KernelFunctionInvocationContext? CurrentContext + internal static AutoFunctionInvocationContext? CurrentContext { get => s_currentContext.Value; set => s_currentContext.Value = value; @@ -120,13 +120,13 @@ internal static KernelFunctionInvocationContext? CurrentContext /// /// /// Setting the value to can help the underlying bypass problems on - /// its own, for example by retrying the function call with different arguments. However it might + /// its own, for example by retrying the function call with different arguments. However, it might /// result in disclosing the raw exception information to external users, which can be a security /// concern depending on the application scenario. /// /// /// Changing the value of this property while the client is in use might result in inconsistencies - /// as to whether detailed errors are provided during an in-flight request. + /// whether detailed errors are provided during an in-flight request. /// /// public bool IncludeDetailedErrors { get; set; } @@ -227,7 +227,7 @@ public override async Task GetResponseAsync( return response; } - // Track aggregatable details from the response, including all of the response messages and usage details. + // Track aggregate details from the response, including all the response messages and usage details. (responseMessages ??= []).AddRange(response.Messages); if (response.Usage is not null) { @@ -257,13 +257,13 @@ public override async Task GetResponseAsync( // Add the responses from the function calls into the augmented history and also into the tracked // list of response messages. - var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, isStreaming: false, cancellationToken).ConfigureAwait(false); + var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents!, iteration, isStreaming: false, cancellationToken).ConfigureAwait(false); responseMessages.AddRange(modeAndMessages.MessagesAdded); // Clear the auto function invocation options. - ClearOptionsForAutoFunctionInvocation(ref options!); + ClearOptionsForAutoFunctionInvocation(ref options); - if (UpdateOptionsForMode(modeAndMessages.Mode, ref options!, response.ChatThreadId)) + if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, response.ChatThreadId)) { // Terminate break; @@ -271,7 +271,7 @@ public override async Task GetResponseAsync( } Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages."); - response.Messages = responseMessages!; + response.Messages = responseMessages; response.Usage = totalUsage; return response; @@ -334,14 +334,14 @@ public override async IAsyncEnumerable GetStreamingResponseA FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId); // Prepare the options for the next auto function invocation iteration. - UpdateOptionsForAutoFunctionInvocation(ref options!, response.Messages.Last().ToChatMessageContent(), isStreaming: true); + UpdateOptionsForAutoFunctionInvocation(ref options, response.Messages.Last().ToChatMessageContent(), isStreaming: true); // Process all the functions, adding their results into the history. var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, isStreaming: true, cancellationToken).ConfigureAwait(false); responseMessages.AddRange(modeAndMessages.MessagesAdded); // Clear the auto function invocation options. - ClearOptionsForAutoFunctionInvocation(ref options!); + ClearOptionsForAutoFunctionInvocation(ref options); // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages // include all activities, including generated function results. @@ -408,7 +408,7 @@ private static void FixupHistories( { // In the very rare case where the inner client returned a response with a thread ID but then // returned a subsequent response without one, we want to reconstitute the full history. To do that, - // we can populate the history with the original chat messages and then all of the response + // we can populate the history with the original chat messages and then all the response // messages up until this point, which includes the most recent ones. augmentedHistory ??= []; augmentedHistory.Clear(); @@ -667,16 +667,16 @@ private async Task ProcessFunctionCallAsync( callContent.Arguments = new KernelArguments(callContent.Arguments); } - var context = new AutoFunctionInvocationContext(options) + var context = new AutoFunctionInvocationContext(new KernelFunctionInvocationContext { + Options = options, Messages = messages, CallContent = callContent, - AIFunction = function, + Function = function, Iteration = iteration, FunctionCallIndex = functionCallIndex, FunctionCount = callContents.Count, - IsStreaming = isStreaming - }; + }) { IsStreaming = isStreaming }; object? result; try @@ -721,9 +721,9 @@ internal IList CreateResponseMessages( IReadOnlyList results) { var contents = new List(results.Count); - for (int i = 0; i < results.Count; i++) + foreach (var t in results) { - contents.Add(CreateFunctionResultContent(results[i])); + contents.Add(CreateFunctionResultContent(t)); } return [new(ChatRole.Tool, contents)]; @@ -778,20 +778,20 @@ private async Task OnAutoFunctionInvocationAsync( /// If there are no registered filters, just function will be executed. /// If there are registered filters, filter on position will be executed. /// Second parameter of filter is callback. It can be either filter on + 1 position or function if there are no remaining filters to execute. - /// Function will be always executed as last step after all filters. + /// Function will always be executed as last step after all filters. /// private async Task InvokeFilterOrFunctionAsync( Func functionCallCallback, AutoFunctionInvocationContext context, int index = 0) { - IList? autoFunctionInvocationFilters = context.Kernel.AutoFunctionInvocationFilters; + IList autoFunctionInvocationFilters = context.Kernel.AutoFunctionInvocationFilters; if (autoFunctionInvocationFilters is { Count: > 0 } && index < autoFunctionInvocationFilters.Count) { await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( context, - (context) => this.InvokeFilterOrFunctionAsync(functionCallCallback, context, index + 1) + (ctx) => this.InvokeFilterOrFunctionAsync(functionCallCallback, ctx, index + 1) ).ConfigureAwait(false); } else @@ -805,11 +805,11 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( /// The to monitor for cancellation requests. The default is . /// The result of the function invocation, or if the function invocation returned . /// is . - internal async Task InvokeFunctionAsync(KernelFunctionInvocationContext invocationContext, CancellationToken cancellationToken) + private async Task InvokeFunctionAsync(AutoFunctionInvocationContext invocationContext, CancellationToken cancellationToken) { Verify.NotNull(invocationContext); - using Activity? activity = this._activitySource?.StartActivity(invocationContext.AIFunction.Name); + using Activity? activity = this._activitySource?.StartActivity(invocationContext.Function.Name); long startingTimestamp = 0; if (this._logger.IsEnabled(LogLevel.Debug)) @@ -817,11 +817,11 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( startingTimestamp = Stopwatch.GetTimestamp(); if (this._logger.IsEnabled(LogLevel.Trace)) { - this.LogInvokingSensitive(invocationContext.AIFunction.Name, LoggingAsJson(invocationContext.CallContent.Arguments, invocationContext.AIFunction.JsonSerializerOptions)); + this.LogInvokingSensitive(invocationContext.Function.Name, LoggingAsJson(invocationContext.CallContent.Arguments, invocationContext.AIFunction.JsonSerializerOptions)); } else { - this.LogInvoking(invocationContext.AIFunction.Name); + this.LogInvoking(invocationContext.Function.Name); } } @@ -829,26 +829,23 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( try { CurrentContext = invocationContext; - if (invocationContext is AutoFunctionInvocationContext autoFunctionInvocationContext) - { - autoFunctionInvocationContext = await this.OnAutoFunctionInvocationAsync( - autoFunctionInvocationContext, - async (context) => + invocationContext = await this.OnAutoFunctionInvocationAsync( + invocationContext, + async (context) => + { + // Check if filter requested termination + if (context.Terminate) { - // Check if filter requested termination - if (context.Terminate) - { - return; - } - - // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any - // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here, - // as the called function could in turn telling the model about itself as a possible candidate for invocation. - result = await invocationContext.AIFunction.InvokeAsync(new(autoFunctionInvocationContext.Arguments), cancellationToken).ConfigureAwait(false); - context.Result = new FunctionResult(context.Function, result); - }).ConfigureAwait(false); - result = autoFunctionInvocationContext.Result.GetValue(); - } + return; + } + + // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any + // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here, + // as the called function could in turn telling the model about itself as a possible candidate for invocation. + result = await invocationContext.AIFunction.InvokeAsync(new(invocationContext.Arguments), cancellationToken).ConfigureAwait(false); + context.Result = new FunctionResult(context.Function, result); + }).ConfigureAwait(false); + result = invocationContext.Result.GetValue(); } catch (Exception e) { @@ -860,11 +857,11 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( if (e is OperationCanceledException) { - this.LogInvocationCanceled(invocationContext.AIFunction.Name); + this.LogInvocationCanceled(invocationContext.Function.Name); } else { - this.LogInvocationFailed(invocationContext.AIFunction.Name, e); + this.LogInvocationFailed(invocationContext.Function.Name, e); } throw; @@ -877,11 +874,11 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( if (result is not null && this._logger.IsEnabled(LogLevel.Trace)) { - this.LogInvocationCompletedSensitive(invocationContext.AIFunction.Name, elapsed, LoggingAsJson(result, invocationContext.AIFunction.JsonSerializerOptions)); + this.LogInvocationCompletedSensitive(invocationContext.Function.Name, elapsed, LoggingAsJson(result, invocationContext.AIFunction.JsonSerializerOptions)); } else { - this.LogInvocationCompleted(invocationContext.AIFunction.Name, elapsed); + this.LogInvocationCompleted(invocationContext.Function.Name, elapsed); } } } diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs index d46e006e7ebb..f70205365330 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs @@ -13,30 +13,32 @@ namespace Microsoft.SemanticKernel; /// /// Class with data related to automatic function invocation. /// -public class AutoFunctionInvocationContext : KernelFunctionInvocationContext +public class AutoFunctionInvocationContext { private ChatHistory? _chatHistory; private KernelFunction? _kernelFunction; + private readonly KernelFunctionInvocationContext _invocationContext = new(); /// /// Initializes a new instance of the class from an existing . /// - public AutoFunctionInvocationContext(ChatOptions options) + internal AutoFunctionInvocationContext(KernelFunctionInvocationContext invocationContext) { - this.Options = options; + Verify.NotNull(invocationContext); + Verify.NotNull(invocationContext.Options); - // To create a AutoFunctionInvocationContext from a KernelFunctionInvocationContext, // the ChatOptions must be provided with AdditionalProperties. - Verify.NotNull(options.AdditionalProperties); + Verify.NotNull(invocationContext.Options.AdditionalProperties); - options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.KernelKey, out var kernel); + invocationContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.KernelKey, out var kernel); Verify.NotNull(kernel); - options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.ChatMessageContentKey, out var chatMessageContent); + invocationContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.ChatMessageContentKey, out var chatMessageContent); Verify.NotNull(chatMessageContent); - options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.PromptExecutionSettingsKey, out var executionSettings); + invocationContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.PromptExecutionSettingsKey, out var executionSettings); this.ExecutionSettings = executionSettings; + this._invocationContext = invocationContext; this.Result = new FunctionResult(this.Function) { Culture = kernel.Culture }; } @@ -62,7 +64,7 @@ public AutoFunctionInvocationContext( Verify.NotNull(chatHistory); Verify.NotNull(chatMessageContent); - this.Options = new() + this._invocationContext.Options = new() { AdditionalProperties = new() { @@ -73,9 +75,9 @@ public AutoFunctionInvocationContext( this._kernelFunction = function; this._chatHistory = chatHistory; - this.Messages = chatHistory.ToChatMessageList(); - chatHistory.SetChatMessageHandlers(this.Messages); - this.AIFunction = function.AsAIFunction(); + this._invocationContext.Messages = chatHistory.ToChatMessageList(); + chatHistory.SetChatMessageHandlers(this._invocationContext.Messages); + this._invocationContext.Function = function.AsAIFunction(); this.Result = result; } @@ -95,8 +97,8 @@ public AutoFunctionInvocationContext( /// public KernelArguments? Arguments { - get => this.CallContent.Arguments is KernelArguments kernelArguments ? kernelArguments : null; - init => this.CallContent.Arguments = value; + get => this._invocationContext.CallContent.Arguments is KernelArguments kernelArguments ? kernelArguments : null; + init => this._invocationContext.CallContent.Arguments = value; } /// @@ -104,8 +106,8 @@ public KernelArguments? Arguments /// public int RequestSequenceIndex { - get => this.Iteration; - init => this.Iteration = value; + get => this._invocationContext.Iteration; + init => this._invocationContext.Iteration = value; } /// @@ -113,8 +115,19 @@ public int RequestSequenceIndex /// public int FunctionSequenceIndex { - get => this.FunctionCallIndex; - init => this.FunctionCallIndex = value; + get => this._invocationContext.FunctionCallIndex; + init => this._invocationContext.FunctionCallIndex = value; + } + + /// Gets or sets the total number of function call requests within the iteration. + /// + /// The response from the underlying client might include multiple function call requests. + /// This count indicates how many there were. + /// + public int FunctionCount + { + get => this._invocationContext.FunctionCount; + init => this._invocationContext.FunctionCount = value; } /// @@ -122,36 +135,40 @@ public int FunctionSequenceIndex /// public string? ToolCallId { - get => this.CallContent.CallId; + get => this._invocationContext.CallContent.CallId; init { - this.CallContent = new Microsoft.Extensions.AI.FunctionCallContent(value ?? string.Empty, this.CallContent.Name, this.CallContent.Arguments); + this._invocationContext.CallContent = new Microsoft.Extensions.AI.FunctionCallContent( + callId: value ?? string.Empty, + name: this._invocationContext.CallContent.Name, + arguments: this._invocationContext.CallContent.Arguments); } } /// /// The chat message content associated with automatic function invocation. /// - public ChatMessageContent ChatMessageContent => (this.Options?.AdditionalProperties?[ChatOptionsExtensions.ChatMessageContentKey] as ChatMessageContent)!; + public ChatMessageContent ChatMessageContent + => (this._invocationContext.Options?.AdditionalProperties?[ChatOptionsExtensions.ChatMessageContentKey] as ChatMessageContent)!; /// /// The execution settings associated with the operation. /// public PromptExecutionSettings? ExecutionSettings { - get => this.Options?.AdditionalProperties?[ChatOptionsExtensions.PromptExecutionSettingsKey] as PromptExecutionSettings; + get => this._invocationContext.Options?.AdditionalProperties?[ChatOptionsExtensions.PromptExecutionSettingsKey] as PromptExecutionSettings; init { - this.Options ??= new(); - this.Options.AdditionalProperties ??= []; - this.Options.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = value; + this._invocationContext.Options ??= new(); + this._invocationContext.Options.AdditionalProperties ??= []; + this._invocationContext.Options.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = value; } } /// /// Gets the associated with automatic function invocation. /// - public ChatHistory ChatHistory => this._chatHistory ??= new ChatMessageHistory(this.Messages); + public ChatHistory ChatHistory => this._chatHistory ??= new ChatMessageHistory(this._invocationContext.Messages); /// /// Gets the with which this filter is associated. @@ -163,9 +180,9 @@ public KernelFunction Function if (this._kernelFunction is null // If the schemas are different, // AIFunction reference potentially was modified and the kernel function should be regenerated. - || !IsSameSchema(this._kernelFunction, this.AIFunction)) + || !IsSameSchema(this._kernelFunction, this._invocationContext.Function)) { - this._kernelFunction = this.AIFunction.AsKernelFunction(); + this._kernelFunction = this._invocationContext.Function.AsKernelFunction(); } return this._kernelFunction; @@ -180,7 +197,7 @@ public Kernel Kernel get { Kernel? kernel = null; - this.Options?.AdditionalProperties?.TryGetValue(ChatOptionsExtensions.KernelKey, out kernel); + this._invocationContext.Options?.AdditionalProperties?.TryGetValue(ChatOptionsExtensions.KernelKey, out kernel); // To avoid exception from properties, when attempting to retrieve a kernel from a non-ready context, it will give a null. return kernel!; @@ -192,6 +209,32 @@ public Kernel Kernel /// public FunctionResult Result { get; set; } + /// Gets or sets a value indicating whether to terminate the request. + /// + /// In response to a function call request, the function might be invoked, its result added to the chat contents, + /// and a new request issued to the wrapped client. If this property is set to , that subsequent request + /// will not be issued and instead the loop immediately terminated rather than continuing until there are no + /// more function call requests in responses. + /// + public bool Terminate + { + get => this._invocationContext.Terminate; + set => this._invocationContext.Terminate = value; + } + + /// Gets or sets the function call content information associated with this invocation. + internal Microsoft.Extensions.AI.FunctionCallContent CallContent + { + get => this._invocationContext.CallContent; + set => this._invocationContext.CallContent = value; + } + + internal AIFunction AIFunction + { + get => this._invocationContext.Function; + set => this._invocationContext.Function = value; + } + private static bool IsSameSchema(KernelFunction kernelFunction, AIFunction aiFunction) { // Compares the schemas, should be similar. @@ -305,11 +348,10 @@ IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)this).GetEnumerator(); } - /// Destructor to clear the overrides and update the message. + /// Destructor to clear the chat history overrides. ~AutoFunctionInvocationContext() { - // The moment this class is destroyed, we need to clear the overrides and - // overrides to update to message + // The moment this class is destroyed, we need to clear the update message overrides this._chatHistory?.ClearOverrides(); } } From c3b89d825b3e6d31271d0a5ce514f1c8feba8966 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Mon, 14 Apr 2025 10:53:39 +0100 Subject: [PATCH 19/26] Starting update of FunctionInvokingChatClient --- dotnet/Directory.Packages.props | 3 +- .../AI/ChatClient/AIFunctionFactory.cs | 631 ------------- .../AI/ChatClient/FunctionFactoryOptions.cs | 63 -- .../KernelFunctionInvokingChatClient.cs | 4 +- .../KernelFunctionInvokingChatClientV2.cs | 873 ++++++++++++++++++ .../SemanticKernel.Abstractions.csproj | 2 +- 6 files changed, 877 insertions(+), 699 deletions(-) delete mode 100644 dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs delete mode 100644 dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/FunctionFactoryOptions.cs create mode 100644 dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientV2.cs diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 7878b0d5b359..94e6c8de85ef 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -93,7 +93,6 @@ - @@ -109,7 +108,7 @@ - + diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs deleted file mode 100644 index e7afdd77e153..000000000000 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs +++ /dev/null @@ -1,631 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Buffers; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.ComponentModel; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.IO; -using System.Linq; -using System.Reflection; -using System.Runtime.CompilerServices; -using System.Text.Json; -using System.Text.Json.Nodes; -using System.Text.Json.Serialization.Metadata; -using System.Text.RegularExpressions; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.AI; - -#pragma warning disable IDE0009 // Use explicit 'this.' qualifier -#pragma warning disable IDE1006 // Missing static prefix s_ suffix - -namespace Microsoft.SemanticKernel.ChatCompletion; - -// Slight modified source from -// https://raw.githubusercontent.com/dotnet/extensions/refs/heads/main/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs - -/// Provides factory methods for creating commonly used implementations of . -[ExcludeFromCodeCoverage] -internal static partial class AIFunctionFactory -{ - /// Holds the default options instance used when creating function. - private static readonly AIFunctionFactoryOptions _defaultOptions = new(); - - /// Creates an instance for a method, specified via a delegate. - /// The method to be represented via the created . - /// Metadata to use to override defaults inferred from . - /// The created for invoking . - /// - /// - /// Return values are serialized to using 's - /// . Arguments that are not already of the expected type are - /// marshaled to the expected type via JSON and using 's - /// . If the argument is a , - /// , or , it is deserialized directly. If the argument is anything else unknown, - /// it is round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. - /// - /// - public static AIFunction Create(Delegate method, AIFunctionFactoryOptions? options) - { - Verify.NotNull(method); - - return ReflectionAIFunction.Build(method.Method, method.Target, options ?? _defaultOptions); - } - - /// Creates an instance for a method, specified via a delegate. - /// The method to be represented via the created . - /// The name to use for the . - /// The description to use for the . - /// The used to marshal function parameters and any return value. - /// The created for invoking . - /// - /// - /// Return values are serialized to using . - /// Arguments that are not already of the expected type are marshaled to the expected type via JSON and using - /// . If the argument is a , , - /// or , it is deserialized directly. If the argument is anything else unknown, it is - /// round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. - /// - /// - public static AIFunction Create(Delegate method, string? name = null, string? description = null, JsonSerializerOptions? serializerOptions = null) - { - Verify.NotNull(method); - - AIFunctionFactoryOptions createOptions = serializerOptions is null && name is null && description is null - ? _defaultOptions - : new() - { - Name = name, - Description = description, - SerializerOptions = serializerOptions, - }; - - return ReflectionAIFunction.Build(method.Method, method.Target, createOptions); - } - - /// - /// Creates an instance for a method, specified via an instance - /// and an optional target object if the method is an instance method. - /// - /// The method to be represented via the created . - /// - /// The target object for the if it represents an instance method. - /// This should be if and only if is a static method. - /// - /// Metadata to use to override defaults inferred from . - /// The created for invoking . - /// - /// - /// Return values are serialized to using 's - /// . Arguments that are not already of the expected type are - /// marshaled to the expected type via JSON and using 's - /// . If the argument is a , - /// , or , it is deserialized directly. If the argument is anything else unknown, - /// it is round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. - /// - /// - public static AIFunction Create(MethodInfo method, object? target, AIFunctionFactoryOptions? options) - { - Verify.NotNull(method); - - return ReflectionAIFunction.Build(method, target, options ?? _defaultOptions); - } - - /// - /// Creates an instance for a method, specified via an instance - /// and an optional target object if the method is an instance method. - /// - /// The method to be represented via the created . - /// - /// The target object for the if it represents an instance method. - /// This should be if and only if is a static method. - /// - /// The name to use for the . - /// The description to use for the . - /// The used to marshal function parameters and return value. - /// The created for invoking . - /// - /// - /// Return values are serialized to using . - /// Arguments that are not already of the expected type are marshaled to the expected type via JSON and using - /// . If the argument is a , , - /// or , it is deserialized directly. If the argument is anything else unknown, it is - /// round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type. - /// - /// - public static AIFunction Create(MethodInfo method, object? target, string? name = null, string? description = null, JsonSerializerOptions? serializerOptions = null) - { - Verify.NotNull(method); - - AIFunctionFactoryOptions createOptions = serializerOptions is null && name is null && description is null - ? _defaultOptions - : new() - { - Name = name, - Description = description, - SerializerOptions = serializerOptions, - }; - - return ReflectionAIFunction.Build(method, target, createOptions); - } - - private sealed class ReflectionAIFunction : AIFunction - { - public static ReflectionAIFunction Build(MethodInfo method, object? target, AIFunctionFactoryOptions options) - { - Verify.NotNull(method); - - if (method.ContainsGenericParameters) - { - throw new ArgumentException("Open generic methods are not supported", nameof(method)); - } - - if (!method.IsStatic && target is null) - { - throw new ArgumentNullException(nameof(target), "Target must not be null for an instance method."); - } - - ReflectionAIFunctionDescriptor functionDescriptor = ReflectionAIFunctionDescriptor.GetOrCreate(method, options); - - if (target is null && options.AdditionalProperties is null) - { - // We can use a cached value for static methods not specifying additional properties. - return functionDescriptor.CachedDefaultInstance ??= new(functionDescriptor, target, options); - } - - return new(functionDescriptor, target, options); - } - - private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, object? target, AIFunctionFactoryOptions options) - { - FunctionDescriptor = functionDescriptor; - Target = target; - AdditionalProperties = options.AdditionalProperties ?? EmptyReadOnlyDictionary.Instance; - } - - public ReflectionAIFunctionDescriptor FunctionDescriptor { get; } - public object? Target { get; } - public override IReadOnlyDictionary AdditionalProperties { get; } - public override string Name => FunctionDescriptor.Name; - public override string Description => FunctionDescriptor.Description; - public override MethodInfo UnderlyingMethod => FunctionDescriptor.Method; - public override JsonElement JsonSchema => FunctionDescriptor.JsonSchema; - public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions; - protected override ValueTask InvokeCoreAsync( - AIFunctionArguments? arguments, - CancellationToken cancellationToken) - { - var paramMarshallers = FunctionDescriptor.ParameterMarshallers; - object?[] args = paramMarshallers.Length != 0 ? new object?[paramMarshallers.Length] : []; - - IReadOnlyDictionary argDict = - arguments is null || args.Length == 0 ? EmptyReadOnlyDictionary.Instance : - arguments as IReadOnlyDictionary ?? - arguments. -#if NET8_0_OR_GREATER - ToDictionary(); -#else - ToDictionary(kvp => kvp.Key, kvp => kvp.Value); -#endif - for (int i = 0; i < args.Length; i++) - { - args[i] = paramMarshallers[i](argDict, cancellationToken); - } - - return FunctionDescriptor.ReturnParameterMarshaller(ReflectionInvoke(FunctionDescriptor.Method, Target, args), cancellationToken); - } - } - - /// - /// A descriptor for a .NET method-backed AIFunction that precomputes its marshalling delegates and JSON schema. - /// - private sealed class ReflectionAIFunctionDescriptor - { - private const int InnerCacheSoftLimit = 512; - private static readonly ConditionalWeakTable> _descriptorCache = new(); - - /// A boxed . - private static readonly object? _boxedDefaultCancellationToken = default(CancellationToken); - - /// - /// Gets or creates a descriptors using the specified method and options. - /// - public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFunctionFactoryOptions options) - { - JsonSerializerOptions serializerOptions = options.SerializerOptions ?? AIJsonUtilities.DefaultOptions; - AIJsonSchemaCreateOptions schemaOptions = options.JsonSchemaCreateOptions ?? AIJsonSchemaCreateOptions.Default; - serializerOptions.MakeReadOnly(); - ConcurrentDictionary innerCache = _descriptorCache.GetOrCreateValue(serializerOptions); - - DescriptorKey key = new(method, options.Name, options.Description, schemaOptions); - if (innerCache.TryGetValue(key, out ReflectionAIFunctionDescriptor? descriptor)) - { - return descriptor; - } - - descriptor = new(key, serializerOptions); - return innerCache.Count < InnerCacheSoftLimit - ? innerCache.GetOrAdd(key, descriptor) - : descriptor; - } - - private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions serializerOptions) - { - // Get marshaling delegates for parameters. - ParameterInfo[] parameters = key.Method.GetParameters(); - ParameterMarshallers = new Func, CancellationToken, object?>[parameters.Length]; - for (int i = 0; i < parameters.Length; i++) - { - ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, parameters[i]); - } - - // Get a marshaling delegate for the return value. - ReturnParameterMarshaller = GetReturnParameterMarshaller(key.Method, serializerOptions); - - Method = key.Method; - Name = key.Name ?? GetFunctionName(key.Method); - Description = key.Description ?? key.Method.GetCustomAttribute(inherit: true)?.Description ?? string.Empty; - JsonSerializerOptions = serializerOptions; - JsonSchema = AIJsonUtilities.CreateFunctionJsonSchema( - key.Method, - Name, - Description, - serializerOptions, - key.SchemaOptions); - } - - public string Name { get; } - public string Description { get; } - public MethodInfo Method { get; } - public JsonSerializerOptions JsonSerializerOptions { get; } - public JsonElement JsonSchema { get; } - public Func, CancellationToken, object?>[] ParameterMarshallers { get; } - public Func> ReturnParameterMarshaller { get; } - public ReflectionAIFunction? CachedDefaultInstance { get; set; } - - private static string GetFunctionName(MethodInfo method) - { - // Get the function name to use. - string name = SanitizeMemberName(method.Name); - - const string AsyncSuffix = "Async"; - if (IsAsyncMethod(method) && - name.EndsWith(AsyncSuffix, StringComparison.Ordinal) && - name.Length > AsyncSuffix.Length) - { - name = name.Substring(0, name.Length - AsyncSuffix.Length); - } - - return name; - - static bool IsAsyncMethod(MethodInfo method) - { - Type t = method.ReturnType; - - if (t == typeof(Task) || t == typeof(ValueTask)) - { - return true; - } - - if (t.IsGenericType) - { - t = t.GetGenericTypeDefinition(); - if (t == typeof(Task<>) || t == typeof(ValueTask<>) || t == typeof(IAsyncEnumerable<>)) - { - return true; - } - } - - return false; - } - } - - /// - /// Gets a delegate for handling the marshaling of a parameter. - /// - private static Func, CancellationToken, object?> GetParameterMarshaller( - JsonSerializerOptions serializerOptions, - ParameterInfo parameter) - { - if (string.IsNullOrWhiteSpace(parameter.Name)) - { - throw new ArgumentException("Parameter is missing a name.", nameof(parameter)); - } - - // Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found. - Type parameterType = parameter.ParameterType; - JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(parameterType); - - // For CancellationToken parameters, we always bind to the token passed directly to InvokeAsync. - if (parameterType == typeof(CancellationToken)) - { - return static (_, cancellationToken) => - cancellationToken == default ? _boxedDefaultCancellationToken : // optimize common case of a default CT to avoid boxing - cancellationToken; - } - - // For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary. - return (arguments, _) => - { - // If the parameter has an argument specified in the dictionary, return that argument. - if (arguments.TryGetValue(parameter.Name, out object? value)) - { - return value switch - { - null => null, // Return as-is if null -- if the parameter is a struct this will be handled by MethodInfo.Invoke - _ when parameterType.IsInstanceOfType(value) => value, // Do nothing if value is assignable to parameter type - JsonElement element => JsonSerializer.Deserialize(element, typeInfo), - JsonDocument doc => JsonSerializer.Deserialize(doc, typeInfo), - JsonNode node => JsonSerializer.Deserialize(node, typeInfo), - _ => MarshallViaJsonRoundtrip(value), - }; - - object? MarshallViaJsonRoundtrip(object value) - { -#pragma warning disable CA1031 // Do not catch general exception types - try - { - string json = JsonSerializer.Serialize(value, serializerOptions.GetTypeInfo(value.GetType())); - return JsonSerializer.Deserialize(json, typeInfo); - } - catch - { - // Eat any exceptions and fall back to the original value to force a cast exception later on. - return value; - } -#pragma warning restore CA1031 - } - } - - // There was no argument for the parameter in the dictionary. - // Does it have a default value? - if (parameter.HasDefaultValue) - { - return parameter.DefaultValue; - } - - // Leave it empty. - return null; - }; - } - - /// - /// Gets a delegate for handling the result value of a method, converting it into the to return from the invocation. - /// - private static Func> GetReturnParameterMarshaller(MethodInfo method, JsonSerializerOptions serializerOptions) - { - Type returnType = method.ReturnType; - JsonTypeInfo returnTypeInfo; - - // Void - if (returnType == typeof(void)) - { - return static (_, _) => new ValueTask((object?)null); - } - - // Task - if (returnType == typeof(Task)) - { - return async static (result, _) => - { - await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false); - return null; - }; - } - - // ValueTask - if (returnType == typeof(ValueTask)) - { - return async static (result, _) => - { - await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(false); - return null; - }; - } - - if (returnType.IsGenericType) - { - // Task - if (returnType.GetGenericTypeDefinition() == typeof(Task<>)) - { - MethodInfo taskResultGetter = GetMethodFromGenericMethodDefinition(returnType, _taskGetResult); - returnTypeInfo = serializerOptions.GetTypeInfo(taskResultGetter.ReturnType); - return async (taskObj, cancellationToken) => - { - await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(false); - object? result = ReflectionInvoke(taskResultGetter, taskObj, null); - return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); - }; - } - - // ValueTask - if (returnType.GetGenericTypeDefinition() == typeof(ValueTask<>)) - { - MethodInfo valueTaskAsTask = GetMethodFromGenericMethodDefinition(returnType, _valueTaskAsTask); - MethodInfo asTaskResultGetter = GetMethodFromGenericMethodDefinition(valueTaskAsTask.ReturnType, _taskGetResult); - returnTypeInfo = serializerOptions.GetTypeInfo(asTaskResultGetter.ReturnType); - return async (taskObj, cancellationToken) => - { - var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!; - await task.ConfigureAwait(false); - object? result = ReflectionInvoke(asTaskResultGetter, task, null); - return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false); - }; - } - } - - // For everything else, just serialize the result as-is. - returnTypeInfo = serializerOptions.GetTypeInfo(returnType); - return (result, cancellationToken) => SerializeResultAsync(result, returnTypeInfo, cancellationToken); - - static async ValueTask SerializeResultAsync(object? result, JsonTypeInfo returnTypeInfo, CancellationToken cancellationToken) - { - if (returnTypeInfo.Kind is JsonTypeInfoKind.None) - { - // Special-case trivial contracts to avoid the more expensive general-purpose serialization path. - return JsonSerializer.SerializeToElement(result, returnTypeInfo); - } - - // Serialize asynchronously to support potential IAsyncEnumerable responses. - using PooledMemoryStream stream = new(); -#if NET9_0_OR_GREATER - await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken).ConfigureAwait(false); - Utf8JsonReader reader = new(stream.GetBuffer()); - return JsonElement.ParseValue(ref reader); -#else - await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken).ConfigureAwait(false); - stream.Position = 0; - var serializerOptions = _defaultOptions.SerializerOptions ?? AIJsonUtilities.DefaultOptions; - return await JsonSerializer.DeserializeAsync(stream, serializerOptions.GetTypeInfo(typeof(JsonElement)), cancellationToken).ConfigureAwait(false); -#endif - } - - // Throws an exception if a result is found to be null unexpectedly - static object ThrowIfNullResult(object? result) => result ?? throw new InvalidOperationException("Function returned null unexpectedly."); - } - - private static readonly MethodInfo _taskGetResult = typeof(Task<>).GetProperty(nameof(Task.Result), BindingFlags.Instance | BindingFlags.Public)!.GetMethod!; - private static readonly MethodInfo _valueTaskAsTask = typeof(ValueTask<>).GetMethod(nameof(ValueTask.AsTask), BindingFlags.Instance | BindingFlags.Public)!; - - private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedType, MethodInfo genericMethodDefinition) - { - Debug.Assert(specializedType.IsGenericType && specializedType.GetGenericTypeDefinition() == genericMethodDefinition.DeclaringType, "generic member definition doesn't match type."); -#if NET - return (MethodInfo)specializedType.GetMemberWithSameMetadataDefinitionAs(genericMethodDefinition); -#else -#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields - const BindingFlags All = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance; -#pragma warning restore S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields - return specializedType.GetMethods(All).First(m => m.MetadataToken == genericMethodDefinition.MetadataToken); -#endif - } - - private record struct DescriptorKey(MethodInfo Method, string? Name, string? Description, AIJsonSchemaCreateOptions SchemaOptions); - } - - /// - /// Removes characters from a .NET member name that shouldn't be used in an AI function name. - /// - /// The .NET member name that should be sanitized. - /// - /// Replaces non-alphanumeric characters in the identifier with the underscore character. - /// Primarily intended to remove characters produced by compiler-generated method name mangling. - /// - internal static string SanitizeMemberName(string memberName) - { - Verify.NotNull(memberName); - return InvalidNameCharsRegex().Replace(memberName, "_"); - } - - /// Regex that flags any character other than ASCII digits or letters or the underscore. -#if NET - [GeneratedRegex("[^0-9A-Za-z_]")] - private static partial Regex InvalidNameCharsRegex(); -#else - private static Regex InvalidNameCharsRegex() => _invalidNameCharsRegex; - private static readonly Regex _invalidNameCharsRegex = new("[^0-9A-Za-z_]", RegexOptions.Compiled); -#endif - - /// Invokes the MethodInfo with the specified target object and arguments. - private static object? ReflectionInvoke(MethodInfo method, object? target, object?[]? arguments) - { -#if NET - return method.Invoke(target, BindingFlags.DoNotWrapExceptions, binder: null, arguments, culture: null); -#else - try - { - return method.Invoke(target, BindingFlags.Default, binder: null, arguments, culture: null); - } - catch (TargetInvocationException e) when (e.InnerException is not null) - { - // If we're targeting .NET Framework, such that BindingFlags.DoNotWrapExceptions - // is ignored, the original exception will be wrapped in a TargetInvocationException. - // Unwrap it and throw that original exception, maintaining its stack information. - System.Runtime.ExceptionServices.ExceptionDispatchInfo.Capture(e.InnerException).Throw(); - throw; - } -#endif - } - - /// - /// Implements a simple write-only memory stream that uses pooled buffers. - /// - private sealed class PooledMemoryStream : Stream - { - private const int DefaultBufferSize = 4096; - private byte[] _buffer; - private int _position; - - public PooledMemoryStream(int initialCapacity = DefaultBufferSize) - { - _buffer = ArrayPool.Shared.Rent(initialCapacity); - _position = 0; - } - - public ReadOnlySpan GetBuffer() => _buffer.AsSpan(0, _position); - public override bool CanWrite => true; - public override bool CanRead => false; - public override bool CanSeek => false; - public override long Length => _position; - public override long Position - { - get => _position; - set => throw new NotSupportedException(); - } - - public override void Write(byte[] buffer, int offset, int count) - { - EnsureNotDisposed(); - EnsureCapacity(_position + count); - - Buffer.BlockCopy(buffer, offset, _buffer, _position, count); - _position += count; - } - - public override void Flush() - { - } - - public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); - public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); - public override void SetLength(long value) => throw new NotSupportedException(); - - protected override void Dispose(bool disposing) - { - if (_buffer is not null) - { - ArrayPool.Shared.Return(_buffer); - _buffer = null!; - } - - base.Dispose(disposing); - } - - private void EnsureCapacity(int requiredCapacity) - { - if (requiredCapacity <= _buffer.Length) - { - return; - } - - int newCapacity = Math.Max(requiredCapacity, _buffer.Length * 2); - byte[] newBuffer = ArrayPool.Shared.Rent(newCapacity); - Buffer.BlockCopy(_buffer, 0, newBuffer, 0, _position); - - ArrayPool.Shared.Return(_buffer); - _buffer = newBuffer; - } - - private void EnsureNotDisposed() - { - if (_buffer is null) - { - Throw(); - static void Throw() => throw new ObjectDisposedException(nameof(PooledMemoryStream)); - } - } - } -} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/FunctionFactoryOptions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/FunctionFactoryOptions.cs deleted file mode 100644 index f9f43ee630ae..000000000000 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/FunctionFactoryOptions.cs +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.ComponentModel; -using System.Diagnostics.CodeAnalysis; -using System.Reflection; -using System.Text.Json; -using Microsoft.Extensions.AI; - -namespace Microsoft.SemanticKernel.ChatCompletion; - -// Slight modified source from -// https://raw.githubusercontent.com/dotnet/extensions/refs/heads/main/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryOptions.cs - -/// -/// Represents options that can be provided when creating an from a method. -/// -[ExcludeFromCodeCoverage] -internal sealed class AIFunctionFactoryOptions -{ - /// - /// Initializes a new instance of the class. - /// - public AIFunctionFactoryOptions() - { - } - - /// Gets or sets the used to marshal .NET values being passed to the underlying delegate. - /// - /// If no value has been specified, the instance will be used. - /// - public JsonSerializerOptions? SerializerOptions { get; set; } - - /// - /// Gets or sets the governing the generation of JSON schemas for the function. - /// - /// - /// If no value has been specified, the instance will be used. - /// - public AIJsonSchemaCreateOptions? JsonSchemaCreateOptions { get; set; } - - /// Gets or sets the name to use for the function. - /// - /// The name to use for the function. The default value is a name derived from the method represented by the passed or . - /// - public string? Name { get; set; } - - /// Gets or sets the description to use for the function. - /// - /// The description for the function. The default value is a description derived from the passed or , if possible - /// (for example, via a on the method). - /// - public string? Description { get; set; } - - /// - /// Gets or sets additional values to store on the resulting property. - /// - /// - /// This property can be used to provide arbitrary information about the function. - /// - public IReadOnlyDictionary? AdditionalProperties { get; set; } -} diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index 3a5eb88b5cf3..df7b9f107fa2 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -45,7 +45,7 @@ namespace Microsoft.SemanticKernel.ChatCompletion; /// /// [ExcludeFromCodeCoverage] -internal sealed partial class KernelFunctionInvokingChatClient : DelegatingChatClient +internal sealed partial class KernelFunctionInvokingChatClientOld : DelegatingChatClient { /// The for the current function invocation. private static readonly AsyncLocal s_currentContext = new(); @@ -65,7 +65,7 @@ internal sealed partial class KernelFunctionInvokingChatClient : DelegatingChatC /// /// The underlying , or the next instance in a chain of clients. /// An to use for logging information about function invocation. - public KernelFunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null) + public KernelFunctionInvokingChatClientOld(IChatClient innerClient, ILogger? logger = null) : base(innerClient) { this._logger = logger ?? NullLogger.Instance; diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientV2.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientV2.cs new file mode 100644 index 000000000000..6978a01dd44b --- /dev/null +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientV2.cs @@ -0,0 +1,873 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.ExceptionServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable CA2213 // Disposable fields should be disposed +#pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test +#pragma warning disable SA1202 // 'protected' members should come before 'private' members + +namespace Microsoft.Extensions.AI; + +/// +/// A delegating chat client that invokes functions defined on . +/// Include this in a chat pipeline to resolve function calls automatically. +/// +/// +/// +/// When this client receives a in a chat response, it responds +/// by calling the corresponding defined in , +/// producing a that it sends back to the inner client. This loop +/// is repeated until there are no more function calls to make, or until another stop condition is met, +/// such as hitting . +/// +/// +/// The provided implementation of is thread-safe for concurrent use so long as the +/// instances employed as part of the supplied are also safe. +/// The property can be used to control whether multiple function invocation +/// requests as part of the same request are invocable concurrently, but even with that set to +/// (the default), multiple concurrent requests to this same instance and using the same tools could result in those +/// tools being used concurrently (one per request). For example, a function that accesses the HttpContext of a specific +/// ASP.NET web request should only be used as part of a single at a time, and only with +/// set to , in case the inner client decided to issue multiple +/// invocation requests to that same function. +/// +/// +public partial class FunctionInvokingChatClient : DelegatingChatClient +{ + /// The for the current function invocation. + private static readonly AsyncLocal _currentContext = new(); + + /// Optional services used for function invocation. + private readonly IServiceProvider? _functionInvocationServices; + + /// The logger to use for logging information about function invocation. + private readonly ILogger _logger; + + /// The to use for telemetry. + /// This component does not own the instance and should not dispose it. + private readonly ActivitySource? _activitySource; + + /// Maximum number of roundtrips allowed to the inner client. + private int _maximumIterationsPerRequest = 10; + + /// Maximum number of consecutive iterations that are allowed contain at least one exception result. If the limit is exceeded, we rethrow the exception instead of continuing. + private int _maximumConsecutiveErrorsPerRequest = 3; + + /// + /// Initializes a new instance of the class. + /// + /// The underlying , or the next instance in a chain of clients. + /// An to use for logging information about function invocation. + /// An optional to use for resolving services required by the instances being invoked. + public FunctionInvokingChatClient(IChatClient innerClient, ILoggerFactory? loggerFactory = null, IServiceProvider? functionInvocationServices = null) + : base(innerClient) + { + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _activitySource = innerClient.GetService(); + _functionInvocationServices = functionInvocationServices; + } + + /// + /// Gets or sets the for the current function invocation. + /// + /// + /// This value flows across async calls. + /// + public static FunctionInvocationContext? CurrentContext + { + get => _currentContext.Value; + protected set => _currentContext.Value = value; + } + + /// + /// Gets or sets a value indicating whether detailed exception information should be included + /// in the chat history when calling the underlying . + /// + /// + /// if the full exception message is added to the chat history + /// when calling the underlying . + /// if a generic error message is included in the chat history. + /// The default value is . + /// + /// + /// + /// Setting the value to prevents the underlying language model from disclosing + /// raw exception details to the end user, since it doesn't receive that information. Even in this + /// case, the raw object is available to application code by inspecting + /// the property. + /// + /// + /// Setting the value to can help the underlying bypass problems on + /// its own, for example by retrying the function call with different arguments. However it might + /// result in disclosing the raw exception information to external users, which can be a security + /// concern depending on the application scenario. + /// + /// + /// Changing the value of this property while the client is in use might result in inconsistencies + /// as to whether detailed errors are provided during an in-flight request. + /// + /// + public bool IncludeDetailedErrors { get; set; } + + /// + /// Gets or sets a value indicating whether to allow concurrent invocation of functions. + /// + /// + /// if multiple function calls can execute in parallel. + /// if function calls are processed serially. + /// The default value is . + /// + /// + /// An individual response from the inner client might contain multiple function call requests. + /// By default, such function calls are processed serially. Set to + /// to enable concurrent invocation such that multiple function calls can execute in parallel. + /// + public bool AllowConcurrentInvocation { get; set; } + + /// + /// Gets or sets the maximum number of iterations per request. + /// + /// + /// The maximum number of iterations per request. + /// The default value is 10. + /// + /// + /// + /// Each request to this might end up making + /// multiple requests to the inner client. Each time the inner client responds with + /// a function call request, this client might perform that invocation and send the results + /// back to the inner client in a new request. This property limits the number of times + /// such a roundtrip is performed. The value must be at least one, as it includes the initial request. + /// + /// + /// Changing the value of this property while the client is in use might result in inconsistencies + /// as to how many iterations are allowed for an in-flight request. + /// + /// + public int MaximumIterationsPerRequest + { + get => _maximumIterationsPerRequest; + set + { + if (value < 1) + { + Throw.ArgumentOutOfRangeException(nameof(value)); + } + + _maximumIterationsPerRequest = value; + } + } + + /// + /// Gets or sets the maximum number of consecutive iterations that are allowed to fail with an error. + /// + /// + /// The maximum number of consecutive iterations that are allowed to fail with an error. + /// The default value is 3. + /// + /// + /// + /// When function invocations fail with an exception, the + /// continues to make requests to the inner client, optionally supplying exception information (as + /// controlled by ). This allows the to + /// recover from errors by trying other function parameters that may succeed. + /// + /// + /// However, in case function invocations continue to produce exceptions, this property can be used to + /// limit the number of consecutive failing attempts. When the limit is reached, the exception will be + /// rethrown to the caller. + /// + /// + /// If the value is set to zero, all function calling exceptions immediately terminate the function + /// invocation loop and the exception will be rethrown to the caller. + /// + /// + /// Changing the value of this property while the client is in use might result in inconsistencies + /// as to how many iterations are allowed for an in-flight request. + /// + /// + public int MaximumConsecutiveErrorsPerRequest + { + get => _maximumConsecutiveErrorsPerRequest; + set => _maximumConsecutiveErrorsPerRequest = Throw.IfLessThan(value, 0); + } + + /// + public override async Task GetResponseAsync( + IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(messages); + + // A single request into this GetResponseAsync may result in multiple requests to the inner client. + // Create an activity to group them together for better observability. + using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); + + // Copy the original messages in order to avoid enumerating the original messages multiple times. + // The IEnumerable can represent an arbitrary amount of work. + List originalMessages = [.. messages]; + messages = originalMessages; + + List? augmentedHistory = null; // the actual history of messages sent on turns other than the first + ChatResponse? response = null; // the response from the inner client, which is possibly modified and then eventually returned + List? responseMessages = null; // tracked list of messages, across multiple turns, to be used for the final response + UsageDetails? totalUsage = null; // tracked usage across all turns, to be used for the final response + List? functionCallContents = null; // function call contents that need responding to in the current turn + bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set + int consecutiveErrorCount = 0; + + for (int iteration = 0; ; iteration++) + { + functionCallContents?.Clear(); + + // Make the call to the inner client. + response = await base.GetResponseAsync(messages, options, cancellationToken); + if (response is null) + { + Throw.InvalidOperationException($"The inner {nameof(IChatClient)} returned a null {nameof(ChatResponse)}."); + } + + // Any function call work to do? If yes, ensure we're tracking that work in functionCallContents. + bool requiresFunctionInvocation = + options?.Tools is { Count: > 0 } && + iteration < MaximumIterationsPerRequest && + CopyFunctionCalls(response.Messages, ref functionCallContents); + + // In a common case where we make a request and there's no function calling work required, + // fast path out by just returning the original response. + if (iteration == 0 && !requiresFunctionInvocation) + { + return response; + } + + // Track aggregatable details from the response, including all of the response messages and usage details. + (responseMessages ??= []).AddRange(response.Messages); + if (response.Usage is not null) + { + if (totalUsage is not null) + { + totalUsage.Add(response.Usage); + } + else + { + totalUsage = response.Usage; + } + } + + // If there are no tools to call, or for any other reason we should stop, we're done. + // Break out of the loop and allow the handling at the end to configure the response + // with aggregated data from previous requests. + if (!requiresFunctionInvocation) + { + break; + } + + // Prepare the history for the next iteration. + FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId); + + // Add the responses from the function calls into the augmented history and also into the tracked + // list of response messages. + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, cancellationToken); + responseMessages.AddRange(modeAndMessages.MessagesAdded); + consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; + + if (modeAndMessages.ShouldTerminate) + { + break; + } + + UpdateOptionsForNextIteration(ref options!, response.ChatThreadId); + } + + Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages."); + response.Messages = responseMessages!; + response.Usage = totalUsage; + + return response; + } + + /// + public override async IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(messages); + + // A single request into this GetStreamingResponseAsync may result in multiple requests to the inner client. + // Create an activity to group them together for better observability. + using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); + + // Copy the original messages in order to avoid enumerating the original messages multiple times. + // The IEnumerable can represent an arbitrary amount of work. + List originalMessages = [.. messages]; + messages = originalMessages; + + List? augmentedHistory = null; // the actual history of messages sent on turns other than the first + List? functionCallContents = null; // function call contents that need responding to in the current turn + List? responseMessages = null; // tracked list of messages, across multiple turns, to be used in fallback cases to reconstitute history + bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set + List updates = []; // updates from the current response + int consecutiveErrorCount = 0; + + for (int iteration = 0; ; iteration++) + { + updates.Clear(); + functionCallContents?.Clear(); + + await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken)) + { + if (update is null) + { + Throw.InvalidOperationException($"The inner {nameof(IChatClient)} streamed a null {nameof(ChatResponseUpdate)}."); + } + + updates.Add(update); + + _ = CopyFunctionCalls(update.Contents, ref functionCallContents); + + yield return update; + Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 + } + + // If there are no tools to call, or for any other reason we should stop, return the response. + if (functionCallContents is not { Count: > 0 } || + options?.Tools is not { Count: > 0 } || + iteration >= _maximumIterationsPerRequest) + { + break; + } + + // Reconsistitue a response from the response updates. + var response = updates.ToChatResponse(); + (responseMessages ??= []).AddRange(response.Messages); + + // Prepare the history for the next iteration. + FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId); + + // Process all of the functions, adding their results into the history. + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, cancellationToken); + responseMessages.AddRange(modeAndMessages.MessagesAdded); + consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; + + // This is a synthetic ID since we're generating the tool messages instead of getting them from + // the underlying provider. When emitting the streamed chunks, it's perfectly valid for us to + // use the same message ID for all of them within a given iteration, as this is a single logical + // message with multiple content items. We could also use different message IDs per tool content, + // but there's no benefit to doing so. + string toolResponseId = Guid.NewGuid().ToString("N"); + + // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages + // includes all activitys, including generated function results. + foreach (var message in modeAndMessages.MessagesAdded) + { + var toolResultUpdate = new ChatResponseUpdate + { + AdditionalProperties = message.AdditionalProperties, + AuthorName = message.AuthorName, + ChatThreadId = response.ChatThreadId, + CreatedAt = DateTimeOffset.UtcNow, + Contents = message.Contents, + RawRepresentation = message.RawRepresentation, + ResponseId = toolResponseId, + MessageId = toolResponseId, // See above for why this can be the same as ResponseId + Role = message.Role, + }; + + yield return toolResultUpdate; + Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 + } + + if (modeAndMessages.ShouldTerminate) + { + yield break; + } + + UpdateOptionsForNextIteration(ref options, response.ChatThreadId); + } + } + + /// Prepares the various chat message lists after a response from the inner client and before invoking functions. + /// The original messages provided by the caller. + /// The messages reference passed to the inner client. + /// The augmented history containing all the messages to be sent. + /// The most recent response being handled. + /// A list of all response messages received up until this point. + /// Whether the previous iteration's response had a thread id. + private static void FixupHistories( + IEnumerable originalMessages, + ref IEnumerable messages, + [NotNull] ref List? augmentedHistory, + ChatResponse response, + List allTurnsResponseMessages, + ref bool lastIterationHadThreadId) + { + // We're now going to need to augment the history with function result contents. + // That means we need a separate list to store the augmented history. + if (response.ChatThreadId is not null) + { + // The response indicates the inner client is tracking the history, so we don't want to send + // anything we've already sent or received. + if (augmentedHistory is not null) + { + augmentedHistory.Clear(); + } + else + { + augmentedHistory = []; + } + + lastIterationHadThreadId = true; + } + else if (lastIterationHadThreadId) + { + // In the very rare case where the inner client returned a response with a thread ID but then + // returned a subsequent response without one, we want to reconstitue the full history. To do that, + // we can populate the history with the original chat messages and then all of the response + // messages up until this point, which includes the most recent ones. + augmentedHistory ??= []; + augmentedHistory.Clear(); + augmentedHistory.AddRange(originalMessages); + augmentedHistory.AddRange(allTurnsResponseMessages); + + lastIterationHadThreadId = false; + } + else + { + // If augmentedHistory is already non-null, then we've already populated it with everything up + // until this point (except for the most recent response). If it's null, we need to seed it with + // the chat history provided by the caller. + augmentedHistory ??= originalMessages.ToList(); + + // Now add the most recent response messages. + augmentedHistory.AddMessages(response); + + lastIterationHadThreadId = false; + } + + // Use the augmented history as the new set of messages to send. + messages = augmentedHistory; + } + + /// Copies any from to . + private static bool CopyFunctionCalls( + IList messages, [NotNullWhen(true)] ref List? functionCalls) + { + bool any = false; + int count = messages.Count; + for (int i = 0; i < count; i++) + { + any |= CopyFunctionCalls(messages[i].Contents, ref functionCalls); + } + + return any; + } + + /// Copies any from to . + private static bool CopyFunctionCalls( + IList content, [NotNullWhen(true)] ref List? functionCalls) + { + bool any = false; + int count = content.Count; + for (int i = 0; i < count; i++) + { + if (content[i] is FunctionCallContent functionCall) + { + (functionCalls ??= []).Add(functionCall); + any = true; + } + } + + return any; + } + + private static void UpdateOptionsForNextIteration(ref ChatOptions options, string? chatThreadId) + { + if (options.ToolMode is RequiredChatToolMode) + { + // We have to reset the tool mode to be non-required after the first iteration, + // as otherwise we'll be in an infinite loop. + options = options.Clone(); + options.ToolMode = null; + options.ChatThreadId = chatThreadId; + } + else if (options.ChatThreadId != chatThreadId) + { + // As with the other modes, ensure we've propagated the chat thread ID to the options. + // We only need to clone the options if we're actually mutating it. + options = options.Clone(); + options.ChatThreadId = chatThreadId; + } + } + + /// + /// Processes the function calls in the list. + /// + /// The current chat contents, inclusive of the function call contents being processed. + /// The options used for the response being processed. + /// The function call contents representing the functions to be invoked. + /// The iteration number of how many roundtrips have been made to the inner client. + /// The number of consecutive iterations, prior to this one, that were recorded as having function invocation errors. + /// The to monitor for cancellation requests. + /// A value indicating how the caller should proceed. + private async Task<(bool ShouldTerminate, int NewConsecutiveErrorCount, IList MessagesAdded)> ProcessFunctionCallsAsync( + List messages, ChatOptions options, List functionCallContents, int iteration, int consecutiveErrorCount, CancellationToken cancellationToken) + { + // We must add a response for every tool call, regardless of whether we successfully executed it or not. + // If we successfully execute it, we'll add the result. If we don't, we'll add an error. + + Debug.Assert(functionCallContents.Count > 0, "Expecteded at least one function call."); + + var captureCurrentIterationExceptions = consecutiveErrorCount < _maximumConsecutiveErrorsPerRequest; + + // Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel. + if (functionCallContents.Count == 1) + { + FunctionInvocationResult result = await ProcessFunctionCallAsync( + messages, options, functionCallContents, iteration, 0, captureCurrentIterationExceptions, cancellationToken); + + IList added = CreateResponseMessages([result]); + ThrowIfNoFunctionResultsAdded(added); + UpdateConsecutiveErrorCountOrThrow(added, ref consecutiveErrorCount); + + messages.AddRange(added); + return (result.ShouldTerminate, consecutiveErrorCount, added); + } + else + { + FunctionInvocationResult[] results; + + if (AllowConcurrentInvocation) + { + // Rather than await'ing each function before invoking the next, invoke all of them + // and then await all of them. We avoid forcibly introducing parallelism via Task.Run, + // but if a function invocation completes asynchronously, its processing can overlap + // with the processing of other the other invocation invocations. + results = await Task.WhenAll( + from i in Enumerable.Range(0, functionCallContents.Count) + select ProcessFunctionCallAsync( + messages, options, functionCallContents, + iteration, i, captureExceptions: true, cancellationToken)); + } + else + { + // Invoke each function serially. + results = new FunctionInvocationResult[functionCallContents.Count]; + for (int i = 0; i < results.Length; i++) + { + results[i] = await ProcessFunctionCallAsync( + messages, options, functionCallContents, + iteration, i, captureCurrentIterationExceptions, cancellationToken); + } + } + + var shouldTerminate = false; + + IList added = CreateResponseMessages(results); + ThrowIfNoFunctionResultsAdded(added); + UpdateConsecutiveErrorCountOrThrow(added, ref consecutiveErrorCount); + + messages.AddRange(added); + foreach (FunctionInvocationResult fir in results) + { + shouldTerminate = shouldTerminate || fir.ShouldTerminate; + } + + return (shouldTerminate, consecutiveErrorCount, added); + } + } + + private void UpdateConsecutiveErrorCountOrThrow(IList added, ref int consecutiveErrorCount) + { + var allExceptions = added.SelectMany(m => m.Contents.OfType()) + .Select(frc => frc.Exception!) + .Where(e => e is not null); + + if (allExceptions.Any()) + { + consecutiveErrorCount++; + if (consecutiveErrorCount > _maximumConsecutiveErrorsPerRequest) + { + var allExceptionsArray = allExceptions.ToArray(); + if (allExceptionsArray.Length == 1) + { + ExceptionDispatchInfo.Capture(allExceptionsArray[0]).Throw(); + } + else + { + throw new AggregateException(allExceptionsArray); + } + } + } + else + { + consecutiveErrorCount = 0; + } + } + + /// + /// Throws an exception if doesn't create any messages. + /// + private void ThrowIfNoFunctionResultsAdded(IList? messages) + { + if (messages is null || messages.Count == 0) + { + Throw.InvalidOperationException($"{GetType().Name}.{nameof(CreateResponseMessages)} returned null or an empty collection of messages."); + } + } + + /// Processes the function call described in []. + /// The current chat contents, inclusive of the function call contents being processed. + /// The options used for the response being processed. + /// The function call contents representing all the functions being invoked. + /// The iteration number of how many roundtrips have been made to the inner client. + /// The 0-based index of the function being called out of . + /// If true, handles function-invocation exceptions by returning a value with . Otherwise, rethrows. + /// The to monitor for cancellation requests. + /// A value indicating how the caller should proceed. + private async Task ProcessFunctionCallAsync( + List messages, ChatOptions options, List callContents, + int iteration, int functionCallIndex, bool captureExceptions, CancellationToken cancellationToken) + { + var callContent = callContents[functionCallIndex]; + + // Look up the AIFunction for the function call. If the requested function isn't available, send back an error. + AIFunction? function = options.Tools!.OfType().FirstOrDefault(t => t.Name == callContent.Name); + if (function is null) + { + return new(shouldTerminate: false, FunctionInvocationStatus.NotFound, callContent, result: null, exception: null); + } + + FunctionInvocationContext context = new() + { + Function = function, + Arguments = new(callContent.Arguments) { Services = _functionInvocationServices }, + + Messages = messages, + Options = options, + + CallContent = callContent, + Iteration = iteration, + FunctionCallIndex = functionCallIndex, + FunctionCount = callContents.Count, + }; + + object? result; + try + { + result = await InvokeFunctionAsync(context, cancellationToken); + } + catch (Exception e) when (!cancellationToken.IsCancellationRequested) + { + if (!captureExceptions) + { + throw; + } + + return new( + shouldTerminate: false, + FunctionInvocationStatus.Exception, + callContent, + result: null, + exception: e); + } + + return new( + shouldTerminate: context.Terminate, + FunctionInvocationStatus.RanToCompletion, + callContent, + result, + exception: null); + } + + /// Creates one or more response messages for function invocation results. + /// Information about the function call invocations and results. + /// A list of all chat messages created from . + protected virtual IList CreateResponseMessages( + ReadOnlySpan results) + { + var contents = new List(results.Length); + for (int i = 0; i < results.Length; i++) + { + contents.Add(CreateFunctionResultContent(results[i])); + } + + return [new(ChatRole.Tool, contents)]; + + FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult result) + { + _ = Throw.IfNull(result); + + object? functionResult; + if (result.Status == FunctionInvocationStatus.RanToCompletion) + { + functionResult = result.Result ?? "Success: Function completed."; + } + else + { + string message = result.Status switch + { + FunctionInvocationStatus.NotFound => $"Error: Requested function \"{result.CallContent.Name}\" not found.", + FunctionInvocationStatus.Exception => "Error: Function failed.", + _ => "Error: Unknown error.", + }; + + if (IncludeDetailedErrors && result.Exception is not null) + { + message = $"{message} Exception: {result.Exception.Message}"; + } + + functionResult = message; + } + + return new FunctionResultContent(result.CallContent.CallId, functionResult) { Exception = result.Exception }; + } + } + + /// Invokes the function asynchronously. + /// + /// The function invocation context detailing the function to be invoked and its arguments along with additional request information. + /// + /// The to monitor for cancellation requests. The default is . + /// The result of the function invocation, or if the function invocation returned . + /// is . + protected virtual async Task InvokeFunctionAsync(FunctionInvocationContext context, CancellationToken cancellationToken) + { + _ = Throw.IfNull(context); + + using Activity? activity = _activitySource?.StartActivity(context.Function.Name); + + long startingTimestamp = 0; + if (_logger.IsEnabled(LogLevel.Debug)) + { + startingTimestamp = Stopwatch.GetTimestamp(); + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogInvokingSensitive(context.Function.Name, LoggingHelpers.AsJson(context.Arguments, context.Function.JsonSerializerOptions)); + } + else + { + LogInvoking(context.Function.Name); + } + } + + object? result = null; + try + { + CurrentContext = context; // doesn't need to be explicitly reset after, as that's handled automatically at async method exit + result = await context.Function.InvokeAsync(context.Arguments, cancellationToken); + } + catch (Exception e) + { + if (activity is not null) + { + _ = activity.SetTag("error.type", e.GetType().FullName) + .SetStatus(ActivityStatusCode.Error, e.Message); + } + + if (e is OperationCanceledException) + { + LogInvocationCanceled(context.Function.Name); + } + else + { + LogInvocationFailed(context.Function.Name, e); + } + + throw; + } + finally + { + if (_logger.IsEnabled(LogLevel.Debug)) + { + TimeSpan elapsed = GetElapsedTime(startingTimestamp); + + if (result is not null && _logger.IsEnabled(LogLevel.Trace)) + { + LogInvocationCompletedSensitive(context.Function.Name, elapsed, LoggingHelpers.AsJson(result, context.Function.JsonSerializerOptions)); + } + else + { + LogInvocationCompleted(context.Function.Name, elapsed); + } + } + } + + return result; + } + + private static TimeSpan GetElapsedTime(long startingTimestamp) => +#if NET + Stopwatch.GetElapsedTime(startingTimestamp); +#else + new((long)((Stopwatch.GetTimestamp() - startingTimestamp) * ((double)TimeSpan.TicksPerSecond / Stopwatch.Frequency))); +#endif + + [LoggerMessage(LogLevel.Debug, "Invoking {MethodName}.", SkipEnabledCheck = true)] + private partial void LogInvoking(string methodName); + + [LoggerMessage(LogLevel.Trace, "Invoking {MethodName}({Arguments}).", SkipEnabledCheck = true)] + private partial void LogInvokingSensitive(string methodName, string arguments); + + [LoggerMessage(LogLevel.Debug, "{MethodName} invocation completed. Duration: {Duration}", SkipEnabledCheck = true)] + private partial void LogInvocationCompleted(string methodName, TimeSpan duration); + + [LoggerMessage(LogLevel.Trace, "{MethodName} invocation completed. Duration: {Duration}. Result: {Result}", SkipEnabledCheck = true)] + private partial void LogInvocationCompletedSensitive(string methodName, TimeSpan duration, string result); + + [LoggerMessage(LogLevel.Debug, "{MethodName} invocation canceled.")] + private partial void LogInvocationCanceled(string methodName); + + [LoggerMessage(LogLevel.Error, "{MethodName} invocation failed.")] + private partial void LogInvocationFailed(string methodName, Exception error); + + /// Provides information about the invocation of a function call. + public sealed class FunctionInvocationResult + { + internal FunctionInvocationResult(bool shouldTerminate, FunctionInvocationStatus status, FunctionCallContent callContent, object? result, Exception? exception) + { + ShouldTerminate = shouldTerminate; + Status = status; + CallContent = callContent; + Result = result; + Exception = exception; + } + + /// Gets status about how the function invocation completed. + public FunctionInvocationStatus Status { get; } + + /// Gets the function call content information associated with this invocation. + public FunctionCallContent CallContent { get; } + + /// Gets the result of the function call. + public object? Result { get; } + + /// Gets any exception the function call threw. + public Exception? Exception { get; } + + /// Gets a value indicating whether the caller should terminate the processing loop. + internal bool ShouldTerminate { get; } + } + + /// Provides error codes for when errors occur as part of the function calling loop. + public enum FunctionInvocationStatus + { + /// The operation completed successfully. + RanToCompletion, + + /// The requested function could not be found. + NotFound, + + /// The function call failed with an exception. + Exception, + } +} diff --git a/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj b/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj index 47043cbe1df8..4826426418a6 100644 --- a/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj +++ b/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj @@ -29,7 +29,7 @@ - + From 3d56b9564712731846bc5949aa8408ccfae2c190 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Mon, 14 Apr 2025 12:20:14 +0100 Subject: [PATCH 20/26] Using Extensions AI + Latest logic for FunctionINvokingChatClient --- ...IServiceCollectionExtensions.ChatClient.cs | 12 +- .../AI/ChatClient/ChatClientExtensions.cs | 6 +- .../KernelFunctionInvocationContext.cs | 79 +- .../KernelFunctionInvokingChatClient.cs | 496 +++++----- .../KernelFunctionInvokingChatClientV2.cs | 873 ------------------ 5 files changed, 331 insertions(+), 1135 deletions(-) delete mode 100644 dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientV2.cs diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs index 7fa002d806ea..2954e958936a 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs @@ -47,12 +47,12 @@ public static IServiceCollection AddOpenAIChatClient( IChatClient Factory(IServiceProvider serviceProvider, object? _) { - ILogger? logger = serviceProvider.GetService()?.CreateLogger(); + var loggerFactory = serviceProvider.GetService(); return new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), options: GetClientOptions(orgId: orgId, httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider))) .GetChatClient(modelId) .AsIChatClient() - .AsKernelFunctionInvokingChatClient(logger); + .AsKernelFunctionInvokingChatClient(loggerFactory); } services.AddKeyedSingleton(serviceId, (Func)Factory); @@ -77,12 +77,12 @@ public static IServiceCollection AddOpenAIChatClient(this IServiceCollection ser IChatClient Factory(IServiceProvider serviceProvider, object? _) { - ILogger? logger = serviceProvider.GetService()?.CreateLogger(); + var loggerFactory = serviceProvider.GetService(); return (openAIClient ?? serviceProvider.GetRequiredService()) .GetChatClient(modelId) .AsIChatClient() - .AsKernelFunctionInvokingChatClient(logger); + .AsKernelFunctionInvokingChatClient(loggerFactory); } services.AddKeyedSingleton(serviceId, (Func)Factory); @@ -114,12 +114,12 @@ public static IServiceCollection AddOpenAIChatClient( IChatClient Factory(IServiceProvider serviceProvider, object? _) { - ILogger? logger = serviceProvider.GetService()?.CreateLogger(); + var loggerFactory = serviceProvider.GetService(); return new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), GetClientOptions(endpoint, orgId, HttpClientProvider.GetHttpClient(httpClient, serviceProvider))) .GetChatClient(modelId) .AsIChatClient() - .AsKernelFunctionInvokingChatClient(logger); + .AsKernelFunctionInvokingChatClient(loggerFactory); } services.AddKeyedSingleton(serviceId, (Func)Factory); diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs index 0ffdd5fec99d..223530beb1ba 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs @@ -92,16 +92,16 @@ public static IChatCompletionService AsChatCompletionService(this IChatClient cl /// Creates a new that supports for function invocation with a . /// /// Target chat client service. - /// Optional logger to use for logging. + /// Optional logger to use for logging. /// Function invoking chat client. [Experimental("SKEXP0001")] - public static IChatClient AsKernelFunctionInvokingChatClient(this IChatClient client, ILogger? logger = null) + public static IChatClient AsKernelFunctionInvokingChatClient(this IChatClient client, ILoggerFactory? loggerFactory = null) { Verify.NotNull(client); return client is KernelFunctionInvokingChatClient kernelFunctionInvocationClient ? kernelFunctionInvocationClient - : new KernelFunctionInvokingChatClient(client, logger); + : new KernelFunctionInvokingChatClient(client, loggerFactory); } private static ChatOptions GetChatOptionsFromSettings(PromptExecutionSettings? executionSettings, Kernel? kernel) diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs index c8f293b2bb0c..efc3d39ab48a 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using Microsoft.Extensions.AI; #pragma warning disable IDE0009 // Use explicit 'this.' qualifier @@ -20,63 +21,61 @@ internal class KernelFunctionInvocationContext { /// /// A nop function used to allow to be non-nullable. Default instances of - /// start with this as the target function. + /// start with this as the target function. /// - private static readonly AIFunction s_nopFunction = AIFunctionFactory.Create(() => { }, nameof(KernelFunctionInvocationContext)); + private static readonly AIFunction _nopFunction = AIFunctionFactory.Create(() => { }, nameof(FunctionInvocationContext)); /// The chat contents associated with the operation that initiated this function call request. private IList _messages = Array.Empty(); /// The AI function to be invoked. - private AIFunction _function = s_nopFunction; + private AIFunction _function = _nopFunction; /// The function call content information associated with this invocation. - private Microsoft.Extensions.AI.FunctionCallContent _callContent = new(string.Empty, s_nopFunction.Name, EmptyReadOnlyDictionary.Instance); + private Microsoft.Extensions.AI.FunctionCallContent? _callContent; - /// Initializes a new instance of the class. - internal KernelFunctionInvocationContext() + /// The arguments used with the function. + private AIFunctionArguments? _arguments; + + /// Initializes a new instance of the class. + public KernelFunctionInvocationContext() + { + } + + /// Gets or sets the AI function to be invoked. + public AIFunction Function + { + get => _function; + set => _function = Throw.IfNull(value); + } + + /// Gets or sets the arguments associated with this invocation. + public AIFunctionArguments Arguments { + get => _arguments ??= []; + set => _arguments = Throw.IfNull(value); } /// Gets or sets the function call content information associated with this invocation. public Microsoft.Extensions.AI.FunctionCallContent CallContent { - get => _callContent; - set - { - Verify.NotNull(value); - _callContent = value; - } + get => _callContent ??= new(string.Empty, _nopFunction.Name, EmptyReadOnlyDictionary.Instance); + set => _callContent = Throw.IfNull(value); } /// Gets or sets the chat contents associated with the operation that initiated this function call request. public IList Messages { get => _messages; - set - { - Verify.NotNull(value); - _messages = value; - } + set => _messages = Throw.IfNull(value); } /// Gets or sets the chat options associated with the operation that initiated this function call request. public ChatOptions? Options { get; set; } - /// Gets or sets the AI function to be invoked. - public AIFunction Function - { - get => _function; - set - { - Verify.NotNull(value); - _function = value; - } - } - /// Gets or sets the number of this iteration with the underlying client. /// - /// The initial request to the client that passes along the chat contents provided to the + /// The initial request to the client that passes along the chat contents provided to the /// is iteration 1. If the client responds with a function call request, the next request to the client is iteration 2, and so on. /// public int Iteration { get; set; } @@ -103,4 +102,26 @@ public AIFunction Function /// more function call requests in responses. /// public bool Terminate { get; set; } + + private static class Throw + { + /// + /// Throws an if the specified argument is . + /// + /// Argument type to be checked for . + /// Object to be checked for . + /// The name of the parameter being checked. + /// The original value of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [return: NotNull] + public static T IfNull([NotNull] T argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument is null) + { + throw new ArgumentNullException(paramName); + } + + return argument; + } + } } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index df7b9f107fa2..e6496cc63f4f 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -1,26 +1,31 @@ // Copyright (c) Microsoft. All rights reserved. using System; +#pragma warning restore IDE0073 // The file header does not match the required text using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Runtime.CompilerServices; +using System.Runtime.ExceptionServices; using System.Text.Json; using System.Threading; using System.Threading.Tasks; -using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +#pragma warning disable IDE1006 // Naming Styles +#pragma warning disable IDE0009 // This #pragma warning disable CA2213 // Disposable fields should be disposed -#pragma warning disable IDE0009 // Use explicit 'this.' qualifier -#pragma warning disable IDE1006 // Missing prefix: 's_' +#pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test +#pragma warning disable SA1202 // 'protected' members should come before 'private' members -namespace Microsoft.SemanticKernel.ChatCompletion; +// Modified source from 2025-04-07 +// https://raw.githubusercontent.com/dotnet/extensions/84d09b794d994435568adcbb85a981143d4f15cb/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs -// Slight modified source from -// https://raw.githubusercontent.com/dotnet/extensions/refs/heads/main/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +namespace Microsoft.Extensions.AI; /// /// A delegating chat client that invokes functions defined on . @@ -28,9 +33,11 @@ namespace Microsoft.SemanticKernel.ChatCompletion; /// /// /// -/// When this client receives a in a chat response, it responds -/// by calling the corresponding defined in , -/// producing a . +/// When this client receives a in a chat response, it responds +/// by calling the corresponding defined in , +/// producing a that it sends back to the inner client. This loop +/// is repeated until there are no more function calls to make, or until another stop condition is met, +/// such as hitting . /// /// /// The provided implementation of is thread-safe for concurrent use so long as the @@ -44,11 +51,13 @@ namespace Microsoft.SemanticKernel.ChatCompletion; /// invocation requests to that same function. /// /// -[ExcludeFromCodeCoverage] -internal sealed partial class KernelFunctionInvokingChatClientOld : DelegatingChatClient +public partial class KernelFunctionInvokingChatClient : DelegatingChatClient { /// The for the current function invocation. - private static readonly AsyncLocal s_currentContext = new(); + private static readonly AsyncLocal _currentContext = new(); + + /// Optional services used for function invocation. + private readonly IServiceProvider? _functionInvocationServices; /// The logger to use for logging information about function invocation. private readonly ILogger _logger; @@ -58,49 +67,37 @@ internal sealed partial class KernelFunctionInvokingChatClientOld : DelegatingCh private readonly ActivitySource? _activitySource; /// Maximum number of roundtrips allowed to the inner client. - private int? _maximumIterationsPerRequest; + private int _maximumIterationsPerRequest = 10; + + /// Maximum number of consecutive iterations that are allowed contain at least one exception result. If the limit is exceeded, we rethrow the exception instead of continuing. + private int _maximumConsecutiveErrorsPerRequest = 3; /// /// Initializes a new instance of the class. /// /// The underlying , or the next instance in a chain of clients. - /// An to use for logging information about function invocation. - public KernelFunctionInvokingChatClientOld(IChatClient innerClient, ILogger? logger = null) + /// An to use for logging information about function invocation. + /// An optional to use for resolving services required by the instances being invoked. + public KernelFunctionInvokingChatClient(IChatClient innerClient, ILoggerFactory? loggerFactory = null, IServiceProvider? functionInvocationServices = null) : base(innerClient) { - this._logger = logger ?? NullLogger.Instance; - this._activitySource = innerClient.GetService(); + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _activitySource = innerClient.GetService(); + _functionInvocationServices = functionInvocationServices; } /// - /// Gets or sets the for the current function invocation. + /// Gets or sets the for the current function invocation. /// /// /// This value flows across async calls. /// - internal static AutoFunctionInvocationContext? CurrentContext + public static AutoFunctionInvocationContext? CurrentContext { - get => s_currentContext.Value; - set => s_currentContext.Value = value; + get => _currentContext.Value; + protected set => _currentContext.Value = value; } - /// - /// Gets or sets a value indicating whether to handle exceptions that occur during function calls. - /// - /// - /// if the - /// underlying will be instructed to give a response without invoking - /// any further functions if a function call fails with an exception. - /// if the underlying is allowed - /// to continue attempting function calls until is reached. - /// The default value is . - /// - /// - /// Changing the value of this property while the client is in use might result in inconsistencies - /// whether errors are retried during an in-flight request. - /// - public bool RetryOnError { get; set; } - /// /// Gets or sets a value indicating whether detailed exception information should be included /// in the chat history when calling the underlying . @@ -116,17 +113,17 @@ internal static AutoFunctionInvocationContext? CurrentContext /// Setting the value to prevents the underlying language model from disclosing /// raw exception details to the end user, since it doesn't receive that information. Even in this /// case, the raw object is available to application code by inspecting - /// the property. + /// the property. /// /// /// Setting the value to can help the underlying bypass problems on - /// its own, for example by retrying the function call with different arguments. However, it might + /// its own, for example by retrying the function call with different arguments. However it might /// result in disclosing the raw exception information to external users, which can be a security /// concern depending on the application scenario. /// /// /// Changing the value of this property while the client is in use might result in inconsistencies - /// whether detailed errors are provided during an in-flight request. + /// as to whether detailed errors are provided during an in-flight request. /// /// public bool IncludeDetailedErrors { get; set; } @@ -151,25 +148,24 @@ internal static AutoFunctionInvocationContext? CurrentContext /// /// /// The maximum number of iterations per request. - /// The default value is . + /// The default value is 10. /// /// /// - /// Each request to this might end up making + /// Each request to this might end up making /// multiple requests to the inner client. Each time the inner client responds with /// a function call request, this client might perform that invocation and send the results /// back to the inner client in a new request. This property limits the number of times - /// such a roundtrip is performed. If null, there is no limit applied. If set, the value - /// must be at least one, as it includes the initial request. + /// such a roundtrip is performed. The value must be at least one, as it includes the initial request. /// /// /// Changing the value of this property while the client is in use might result in inconsistencies /// as to how many iterations are allowed for an in-flight request. /// /// - public int? MaximumIterationsPerRequest + public int MaximumIterationsPerRequest { - get => this._maximumIterationsPerRequest; + get => _maximumIterationsPerRequest; set { if (value < 1) @@ -177,7 +173,48 @@ public int? MaximumIterationsPerRequest throw new ArgumentOutOfRangeException(nameof(value)); } - this._maximumIterationsPerRequest = value; + _maximumIterationsPerRequest = value; + } + } + + /// + /// Gets or sets the maximum number of consecutive iterations that are allowed to fail with an error. + /// + /// + /// The maximum number of consecutive iterations that are allowed to fail with an error. + /// The default value is 3. + /// + /// + /// + /// When function invocations fail with an exception, the + /// continues to make requests to the inner client, optionally supplying exception information (as + /// controlled by ). This allows the to + /// recover from errors by trying other function parameters that may succeed. + /// + /// + /// However, in case function invocations continue to produce exceptions, this property can be used to + /// limit the number of consecutive failing attempts. When the limit is reached, the exception will be + /// rethrown to the caller. + /// + /// + /// If the value is set to zero, all function calling exceptions immediately terminate the function + /// invocation loop and the exception will be rethrown to the caller. + /// + /// + /// Changing the value of this property while the client is in use might result in inconsistencies + /// as to how many iterations are allowed for an in-flight request. + /// + /// + public int MaximumConsecutiveErrorsPerRequest + { + get => _maximumConsecutiveErrorsPerRequest; + set + { + if (value < 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "Argument less than minimum value 0"); + } + _maximumConsecutiveErrorsPerRequest = value; } } @@ -189,7 +226,7 @@ public override async Task GetResponseAsync( // A single request into this GetResponseAsync may result in multiple requests to the inner client. // Create an activity to group them together for better observability. - using Activity? activity = this._activitySource?.StartActivity(nameof(KernelFunctionInvokingChatClient)); + using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); // Copy the original messages in order to avoid enumerating the original messages multiple times. // The IEnumerable can represent an arbitrary amount of work. @@ -200,8 +237,9 @@ public override async Task GetResponseAsync( ChatResponse? response = null; // the response from the inner client, which is possibly modified and then eventually returned List? responseMessages = null; // tracked list of messages, across multiple turns, to be used for the final response UsageDetails? totalUsage = null; // tracked usage across all turns, to be used for the final response - List? functionCallContents = null; // function call contents that need responding to in the current turn + List? functionCallContents = null; // function call contents that need responding to in the current turn bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set + int consecutiveErrorCount = 0; for (int iteration = 0; ; iteration++) { @@ -217,7 +255,7 @@ public override async Task GetResponseAsync( // Any function call work to do? If yes, ensure we're tracking that work in functionCallContents. bool requiresFunctionInvocation = options?.Tools is { Count: > 0 } && - (!this.MaximumIterationsPerRequest.HasValue || iteration < this.MaximumIterationsPerRequest.GetValueOrDefault()) && + iteration < MaximumIterationsPerRequest && CopyFunctionCalls(response.Messages, ref functionCallContents); // In a common case where we make a request and there's no function calling work required, @@ -227,7 +265,7 @@ public override async Task GetResponseAsync( return response; } - // Track aggregate details from the response, including all the response messages and usage details. + // Track aggregatable details from the response, including all of the response messages and usage details. (responseMessages ??= []).AddRange(response.Messages); if (response.Usage is not null) { @@ -257,21 +295,23 @@ public override async Task GetResponseAsync( // Add the responses from the function calls into the augmented history and also into the tracked // list of response messages. - var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents!, iteration, isStreaming: false, cancellationToken).ConfigureAwait(false); + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, isStreaming: false, cancellationToken).ConfigureAwait(false); responseMessages.AddRange(modeAndMessages.MessagesAdded); + consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; - // Clear the auto function invocation options. - ClearOptionsForAutoFunctionInvocation(ref options); - - if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, response.ChatThreadId)) + if (modeAndMessages.ShouldTerminate) { - // Terminate break; } + + // Clear the auto function invocation options. + ClearOptionsForAutoFunctionInvocation(ref options); + + UpdateOptionsForNextIteration(ref options!, response.ChatThreadId); } Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages."); - response.Messages = responseMessages; + response.Messages = responseMessages!; response.Usage = totalUsage; return response; @@ -285,7 +325,7 @@ public override async IAsyncEnumerable GetStreamingResponseA // A single request into this GetStreamingResponseAsync may result in multiple requests to the inner client. // Create an activity to group them together for better observability. - using Activity? activity = this._activitySource?.StartActivity(nameof(KernelFunctionInvokingChatClient)); + using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); // Copy the original messages in order to avoid enumerating the original messages multiple times. // The IEnumerable can represent an arbitrary amount of work. @@ -293,10 +333,11 @@ public override async IAsyncEnumerable GetStreamingResponseA messages = originalMessages; List? augmentedHistory = null; // the actual history of messages sent on turns other than the first - List? functionCallContents = null; // function call contents that need responding to in the current turn + List? functionCallContents = null; // function call contents that need responding to in the current turn List? responseMessages = null; // tracked list of messages, across multiple turns, to be used in fallback cases to reconstitute history bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set List updates = []; // updates from the current response + int consecutiveErrorCount = 0; for (int iteration = 0; ; iteration++) { @@ -321,12 +362,12 @@ public override async IAsyncEnumerable GetStreamingResponseA // If there are no tools to call, or for any other reason we should stop, return the response. if (functionCallContents is not { Count: > 0 } || options?.Tools is not { Count: > 0 } || - (this.MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations)) + iteration >= _maximumIterationsPerRequest) { break; } - // Reconstitute a response from the response updates. + // Reconsistitue a response from the response updates. var response = updates.ToChatResponse(); (responseMessages ??= []).AddRange(response.Messages); @@ -336,16 +377,23 @@ public override async IAsyncEnumerable GetStreamingResponseA // Prepare the options for the next auto function invocation iteration. UpdateOptionsForAutoFunctionInvocation(ref options, response.Messages.Last().ToChatMessageContent(), isStreaming: true); - // Process all the functions, adding their results into the history. - var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, isStreaming: true, cancellationToken).ConfigureAwait(false); + // Process all of the functions, adding their results into the history. + var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, isStreaming: true, cancellationToken).ConfigureAwait(false); responseMessages.AddRange(modeAndMessages.MessagesAdded); + consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; // Clear the auto function invocation options. ClearOptionsForAutoFunctionInvocation(ref options); - // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages - // include all activities, including generated function results. + // This is a synthetic ID since we're generating the tool messages instead of getting them from + // the underlying provider. When emitting the streamed chunks, it's perfectly valid for us to + // use the same message ID for all of them within a given iteration, as this is a single logical + // message with multiple content items. We could also use different message IDs per tool content, + // but there's no benefit to doing so. string toolResponseId = Guid.NewGuid().ToString("N"); + + // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages + // includes all activitys, including generated function results. foreach (var message in modeAndMessages.MessagesAdded) { var toolResultUpdate = new ChatResponseUpdate @@ -357,6 +405,7 @@ public override async IAsyncEnumerable GetStreamingResponseA Contents = message.Contents, RawRepresentation = message.RawRepresentation, ResponseId = toolResponseId, + MessageId = toolResponseId, // See above for why this can be the same as ResponseId Role = message.Role, }; @@ -364,11 +413,12 @@ public override async IAsyncEnumerable GetStreamingResponseA Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 } - if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, response.ChatThreadId)) + if (modeAndMessages.ShouldTerminate) { - // Terminate yield break; } + + UpdateOptionsForNextIteration(ref options, response.ChatThreadId); } } @@ -407,8 +457,8 @@ private static void FixupHistories( else if (lastIterationHadThreadId) { // In the very rare case where the inner client returned a response with a thread ID but then - // returned a subsequent response without one, we want to reconstitute the full history. To do that, - // we can populate the history with the original chat messages and then all the response + // returned a subsequent response without one, we want to reconstitue the full history. To do that, + // we can populate the history with the original chat messages and then all of the response // messages up until this point, which includes the most recent ones. augmentedHistory ??= []; augmentedHistory.Clear(); @@ -436,7 +486,7 @@ private static void FixupHistories( /// Copies any from to . private static bool CopyFunctionCalls( - IList messages, [NotNullWhen(true)] ref List? functionCalls) + IList messages, [NotNullWhen(true)] ref List? functionCalls) { bool any = false; int count = messages.Count; @@ -448,15 +498,15 @@ private static bool CopyFunctionCalls( return any; } - /// Copies any from to . + /// Copies any from to . private static bool CopyFunctionCalls( - IList content, [NotNullWhen(true)] ref List? functionCalls) + IList content, [NotNullWhen(true)] ref List? functionCalls) { bool any = false; int count = content.Count; for (int i = 0; i < count; i++) { - if (content[i] is Microsoft.Extensions.AI.FunctionCallContent functionCall) + if (content[i] is FunctionCallContent functionCall) { (functionCalls ??= []).Add(functionCall); any = true; @@ -497,47 +547,23 @@ private static void ClearOptionsForAutoFunctionInvocation(ref ChatOptions option } } - /// Updates for the response. - /// true if the function calling loop should terminate; otherwise, false. - private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions options, string? chatThreadId) + private static void UpdateOptionsForNextIteration(ref ChatOptions options, string? chatThreadId) { - switch (mode) + if (options.ToolMode is RequiredChatToolMode) { - case ContinueMode.Continue when options.ToolMode is RequiredChatToolMode: - // We have to reset the tool mode to be non-required after the first iteration, - // as otherwise we'll be in an infinite loop. - options = options.Clone(); - options.ToolMode = null; - options.ChatThreadId = chatThreadId; - - break; - - case ContinueMode.AllowOneMoreRoundtrip: - // The LLM gets one further chance to answer, but cannot use tools. - options = options.Clone(); - options.Tools = null; - options.ToolMode = null; - options.ChatThreadId = chatThreadId; - - break; - - case ContinueMode.Terminate: - // Bail immediately. - return true; - - default: - // As with the other modes, ensure we've propagated the chat thread ID to the options. - // We only need to clone the options if we're actually mutating it. - if (options.ChatThreadId != chatThreadId) - { - options = options.Clone(); - options.ChatThreadId = chatThreadId; - } - - break; + // We have to reset the tool mode to be non-required after the first iteration, + // as otherwise we'll be in an infinite loop. + options = options.Clone(); + options.ToolMode = null; + options.ChatThreadId = chatThreadId; + } + else if (options.ChatThreadId != chatThreadId) + { + // As with the other modes, ensure we've propagated the chat thread ID to the options. + // We only need to clone the options if we're actually mutating it. + options = options.Clone(); + options.ChatThreadId = chatThreadId; } - - return false; } /// @@ -547,85 +573,122 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti /// The options used for the response being processed. /// The function call contents representing the functions to be invoked. /// The iteration number of how many roundtrips have been made to the inner client. + /// The number of consecutive iterations, prior to this one, that were recorded as having function invocation errors. /// Whether the function calls are being processed in a streaming context. /// The to monitor for cancellation requests. - /// A value indicating how the caller should proceed. - private async Task<(ContinueMode Mode, IList MessagesAdded)> ProcessFunctionCallsAsync( - List messages, ChatOptions options, List functionCallContents, int iteration, bool isStreaming, CancellationToken cancellationToken) + /// A value indicating how the caller should proceed. + private async Task<(bool ShouldTerminate, int NewConsecutiveErrorCount, IList MessagesAdded)> ProcessFunctionCallsAsync( + List messages, ChatOptions options, List functionCallContents, + int iteration, int consecutiveErrorCount, bool isStreaming, CancellationToken cancellationToken) { // We must add a response for every tool call, regardless of whether we successfully executed it or not. // If we successfully execute it, we'll add the result. If we don't, we'll add an error. - Debug.Assert(functionCallContents.Count > 0, "Expected at least one function call."); - ContinueMode continueMode = ContinueMode.Continue; + Debug.Assert(functionCallContents.Count > 0, "Expecteded at least one function call."); + var shouldTerminate = false; + + var captureCurrentIterationExceptions = consecutiveErrorCount < _maximumConsecutiveErrorsPerRequest; // Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel. if (functionCallContents.Count == 1) { - FunctionInvocationResult result = await this.ProcessFunctionCallAsync( - messages, options, functionCallContents, iteration, 0, isStreaming, cancellationToken).ConfigureAwait(false); + FunctionInvocationResult result = await ProcessFunctionCallAsync( + messages, options, functionCallContents, iteration, 0, captureCurrentIterationExceptions, isStreaming, cancellationToken).ConfigureAwait(false); - IList added = this.CreateResponseMessages([result]); - this.ThrowIfNoFunctionResultsAdded(added); + IList added = CreateResponseMessages([result]); + ThrowIfNoFunctionResultsAdded(added); + UpdateConsecutiveErrorCountOrThrow(added, ref consecutiveErrorCount); messages.AddRange(added); - return (result.ContinueMode, added); + return (result.ShouldTerminate, consecutiveErrorCount, added); } else { List results = []; var terminationRequested = false; - if (this.AllowConcurrentInvocation) + if (AllowConcurrentInvocation) { - // Schedule the invocation of every function. + // Rather than await'ing each function before invoking the next, invoke all of them + // and then await all of them. We avoid forcibly introducing parallelism via Task.Run, + // but if a function invocation completes asynchronously, its processing can overlap + // with the processing of other the other invocation invocations. results.AddRange(await Task.WhenAll( from i in Enumerable.Range(0, functionCallContents.Count) - select Task.Run(() => this.ProcessFunctionCallAsync( + select ProcessFunctionCallAsync( messages, options, functionCallContents, - iteration, i, isStreaming, cancellationToken))).ConfigureAwait(false)); + iteration, i, captureExceptions: true, isStreaming, cancellationToken)).ConfigureAwait(false)); - terminationRequested = results.Any(r => r.ContinueMode == ContinueMode.Terminate); + terminationRequested = results.Any(r => r.ShouldTerminate); } else { // Invoke each function serially. for (int i = 0; i < functionCallContents.Count; i++) { - var result = await this.ProcessFunctionCallAsync( + var result = await ProcessFunctionCallAsync( messages, options, functionCallContents, - iteration, i, isStreaming, cancellationToken).ConfigureAwait(false); + iteration, i, captureCurrentIterationExceptions, isStreaming, cancellationToken).ConfigureAwait(false); results.Add(result); - if (result.ContinueMode == ContinueMode.Terminate) + if (result.ShouldTerminate) { - continueMode = ContinueMode.Terminate; + shouldTerminate = true; terminationRequested = true; break; } } } - IList added = this.CreateResponseMessages(results); - this.ThrowIfNoFunctionResultsAdded(added); + IList added = CreateResponseMessages(results); + ThrowIfNoFunctionResultsAdded(added); + UpdateConsecutiveErrorCountOrThrow(added, ref consecutiveErrorCount); + messages.AddRange(added); if (!terminationRequested) { // If any function requested termination, we'll terminate. - continueMode = ContinueMode.Continue; + shouldTerminate = false; foreach (FunctionInvocationResult fir in results) { - if (fir.ContinueMode > continueMode) - { - continueMode = fir.ContinueMode; - } + shouldTerminate = shouldTerminate || fir.ShouldTerminate; } } - return (continueMode, added); + return (shouldTerminate, consecutiveErrorCount, added); + } + } + + private void UpdateConsecutiveErrorCountOrThrow(IList added, ref int consecutiveErrorCount) + { + var allExceptions = added.SelectMany(m => m.Contents.OfType()) + .Select(frc => frc.Exception!) + .Where(e => e is not null); + +#pragma warning disable CA1851 // Possible multiple enumerations of 'IEnumerable' collection + if (allExceptions.Any()) + { + consecutiveErrorCount++; + if (consecutiveErrorCount > _maximumConsecutiveErrorsPerRequest) + { + var allExceptionsArray = allExceptions.ToArray(); + if (allExceptionsArray.Length == 1) + { + ExceptionDispatchInfo.Capture(allExceptionsArray[0]).Throw(); + } + else + { + throw new AggregateException(allExceptionsArray); + } + } } + else + { + consecutiveErrorCount = 0; + } +#pragma warning restore CA1851 // Possible multiple enumerations of 'IEnumerable' collection } /// @@ -645,21 +708,21 @@ private void ThrowIfNoFunctionResultsAdded(IList? messages) /// The function call contents representing all the functions being invoked. /// The iteration number of how many roundtrips have been made to the inner client. /// The 0-based index of the function being called out of . + /// If true, handles function-invocation exceptions by returning a value with . Otherwise, rethrows. /// Whether the function calls are being processed in a streaming context. /// The to monitor for cancellation requests. - /// A value indicating how the caller should proceed. + /// A value indicating how the caller should proceed. private async Task ProcessFunctionCallAsync( - List messages, ChatOptions options, List callContents, - int iteration, int functionCallIndex, bool isStreaming, CancellationToken cancellationToken) + List messages, ChatOptions options, List callContents, + int iteration, int functionCallIndex, bool captureExceptions, bool isStreaming, CancellationToken cancellationToken) { var callContent = callContents[functionCallIndex]; // Look up the AIFunction for the function call. If the requested function isn't available, send back an error. - AIFunction? function = options.Tools!.OfType().FirstOrDefault( - t => t.Name == callContent.Name); + AIFunction? function = options.Tools!.OfType().FirstOrDefault(t => t.Name == callContent.Name); if (function is null) { - return new(ContinueMode.Continue, FunctionInvocationStatus.NotFound, callContent, result: null, exception: null); + return new(shouldTerminate: false, FunctionInvokingChatClient.FunctionInvocationStatus.NotFound, callContent, result: null, exception: null); } if (callContent.Arguments is not null) @@ -667,73 +730,69 @@ private async Task ProcessFunctionCallAsync( callContent.Arguments = new KernelArguments(callContent.Arguments); } - var context = new AutoFunctionInvocationContext(new KernelFunctionInvocationContext + var context = new AutoFunctionInvocationContext(new() { - Options = options, + Function = function, + Arguments = new(callContent.Arguments) { Services = _functionInvocationServices }, + Messages = messages, + Options = options, + CallContent = callContent, - Function = function, Iteration = iteration, FunctionCallIndex = functionCallIndex, FunctionCount = callContents.Count, - }) { IsStreaming = isStreaming }; + }) + { IsStreaming = isStreaming }; object? result; try { - result = await this.InvokeFunctionAsync(context, cancellationToken).ConfigureAwait(false); + result = await InvokeFunctionAsync(context, cancellationToken).ConfigureAwait(false); } catch (Exception e) when (!cancellationToken.IsCancellationRequested) { - return new(this.RetryOnError ? ContinueMode.Continue : ContinueMode.AllowOneMoreRoundtrip, // We won't allow further function calls, hence the LLM will just get one more chance to give a final answer. - FunctionInvocationStatus.Exception, + if (!captureExceptions) + { + throw; + } + + return new( + shouldTerminate: false, + FunctionInvokingChatClient.FunctionInvocationStatus.Exception, callContent, result: null, exception: e); } return new( - context.Terminate ? ContinueMode.Terminate : ContinueMode.Continue, - FunctionInvocationStatus.RanToCompletion, + shouldTerminate: context.Terminate, + FunctionInvokingChatClient.FunctionInvocationStatus.RanToCompletion, callContent, result, exception: null); } - /// Represents the return value of , dictating how the loop should behave. - /// These values are ordered from least severe to most severe, and code explicitly depends on the ordering. - internal enum ContinueMode - { - /// Send back the responses and continue processing. - Continue = 0, - - /// Send back the response but without any tools. - AllowOneMoreRoundtrip = 1, - - /// Immediately exit the function calling loop. - Terminate = 2, - } - /// Creates one or more response messages for function invocation results. /// Information about the function call invocations and results. /// A list of all chat messages created from . - internal IList CreateResponseMessages( + protected virtual IList CreateResponseMessages( IReadOnlyList results) { var contents = new List(results.Count); - foreach (var t in results) + for (int i = 0; i < results.Count; i++) { - contents.Add(CreateFunctionResultContent(t)); + contents.Add(CreateFunctionResultContent(results[i])); } return [new(ChatRole.Tool, contents)]; - Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult result) + FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult result) { Verify.NotNull(result); object? functionResult; - if (result.Status == FunctionInvocationStatus.RanToCompletion) + if (result.Status == FunctionInvokingChatClient.FunctionInvocationStatus.RanToCompletion) { functionResult = result.Result ?? "Success: Function completed."; } @@ -741,12 +800,12 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi { string message = result.Status switch { - FunctionInvocationStatus.NotFound => $"Error: Requested function \"{result.CallContent.Name}\" not found.", - FunctionInvocationStatus.Exception => "Error: Function failed.", + FunctionInvokingChatClient.FunctionInvocationStatus.NotFound => $"Error: Requested function \"{result.CallContent.Name}\" not found.", + FunctionInvokingChatClient.FunctionInvocationStatus.Exception => "Error: Function failed.", _ => "Error: Unknown error.", }; - if (this.IncludeDetailedErrors && result.Exception is not null) + if (IncludeDetailedErrors && result.Exception is not null) { message = $"{message} Exception: {result.Exception.Message}"; } @@ -754,7 +813,7 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi functionResult = message; } - return new Microsoft.Extensions.AI.FunctionResultContent(result.CallContent.CallId, functionResult) { Exception = result.Exception }; + return new FunctionResultContent(result.CallContent.CallId, functionResult) { Exception = result.Exception }; } } @@ -801,40 +860,42 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( } /// Invokes the function asynchronously. - /// The function invocation context detailing the function to be invoked and its arguments along with additional request information. + /// + /// The function invocation context detailing the function to be invoked and its arguments along with additional request information. + /// /// The to monitor for cancellation requests. The default is . /// The result of the function invocation, or if the function invocation returned . - /// is . - private async Task InvokeFunctionAsync(AutoFunctionInvocationContext invocationContext, CancellationToken cancellationToken) + /// is . + protected virtual async Task InvokeFunctionAsync(AutoFunctionInvocationContext context, CancellationToken cancellationToken) { - Verify.NotNull(invocationContext); + Verify.NotNull(context); - using Activity? activity = this._activitySource?.StartActivity(invocationContext.Function.Name); + using Activity? activity = _activitySource?.StartActivity(context.Function.Name); long startingTimestamp = 0; - if (this._logger.IsEnabled(LogLevel.Debug)) + if (_logger.IsEnabled(LogLevel.Debug)) { startingTimestamp = Stopwatch.GetTimestamp(); - if (this._logger.IsEnabled(LogLevel.Trace)) + if (_logger.IsEnabled(LogLevel.Trace)) { - this.LogInvokingSensitive(invocationContext.Function.Name, LoggingAsJson(invocationContext.CallContent.Arguments, invocationContext.AIFunction.JsonSerializerOptions)); + LogInvokingSensitive(context.Function.Name, LoggingAsJson(context.CallContent.Arguments, context.AIFunction.JsonSerializerOptions)); } else { - this.LogInvoking(invocationContext.Function.Name); + LogInvoking(context.Function.Name); } } object? result = null; try { - CurrentContext = invocationContext; - invocationContext = await this.OnAutoFunctionInvocationAsync( - invocationContext, - async (context) => + CurrentContext = context; // doesn't need to be explicitly reset after, as that's handled automatically at async method exit + context = await this.OnAutoFunctionInvocationAsync( + context, + async (ctx) => { // Check if filter requested termination - if (context.Terminate) + if (ctx.Terminate) { return; } @@ -842,10 +903,10 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here, // as the called function could in turn telling the model about itself as a possible candidate for invocation. - result = await invocationContext.AIFunction.InvokeAsync(new(invocationContext.Arguments), cancellationToken).ConfigureAwait(false); - context.Result = new FunctionResult(context.Function, result); + result = await context.AIFunction.InvokeAsync(new(context.Arguments), cancellationToken).ConfigureAwait(false); + ctx.Result = new FunctionResult(ctx.Function, result); }).ConfigureAwait(false); - result = invocationContext.Result.GetValue(); + result = context.Result.GetValue(); } catch (Exception e) { @@ -857,28 +918,28 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( if (e is OperationCanceledException) { - this.LogInvocationCanceled(invocationContext.Function.Name); + LogInvocationCanceled(context.Function.Name); } else { - this.LogInvocationFailed(invocationContext.Function.Name, e); + LogInvocationFailed(context.Function.Name, e); } throw; } finally { - if (this._logger.IsEnabled(LogLevel.Debug)) + if (_logger.IsEnabled(LogLevel.Debug)) { TimeSpan elapsed = GetElapsedTime(startingTimestamp); - if (result is not null && this._logger.IsEnabled(LogLevel.Trace)) + if (result is not null && _logger.IsEnabled(LogLevel.Trace)) { - this.LogInvocationCompletedSensitive(invocationContext.Function.Name, elapsed, LoggingAsJson(result, invocationContext.AIFunction.JsonSerializerOptions)); + LogInvocationCompletedSensitive(context.Function.Name, elapsed, LoggingAsJson(result, context.AIFunction.JsonSerializerOptions)); } else { - this.LogInvocationCompleted(invocationContext.Function.Name, elapsed); + LogInvocationCompleted(context.Function.Name, elapsed); } } } @@ -936,20 +997,20 @@ private static TimeSpan GetElapsedTime(long startingTimestamp) => /// Provides information about the invocation of a function call. public sealed class FunctionInvocationResult { - internal FunctionInvocationResult(ContinueMode continueMode, FunctionInvocationStatus status, Microsoft.Extensions.AI.FunctionCallContent callContent, object? result, Exception? exception) + internal FunctionInvocationResult(bool shouldTerminate, FunctionInvokingChatClient.FunctionInvocationStatus status, FunctionCallContent callContent, object? result, Exception? exception) { - this.ContinueMode = continueMode; - this.Status = status; - this.CallContent = callContent; - this.Result = result; - this.Exception = exception; + ShouldTerminate = shouldTerminate; + Status = status; + CallContent = callContent; + Result = result; + Exception = exception; } /// Gets status about how the function invocation completed. - public FunctionInvocationStatus Status { get; } + public FunctionInvokingChatClient.FunctionInvocationStatus Status { get; } /// Gets the function call content information associated with this invocation. - public Microsoft.Extensions.AI.FunctionCallContent CallContent { get; } + public FunctionCallContent CallContent { get; } /// Gets the result of the function call. public object? Result { get; } @@ -957,20 +1018,7 @@ internal FunctionInvocationResult(ContinueMode continueMode, FunctionInvocationS /// Gets any exception the function call threw. public Exception? Exception { get; } - /// Gets an indication for how the caller should continue the processing loop. - internal ContinueMode ContinueMode { get; } - } - - /// Provides error codes for when errors occur as part of the function calling loop. - public enum FunctionInvocationStatus - { - /// The operation completed successfully. - RanToCompletion, - - /// The requested function could not be found. - NotFound, - - /// The function call failed with an exception. - Exception, + /// Gets a value indicating whether the caller should terminate the processing loop. + internal bool ShouldTerminate { get; } } } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientV2.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientV2.cs deleted file mode 100644 index 6978a01dd44b..000000000000 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientV2.cs +++ /dev/null @@ -1,873 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Runtime.ExceptionServices; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; -using Microsoft.Shared.Diagnostics; - -#pragma warning disable CA2213 // Disposable fields should be disposed -#pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test -#pragma warning disable SA1202 // 'protected' members should come before 'private' members - -namespace Microsoft.Extensions.AI; - -/// -/// A delegating chat client that invokes functions defined on . -/// Include this in a chat pipeline to resolve function calls automatically. -/// -/// -/// -/// When this client receives a in a chat response, it responds -/// by calling the corresponding defined in , -/// producing a that it sends back to the inner client. This loop -/// is repeated until there are no more function calls to make, or until another stop condition is met, -/// such as hitting . -/// -/// -/// The provided implementation of is thread-safe for concurrent use so long as the -/// instances employed as part of the supplied are also safe. -/// The property can be used to control whether multiple function invocation -/// requests as part of the same request are invocable concurrently, but even with that set to -/// (the default), multiple concurrent requests to this same instance and using the same tools could result in those -/// tools being used concurrently (one per request). For example, a function that accesses the HttpContext of a specific -/// ASP.NET web request should only be used as part of a single at a time, and only with -/// set to , in case the inner client decided to issue multiple -/// invocation requests to that same function. -/// -/// -public partial class FunctionInvokingChatClient : DelegatingChatClient -{ - /// The for the current function invocation. - private static readonly AsyncLocal _currentContext = new(); - - /// Optional services used for function invocation. - private readonly IServiceProvider? _functionInvocationServices; - - /// The logger to use for logging information about function invocation. - private readonly ILogger _logger; - - /// The to use for telemetry. - /// This component does not own the instance and should not dispose it. - private readonly ActivitySource? _activitySource; - - /// Maximum number of roundtrips allowed to the inner client. - private int _maximumIterationsPerRequest = 10; - - /// Maximum number of consecutive iterations that are allowed contain at least one exception result. If the limit is exceeded, we rethrow the exception instead of continuing. - private int _maximumConsecutiveErrorsPerRequest = 3; - - /// - /// Initializes a new instance of the class. - /// - /// The underlying , or the next instance in a chain of clients. - /// An to use for logging information about function invocation. - /// An optional to use for resolving services required by the instances being invoked. - public FunctionInvokingChatClient(IChatClient innerClient, ILoggerFactory? loggerFactory = null, IServiceProvider? functionInvocationServices = null) - : base(innerClient) - { - _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; - _activitySource = innerClient.GetService(); - _functionInvocationServices = functionInvocationServices; - } - - /// - /// Gets or sets the for the current function invocation. - /// - /// - /// This value flows across async calls. - /// - public static FunctionInvocationContext? CurrentContext - { - get => _currentContext.Value; - protected set => _currentContext.Value = value; - } - - /// - /// Gets or sets a value indicating whether detailed exception information should be included - /// in the chat history when calling the underlying . - /// - /// - /// if the full exception message is added to the chat history - /// when calling the underlying . - /// if a generic error message is included in the chat history. - /// The default value is . - /// - /// - /// - /// Setting the value to prevents the underlying language model from disclosing - /// raw exception details to the end user, since it doesn't receive that information. Even in this - /// case, the raw object is available to application code by inspecting - /// the property. - /// - /// - /// Setting the value to can help the underlying bypass problems on - /// its own, for example by retrying the function call with different arguments. However it might - /// result in disclosing the raw exception information to external users, which can be a security - /// concern depending on the application scenario. - /// - /// - /// Changing the value of this property while the client is in use might result in inconsistencies - /// as to whether detailed errors are provided during an in-flight request. - /// - /// - public bool IncludeDetailedErrors { get; set; } - - /// - /// Gets or sets a value indicating whether to allow concurrent invocation of functions. - /// - /// - /// if multiple function calls can execute in parallel. - /// if function calls are processed serially. - /// The default value is . - /// - /// - /// An individual response from the inner client might contain multiple function call requests. - /// By default, such function calls are processed serially. Set to - /// to enable concurrent invocation such that multiple function calls can execute in parallel. - /// - public bool AllowConcurrentInvocation { get; set; } - - /// - /// Gets or sets the maximum number of iterations per request. - /// - /// - /// The maximum number of iterations per request. - /// The default value is 10. - /// - /// - /// - /// Each request to this might end up making - /// multiple requests to the inner client. Each time the inner client responds with - /// a function call request, this client might perform that invocation and send the results - /// back to the inner client in a new request. This property limits the number of times - /// such a roundtrip is performed. The value must be at least one, as it includes the initial request. - /// - /// - /// Changing the value of this property while the client is in use might result in inconsistencies - /// as to how many iterations are allowed for an in-flight request. - /// - /// - public int MaximumIterationsPerRequest - { - get => _maximumIterationsPerRequest; - set - { - if (value < 1) - { - Throw.ArgumentOutOfRangeException(nameof(value)); - } - - _maximumIterationsPerRequest = value; - } - } - - /// - /// Gets or sets the maximum number of consecutive iterations that are allowed to fail with an error. - /// - /// - /// The maximum number of consecutive iterations that are allowed to fail with an error. - /// The default value is 3. - /// - /// - /// - /// When function invocations fail with an exception, the - /// continues to make requests to the inner client, optionally supplying exception information (as - /// controlled by ). This allows the to - /// recover from errors by trying other function parameters that may succeed. - /// - /// - /// However, in case function invocations continue to produce exceptions, this property can be used to - /// limit the number of consecutive failing attempts. When the limit is reached, the exception will be - /// rethrown to the caller. - /// - /// - /// If the value is set to zero, all function calling exceptions immediately terminate the function - /// invocation loop and the exception will be rethrown to the caller. - /// - /// - /// Changing the value of this property while the client is in use might result in inconsistencies - /// as to how many iterations are allowed for an in-flight request. - /// - /// - public int MaximumConsecutiveErrorsPerRequest - { - get => _maximumConsecutiveErrorsPerRequest; - set => _maximumConsecutiveErrorsPerRequest = Throw.IfLessThan(value, 0); - } - - /// - public override async Task GetResponseAsync( - IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) - { - _ = Throw.IfNull(messages); - - // A single request into this GetResponseAsync may result in multiple requests to the inner client. - // Create an activity to group them together for better observability. - using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); - - // Copy the original messages in order to avoid enumerating the original messages multiple times. - // The IEnumerable can represent an arbitrary amount of work. - List originalMessages = [.. messages]; - messages = originalMessages; - - List? augmentedHistory = null; // the actual history of messages sent on turns other than the first - ChatResponse? response = null; // the response from the inner client, which is possibly modified and then eventually returned - List? responseMessages = null; // tracked list of messages, across multiple turns, to be used for the final response - UsageDetails? totalUsage = null; // tracked usage across all turns, to be used for the final response - List? functionCallContents = null; // function call contents that need responding to in the current turn - bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set - int consecutiveErrorCount = 0; - - for (int iteration = 0; ; iteration++) - { - functionCallContents?.Clear(); - - // Make the call to the inner client. - response = await base.GetResponseAsync(messages, options, cancellationToken); - if (response is null) - { - Throw.InvalidOperationException($"The inner {nameof(IChatClient)} returned a null {nameof(ChatResponse)}."); - } - - // Any function call work to do? If yes, ensure we're tracking that work in functionCallContents. - bool requiresFunctionInvocation = - options?.Tools is { Count: > 0 } && - iteration < MaximumIterationsPerRequest && - CopyFunctionCalls(response.Messages, ref functionCallContents); - - // In a common case where we make a request and there's no function calling work required, - // fast path out by just returning the original response. - if (iteration == 0 && !requiresFunctionInvocation) - { - return response; - } - - // Track aggregatable details from the response, including all of the response messages and usage details. - (responseMessages ??= []).AddRange(response.Messages); - if (response.Usage is not null) - { - if (totalUsage is not null) - { - totalUsage.Add(response.Usage); - } - else - { - totalUsage = response.Usage; - } - } - - // If there are no tools to call, or for any other reason we should stop, we're done. - // Break out of the loop and allow the handling at the end to configure the response - // with aggregated data from previous requests. - if (!requiresFunctionInvocation) - { - break; - } - - // Prepare the history for the next iteration. - FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId); - - // Add the responses from the function calls into the augmented history and also into the tracked - // list of response messages. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, cancellationToken); - responseMessages.AddRange(modeAndMessages.MessagesAdded); - consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; - - if (modeAndMessages.ShouldTerminate) - { - break; - } - - UpdateOptionsForNextIteration(ref options!, response.ChatThreadId); - } - - Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages."); - response.Messages = responseMessages!; - response.Usage = totalUsage; - - return response; - } - - /// - public override async IAsyncEnumerable GetStreamingResponseAsync( - IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - _ = Throw.IfNull(messages); - - // A single request into this GetStreamingResponseAsync may result in multiple requests to the inner client. - // Create an activity to group them together for better observability. - using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); - - // Copy the original messages in order to avoid enumerating the original messages multiple times. - // The IEnumerable can represent an arbitrary amount of work. - List originalMessages = [.. messages]; - messages = originalMessages; - - List? augmentedHistory = null; // the actual history of messages sent on turns other than the first - List? functionCallContents = null; // function call contents that need responding to in the current turn - List? responseMessages = null; // tracked list of messages, across multiple turns, to be used in fallback cases to reconstitute history - bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set - List updates = []; // updates from the current response - int consecutiveErrorCount = 0; - - for (int iteration = 0; ; iteration++) - { - updates.Clear(); - functionCallContents?.Clear(); - - await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken)) - { - if (update is null) - { - Throw.InvalidOperationException($"The inner {nameof(IChatClient)} streamed a null {nameof(ChatResponseUpdate)}."); - } - - updates.Add(update); - - _ = CopyFunctionCalls(update.Contents, ref functionCallContents); - - yield return update; - Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 - } - - // If there are no tools to call, or for any other reason we should stop, return the response. - if (functionCallContents is not { Count: > 0 } || - options?.Tools is not { Count: > 0 } || - iteration >= _maximumIterationsPerRequest) - { - break; - } - - // Reconsistitue a response from the response updates. - var response = updates.ToChatResponse(); - (responseMessages ??= []).AddRange(response.Messages); - - // Prepare the history for the next iteration. - FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId); - - // Process all of the functions, adding their results into the history. - var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, cancellationToken); - responseMessages.AddRange(modeAndMessages.MessagesAdded); - consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; - - // This is a synthetic ID since we're generating the tool messages instead of getting them from - // the underlying provider. When emitting the streamed chunks, it's perfectly valid for us to - // use the same message ID for all of them within a given iteration, as this is a single logical - // message with multiple content items. We could also use different message IDs per tool content, - // but there's no benefit to doing so. - string toolResponseId = Guid.NewGuid().ToString("N"); - - // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages - // includes all activitys, including generated function results. - foreach (var message in modeAndMessages.MessagesAdded) - { - var toolResultUpdate = new ChatResponseUpdate - { - AdditionalProperties = message.AdditionalProperties, - AuthorName = message.AuthorName, - ChatThreadId = response.ChatThreadId, - CreatedAt = DateTimeOffset.UtcNow, - Contents = message.Contents, - RawRepresentation = message.RawRepresentation, - ResponseId = toolResponseId, - MessageId = toolResponseId, // See above for why this can be the same as ResponseId - Role = message.Role, - }; - - yield return toolResultUpdate; - Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802 - } - - if (modeAndMessages.ShouldTerminate) - { - yield break; - } - - UpdateOptionsForNextIteration(ref options, response.ChatThreadId); - } - } - - /// Prepares the various chat message lists after a response from the inner client and before invoking functions. - /// The original messages provided by the caller. - /// The messages reference passed to the inner client. - /// The augmented history containing all the messages to be sent. - /// The most recent response being handled. - /// A list of all response messages received up until this point. - /// Whether the previous iteration's response had a thread id. - private static void FixupHistories( - IEnumerable originalMessages, - ref IEnumerable messages, - [NotNull] ref List? augmentedHistory, - ChatResponse response, - List allTurnsResponseMessages, - ref bool lastIterationHadThreadId) - { - // We're now going to need to augment the history with function result contents. - // That means we need a separate list to store the augmented history. - if (response.ChatThreadId is not null) - { - // The response indicates the inner client is tracking the history, so we don't want to send - // anything we've already sent or received. - if (augmentedHistory is not null) - { - augmentedHistory.Clear(); - } - else - { - augmentedHistory = []; - } - - lastIterationHadThreadId = true; - } - else if (lastIterationHadThreadId) - { - // In the very rare case where the inner client returned a response with a thread ID but then - // returned a subsequent response without one, we want to reconstitue the full history. To do that, - // we can populate the history with the original chat messages and then all of the response - // messages up until this point, which includes the most recent ones. - augmentedHistory ??= []; - augmentedHistory.Clear(); - augmentedHistory.AddRange(originalMessages); - augmentedHistory.AddRange(allTurnsResponseMessages); - - lastIterationHadThreadId = false; - } - else - { - // If augmentedHistory is already non-null, then we've already populated it with everything up - // until this point (except for the most recent response). If it's null, we need to seed it with - // the chat history provided by the caller. - augmentedHistory ??= originalMessages.ToList(); - - // Now add the most recent response messages. - augmentedHistory.AddMessages(response); - - lastIterationHadThreadId = false; - } - - // Use the augmented history as the new set of messages to send. - messages = augmentedHistory; - } - - /// Copies any from to . - private static bool CopyFunctionCalls( - IList messages, [NotNullWhen(true)] ref List? functionCalls) - { - bool any = false; - int count = messages.Count; - for (int i = 0; i < count; i++) - { - any |= CopyFunctionCalls(messages[i].Contents, ref functionCalls); - } - - return any; - } - - /// Copies any from to . - private static bool CopyFunctionCalls( - IList content, [NotNullWhen(true)] ref List? functionCalls) - { - bool any = false; - int count = content.Count; - for (int i = 0; i < count; i++) - { - if (content[i] is FunctionCallContent functionCall) - { - (functionCalls ??= []).Add(functionCall); - any = true; - } - } - - return any; - } - - private static void UpdateOptionsForNextIteration(ref ChatOptions options, string? chatThreadId) - { - if (options.ToolMode is RequiredChatToolMode) - { - // We have to reset the tool mode to be non-required after the first iteration, - // as otherwise we'll be in an infinite loop. - options = options.Clone(); - options.ToolMode = null; - options.ChatThreadId = chatThreadId; - } - else if (options.ChatThreadId != chatThreadId) - { - // As with the other modes, ensure we've propagated the chat thread ID to the options. - // We only need to clone the options if we're actually mutating it. - options = options.Clone(); - options.ChatThreadId = chatThreadId; - } - } - - /// - /// Processes the function calls in the list. - /// - /// The current chat contents, inclusive of the function call contents being processed. - /// The options used for the response being processed. - /// The function call contents representing the functions to be invoked. - /// The iteration number of how many roundtrips have been made to the inner client. - /// The number of consecutive iterations, prior to this one, that were recorded as having function invocation errors. - /// The to monitor for cancellation requests. - /// A value indicating how the caller should proceed. - private async Task<(bool ShouldTerminate, int NewConsecutiveErrorCount, IList MessagesAdded)> ProcessFunctionCallsAsync( - List messages, ChatOptions options, List functionCallContents, int iteration, int consecutiveErrorCount, CancellationToken cancellationToken) - { - // We must add a response for every tool call, regardless of whether we successfully executed it or not. - // If we successfully execute it, we'll add the result. If we don't, we'll add an error. - - Debug.Assert(functionCallContents.Count > 0, "Expecteded at least one function call."); - - var captureCurrentIterationExceptions = consecutiveErrorCount < _maximumConsecutiveErrorsPerRequest; - - // Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel. - if (functionCallContents.Count == 1) - { - FunctionInvocationResult result = await ProcessFunctionCallAsync( - messages, options, functionCallContents, iteration, 0, captureCurrentIterationExceptions, cancellationToken); - - IList added = CreateResponseMessages([result]); - ThrowIfNoFunctionResultsAdded(added); - UpdateConsecutiveErrorCountOrThrow(added, ref consecutiveErrorCount); - - messages.AddRange(added); - return (result.ShouldTerminate, consecutiveErrorCount, added); - } - else - { - FunctionInvocationResult[] results; - - if (AllowConcurrentInvocation) - { - // Rather than await'ing each function before invoking the next, invoke all of them - // and then await all of them. We avoid forcibly introducing parallelism via Task.Run, - // but if a function invocation completes asynchronously, its processing can overlap - // with the processing of other the other invocation invocations. - results = await Task.WhenAll( - from i in Enumerable.Range(0, functionCallContents.Count) - select ProcessFunctionCallAsync( - messages, options, functionCallContents, - iteration, i, captureExceptions: true, cancellationToken)); - } - else - { - // Invoke each function serially. - results = new FunctionInvocationResult[functionCallContents.Count]; - for (int i = 0; i < results.Length; i++) - { - results[i] = await ProcessFunctionCallAsync( - messages, options, functionCallContents, - iteration, i, captureCurrentIterationExceptions, cancellationToken); - } - } - - var shouldTerminate = false; - - IList added = CreateResponseMessages(results); - ThrowIfNoFunctionResultsAdded(added); - UpdateConsecutiveErrorCountOrThrow(added, ref consecutiveErrorCount); - - messages.AddRange(added); - foreach (FunctionInvocationResult fir in results) - { - shouldTerminate = shouldTerminate || fir.ShouldTerminate; - } - - return (shouldTerminate, consecutiveErrorCount, added); - } - } - - private void UpdateConsecutiveErrorCountOrThrow(IList added, ref int consecutiveErrorCount) - { - var allExceptions = added.SelectMany(m => m.Contents.OfType()) - .Select(frc => frc.Exception!) - .Where(e => e is not null); - - if (allExceptions.Any()) - { - consecutiveErrorCount++; - if (consecutiveErrorCount > _maximumConsecutiveErrorsPerRequest) - { - var allExceptionsArray = allExceptions.ToArray(); - if (allExceptionsArray.Length == 1) - { - ExceptionDispatchInfo.Capture(allExceptionsArray[0]).Throw(); - } - else - { - throw new AggregateException(allExceptionsArray); - } - } - } - else - { - consecutiveErrorCount = 0; - } - } - - /// - /// Throws an exception if doesn't create any messages. - /// - private void ThrowIfNoFunctionResultsAdded(IList? messages) - { - if (messages is null || messages.Count == 0) - { - Throw.InvalidOperationException($"{GetType().Name}.{nameof(CreateResponseMessages)} returned null or an empty collection of messages."); - } - } - - /// Processes the function call described in []. - /// The current chat contents, inclusive of the function call contents being processed. - /// The options used for the response being processed. - /// The function call contents representing all the functions being invoked. - /// The iteration number of how many roundtrips have been made to the inner client. - /// The 0-based index of the function being called out of . - /// If true, handles function-invocation exceptions by returning a value with . Otherwise, rethrows. - /// The to monitor for cancellation requests. - /// A value indicating how the caller should proceed. - private async Task ProcessFunctionCallAsync( - List messages, ChatOptions options, List callContents, - int iteration, int functionCallIndex, bool captureExceptions, CancellationToken cancellationToken) - { - var callContent = callContents[functionCallIndex]; - - // Look up the AIFunction for the function call. If the requested function isn't available, send back an error. - AIFunction? function = options.Tools!.OfType().FirstOrDefault(t => t.Name == callContent.Name); - if (function is null) - { - return new(shouldTerminate: false, FunctionInvocationStatus.NotFound, callContent, result: null, exception: null); - } - - FunctionInvocationContext context = new() - { - Function = function, - Arguments = new(callContent.Arguments) { Services = _functionInvocationServices }, - - Messages = messages, - Options = options, - - CallContent = callContent, - Iteration = iteration, - FunctionCallIndex = functionCallIndex, - FunctionCount = callContents.Count, - }; - - object? result; - try - { - result = await InvokeFunctionAsync(context, cancellationToken); - } - catch (Exception e) when (!cancellationToken.IsCancellationRequested) - { - if (!captureExceptions) - { - throw; - } - - return new( - shouldTerminate: false, - FunctionInvocationStatus.Exception, - callContent, - result: null, - exception: e); - } - - return new( - shouldTerminate: context.Terminate, - FunctionInvocationStatus.RanToCompletion, - callContent, - result, - exception: null); - } - - /// Creates one or more response messages for function invocation results. - /// Information about the function call invocations and results. - /// A list of all chat messages created from . - protected virtual IList CreateResponseMessages( - ReadOnlySpan results) - { - var contents = new List(results.Length); - for (int i = 0; i < results.Length; i++) - { - contents.Add(CreateFunctionResultContent(results[i])); - } - - return [new(ChatRole.Tool, contents)]; - - FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult result) - { - _ = Throw.IfNull(result); - - object? functionResult; - if (result.Status == FunctionInvocationStatus.RanToCompletion) - { - functionResult = result.Result ?? "Success: Function completed."; - } - else - { - string message = result.Status switch - { - FunctionInvocationStatus.NotFound => $"Error: Requested function \"{result.CallContent.Name}\" not found.", - FunctionInvocationStatus.Exception => "Error: Function failed.", - _ => "Error: Unknown error.", - }; - - if (IncludeDetailedErrors && result.Exception is not null) - { - message = $"{message} Exception: {result.Exception.Message}"; - } - - functionResult = message; - } - - return new FunctionResultContent(result.CallContent.CallId, functionResult) { Exception = result.Exception }; - } - } - - /// Invokes the function asynchronously. - /// - /// The function invocation context detailing the function to be invoked and its arguments along with additional request information. - /// - /// The to monitor for cancellation requests. The default is . - /// The result of the function invocation, or if the function invocation returned . - /// is . - protected virtual async Task InvokeFunctionAsync(FunctionInvocationContext context, CancellationToken cancellationToken) - { - _ = Throw.IfNull(context); - - using Activity? activity = _activitySource?.StartActivity(context.Function.Name); - - long startingTimestamp = 0; - if (_logger.IsEnabled(LogLevel.Debug)) - { - startingTimestamp = Stopwatch.GetTimestamp(); - if (_logger.IsEnabled(LogLevel.Trace)) - { - LogInvokingSensitive(context.Function.Name, LoggingHelpers.AsJson(context.Arguments, context.Function.JsonSerializerOptions)); - } - else - { - LogInvoking(context.Function.Name); - } - } - - object? result = null; - try - { - CurrentContext = context; // doesn't need to be explicitly reset after, as that's handled automatically at async method exit - result = await context.Function.InvokeAsync(context.Arguments, cancellationToken); - } - catch (Exception e) - { - if (activity is not null) - { - _ = activity.SetTag("error.type", e.GetType().FullName) - .SetStatus(ActivityStatusCode.Error, e.Message); - } - - if (e is OperationCanceledException) - { - LogInvocationCanceled(context.Function.Name); - } - else - { - LogInvocationFailed(context.Function.Name, e); - } - - throw; - } - finally - { - if (_logger.IsEnabled(LogLevel.Debug)) - { - TimeSpan elapsed = GetElapsedTime(startingTimestamp); - - if (result is not null && _logger.IsEnabled(LogLevel.Trace)) - { - LogInvocationCompletedSensitive(context.Function.Name, elapsed, LoggingHelpers.AsJson(result, context.Function.JsonSerializerOptions)); - } - else - { - LogInvocationCompleted(context.Function.Name, elapsed); - } - } - } - - return result; - } - - private static TimeSpan GetElapsedTime(long startingTimestamp) => -#if NET - Stopwatch.GetElapsedTime(startingTimestamp); -#else - new((long)((Stopwatch.GetTimestamp() - startingTimestamp) * ((double)TimeSpan.TicksPerSecond / Stopwatch.Frequency))); -#endif - - [LoggerMessage(LogLevel.Debug, "Invoking {MethodName}.", SkipEnabledCheck = true)] - private partial void LogInvoking(string methodName); - - [LoggerMessage(LogLevel.Trace, "Invoking {MethodName}({Arguments}).", SkipEnabledCheck = true)] - private partial void LogInvokingSensitive(string methodName, string arguments); - - [LoggerMessage(LogLevel.Debug, "{MethodName} invocation completed. Duration: {Duration}", SkipEnabledCheck = true)] - private partial void LogInvocationCompleted(string methodName, TimeSpan duration); - - [LoggerMessage(LogLevel.Trace, "{MethodName} invocation completed. Duration: {Duration}. Result: {Result}", SkipEnabledCheck = true)] - private partial void LogInvocationCompletedSensitive(string methodName, TimeSpan duration, string result); - - [LoggerMessage(LogLevel.Debug, "{MethodName} invocation canceled.")] - private partial void LogInvocationCanceled(string methodName); - - [LoggerMessage(LogLevel.Error, "{MethodName} invocation failed.")] - private partial void LogInvocationFailed(string methodName, Exception error); - - /// Provides information about the invocation of a function call. - public sealed class FunctionInvocationResult - { - internal FunctionInvocationResult(bool shouldTerminate, FunctionInvocationStatus status, FunctionCallContent callContent, object? result, Exception? exception) - { - ShouldTerminate = shouldTerminate; - Status = status; - CallContent = callContent; - Result = result; - Exception = exception; - } - - /// Gets status about how the function invocation completed. - public FunctionInvocationStatus Status { get; } - - /// Gets the function call content information associated with this invocation. - public FunctionCallContent CallContent { get; } - - /// Gets the result of the function call. - public object? Result { get; } - - /// Gets any exception the function call threw. - public Exception? Exception { get; } - - /// Gets a value indicating whether the caller should terminate the processing loop. - internal bool ShouldTerminate { get; } - } - - /// Provides error codes for when errors occur as part of the function calling loop. - public enum FunctionInvocationStatus - { - /// The operation completed successfully. - RanToCompletion, - - /// The requested function could not be found. - NotFound, - - /// The function call failed with an exception. - Exception, - } -} From 6d7a374b9808d5e0f072a9279bb3ec1ec5db9833 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Mon, 14 Apr 2025 12:35:24 +0100 Subject: [PATCH 21/26] Removing KernelFunctionINvocationContext in favor of Microsoft.Extensions.AI.FunctionInvocationContext --- .../AI/ChatClient/ChatClientExtensions.cs | 2 +- .../KernelFunctionInvocationContext.cs | 127 ------------------ .../AutoFunctionInvocationContext.cs | 4 +- 3 files changed, 3 insertions(+), 130 deletions(-) delete mode 100644 dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs index 223530beb1ba..e035a436a83a 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs @@ -92,7 +92,7 @@ public static IChatCompletionService AsChatCompletionService(this IChatClient cl /// Creates a new that supports for function invocation with a . /// /// Target chat client service. - /// Optional logger to use for logging. + /// Optional logger factory to use for logging. /// Function invoking chat client. [Experimental("SKEXP0001")] public static IChatClient AsKernelFunctionInvokingChatClient(this IChatClient client, ILoggerFactory? loggerFactory = null) diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs deleted file mode 100644 index efc3d39ab48a..000000000000 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; -using System.Runtime.CompilerServices; -using Microsoft.Extensions.AI; - -#pragma warning disable IDE0009 // Use explicit 'this.' qualifier -#pragma warning disable CA2213 // Disposable fields should be disposed -#pragma warning disable IDE0044 // Add readonly modifier - -namespace Microsoft.SemanticKernel.ChatCompletion; - -// Slight modified source from -// https://raw.githubusercontent.com/dotnet/extensions/refs/heads/main/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs - -/// Provides context for an in-flight function invocation. -[ExcludeFromCodeCoverage] -internal class KernelFunctionInvocationContext -{ - /// - /// A nop function used to allow to be non-nullable. Default instances of - /// start with this as the target function. - /// - private static readonly AIFunction _nopFunction = AIFunctionFactory.Create(() => { }, nameof(FunctionInvocationContext)); - - /// The chat contents associated with the operation that initiated this function call request. - private IList _messages = Array.Empty(); - - /// The AI function to be invoked. - private AIFunction _function = _nopFunction; - - /// The function call content information associated with this invocation. - private Microsoft.Extensions.AI.FunctionCallContent? _callContent; - - /// The arguments used with the function. - private AIFunctionArguments? _arguments; - - /// Initializes a new instance of the class. - public KernelFunctionInvocationContext() - { - } - - /// Gets or sets the AI function to be invoked. - public AIFunction Function - { - get => _function; - set => _function = Throw.IfNull(value); - } - - /// Gets or sets the arguments associated with this invocation. - public AIFunctionArguments Arguments - { - get => _arguments ??= []; - set => _arguments = Throw.IfNull(value); - } - - /// Gets or sets the function call content information associated with this invocation. - public Microsoft.Extensions.AI.FunctionCallContent CallContent - { - get => _callContent ??= new(string.Empty, _nopFunction.Name, EmptyReadOnlyDictionary.Instance); - set => _callContent = Throw.IfNull(value); - } - - /// Gets or sets the chat contents associated with the operation that initiated this function call request. - public IList Messages - { - get => _messages; - set => _messages = Throw.IfNull(value); - } - - /// Gets or sets the chat options associated with the operation that initiated this function call request. - public ChatOptions? Options { get; set; } - - /// Gets or sets the number of this iteration with the underlying client. - /// - /// The initial request to the client that passes along the chat contents provided to the - /// is iteration 1. If the client responds with a function call request, the next request to the client is iteration 2, and so on. - /// - public int Iteration { get; set; } - - /// Gets or sets the index of the function call within the iteration. - /// - /// The response from the underlying client may include multiple function call requests. - /// This index indicates the position of the function call within the iteration. - /// - public int FunctionCallIndex { get; set; } - - /// Gets or sets the total number of function call requests within the iteration. - /// - /// The response from the underlying client might include multiple function call requests. - /// This count indicates how many there were. - /// - public int FunctionCount { get; set; } - - /// Gets or sets a value indicating whether to terminate the request. - /// - /// In response to a function call request, the function might be invoked, its result added to the chat contents, - /// and a new request issued to the wrapped client. If this property is set to , that subsequent request - /// will not be issued and instead the loop immediately terminated rather than continuing until there are no - /// more function call requests in responses. - /// - public bool Terminate { get; set; } - - private static class Throw - { - /// - /// Throws an if the specified argument is . - /// - /// Argument type to be checked for . - /// Object to be checked for . - /// The name of the parameter being checked. - /// The original value of . - [MethodImpl(MethodImplOptions.AggressiveInlining)] - [return: NotNull] - public static T IfNull([NotNull] T argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") - { - if (argument is null) - { - throw new ArgumentNullException(paramName); - } - - return argument; - } - } -} diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs index f70205365330..a0ddb480c275 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs @@ -17,12 +17,12 @@ public class AutoFunctionInvocationContext { private ChatHistory? _chatHistory; private KernelFunction? _kernelFunction; - private readonly KernelFunctionInvocationContext _invocationContext = new(); + private readonly Microsoft.Extensions.AI.FunctionInvocationContext _invocationContext = new(); /// /// Initializes a new instance of the class from an existing . /// - internal AutoFunctionInvocationContext(KernelFunctionInvocationContext invocationContext) + internal AutoFunctionInvocationContext(Microsoft.Extensions.AI.FunctionInvocationContext invocationContext) { Verify.NotNull(invocationContext); Verify.NotNull(invocationContext.Options); From 7a380273bd21e7dcd5c7e2db5908bdbd2c2f3bf8 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Mon, 14 Apr 2025 12:58:32 +0100 Subject: [PATCH 22/26] Fix reference --- .../AutoFunctionInvocation/AutoFunctionInvocationContext.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs index a0ddb480c275..bc8dd0c3490c 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs @@ -20,7 +20,7 @@ public class AutoFunctionInvocationContext private readonly Microsoft.Extensions.AI.FunctionInvocationContext _invocationContext = new(); /// - /// Initializes a new instance of the class from an existing . + /// Initializes a new instance of the class from an existing . /// internal AutoFunctionInvocationContext(Microsoft.Extensions.AI.FunctionInvocationContext invocationContext) { From c989c142e7c4ad1daeb0dd4b88bdf8254c2ab396 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Mon, 14 Apr 2025 13:00:41 +0100 Subject: [PATCH 23/26] Typo fix --- .../AI/ChatClient/KernelFunctionInvokingChatClient.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index e6496cc63f4f..5c77bb415867 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -367,7 +367,7 @@ public override async IAsyncEnumerable GetStreamingResponseA break; } - // Reconsistitue a response from the response updates. + // Reconstitute a response from the response updates. var response = updates.ToChatResponse(); (responseMessages ??= []).AddRange(response.Messages); From 4316651ad72f81e418a4c0fbab1a750997218aa6 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Mon, 14 Apr 2025 13:08:06 +0100 Subject: [PATCH 24/26] Fix typos + virtual to private --- .../KernelFunctionInvokingChatClient.cs | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index 5c77bb415867..544fefe72493 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -117,13 +117,13 @@ public static AutoFunctionInvocationContext? CurrentContext /// /// /// Setting the value to can help the underlying bypass problems on - /// its own, for example by retrying the function call with different arguments. However it might + /// its own, for example by retrying the function call with different arguments. However, it might /// result in disclosing the raw exception information to external users, which can be a security /// concern depending on the application scenario. /// /// /// Changing the value of this property while the client is in use might result in inconsistencies - /// as to whether detailed errors are provided during an in-flight request. + /// whether detailed errors are provided during an in-flight request. /// /// public bool IncludeDetailedErrors { get; set; } @@ -265,7 +265,7 @@ public override async Task GetResponseAsync( return response; } - // Track aggregatable details from the response, including all of the response messages and usage details. + // Track aggregate details from the response, including all the response messages and usage details. (responseMessages ??= []).AddRange(response.Messages); if (response.Usage is not null) { @@ -377,7 +377,7 @@ public override async IAsyncEnumerable GetStreamingResponseA // Prepare the options for the next auto function invocation iteration. UpdateOptionsForAutoFunctionInvocation(ref options, response.Messages.Last().ToChatMessageContent(), isStreaming: true); - // Process all of the functions, adding their results into the history. + // Process all the functions, adding their results into the history. var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, isStreaming: true, cancellationToken).ConfigureAwait(false); responseMessages.AddRange(modeAndMessages.MessagesAdded); consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount; @@ -393,7 +393,7 @@ public override async IAsyncEnumerable GetStreamingResponseA string toolResponseId = Guid.NewGuid().ToString("N"); // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages - // includes all activitys, including generated function results. + // include all activity, including generated function results. foreach (var message in modeAndMessages.MessagesAdded) { var toolResultUpdate = new ChatResponseUpdate @@ -457,8 +457,8 @@ private static void FixupHistories( else if (lastIterationHadThreadId) { // In the very rare case where the inner client returned a response with a thread ID but then - // returned a subsequent response without one, we want to reconstitue the full history. To do that, - // we can populate the history with the original chat messages and then all of the response + // returned a subsequent response without one, we want to reconstitute the full history. To do that, + // we can populate the history with the original chat messages and then all the response // messages up until this point, which includes the most recent ones. augmentedHistory ??= []; augmentedHistory.Clear(); @@ -584,7 +584,7 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions options, strin // We must add a response for every tool call, regardless of whether we successfully executed it or not. // If we successfully execute it, we'll add the result. If we don't, we'll add an error. - Debug.Assert(functionCallContents.Count > 0, "Expecteded at least one function call."); + Debug.Assert(functionCallContents.Count > 0, "Expected at least one function call."); var shouldTerminate = false; var captureCurrentIterationExceptions = consecutiveErrorCount < _maximumConsecutiveErrorsPerRequest; @@ -609,7 +609,7 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions options, strin var terminationRequested = false; if (AllowConcurrentInvocation) { - // Rather than await'ing each function before invoking the next, invoke all of them + // Rather than awaiting each function before invoking the next, invoke all of them // and then await all of them. We avoid forcibly introducing parallelism via Task.Run, // but if a function invocation completes asynchronously, its processing can overlap // with the processing of other the other invocation invocations. @@ -776,8 +776,7 @@ private async Task ProcessFunctionCallAsync( /// Creates one or more response messages for function invocation results. /// Information about the function call invocations and results. /// A list of all chat messages created from . - protected virtual IList CreateResponseMessages( - IReadOnlyList results) + private IList CreateResponseMessages(List results) { var contents = new List(results.Count); for (int i = 0; i < results.Count; i++) @@ -866,7 +865,7 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync( /// The to monitor for cancellation requests. The default is . /// The result of the function invocation, or if the function invocation returned . /// is . - protected virtual async Task InvokeFunctionAsync(AutoFunctionInvocationContext context, CancellationToken cancellationToken) + private async Task InvokeFunctionAsync(AutoFunctionInvocationContext context, CancellationToken cancellationToken) { Verify.NotNull(context); From cb636e941155aec247f9451056d4912a5880d78f Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Tue, 15 Apr 2025 13:32:03 +0100 Subject: [PATCH 25/26] Address PR comments --- .../OpenAIKernelBuilderExtensions.ChatClient.cs | 11 +++++------ .../Extensions/OpenAIKernelBuilderExtensions.cs | 8 ++++---- .../AI/ChatClient/ChatOptionsExtensions.cs | 2 +- .../AI/ChatClient/KernelFunctionInvokingChatClient.cs | 2 +- .../Functions/KernelFunction.cs | 1 - 5 files changed, 11 insertions(+), 13 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs index ac0005f5d5c5..9d1832b340ff 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs @@ -3,20 +3,19 @@ using System; using System.Diagnostics.CodeAnalysis; using System.Net.Http; +using Microsoft.Extensions.AI; using OpenAI; namespace Microsoft.SemanticKernel; -/// -/// Sponsor extensions class for . -/// +/// Extension methods for . [Experimental("SKEXP0010")] public static class OpenAIChatClientKernelBuilderExtensions { #region Chat Completion /// - /// Adds the OpenAI chat completion service to the list. + /// Adds an OpenAI to the . /// /// The instance to augment. /// OpenAI model name, see https://platform.openai.com/docs/models @@ -46,7 +45,7 @@ public static IKernelBuilder AddOpenAIChatClient( } /// - /// Adds the OpenAI chat completion service to the list. + /// Adds an OpenAI to the . /// /// The instance to augment. /// OpenAI model id @@ -70,7 +69,7 @@ public static IKernelBuilder AddOpenAIChatClient( } /// - /// Adds the Custom Endpoint OpenAI chat completion service to the list. + /// Adds a custom endpoint OpenAI to the . /// /// The instance to augment. /// OpenAI model name, see https://platform.openai.com/docs/models diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.cs index a07a81fdb5b3..01307b9adc2a 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.cs @@ -20,7 +20,7 @@ namespace Microsoft.SemanticKernel; /// -/// Sponsor extensions class for . +/// Extension methods for . /// public static class OpenAIKernelBuilderExtensions { @@ -269,7 +269,7 @@ public static IKernelBuilder AddOpenAIFiles( #region Chat Completion /// - /// Adds the OpenAI chat completion service to the list. + /// Adds an to the . /// /// The instance to augment. /// OpenAI model name, see https://platform.openai.com/docs/models @@ -304,7 +304,7 @@ OpenAIChatCompletionService Factory(IServiceProvider serviceProvider, object? _) } /// - /// Adds the OpenAI chat completion service to the list. + /// Adds an to the . /// /// The instance to augment. /// OpenAI model id @@ -330,7 +330,7 @@ OpenAIChatCompletionService Factory(IServiceProvider serviceProvider, object? _) } /// - /// Adds the Custom Endpoint OpenAI chat completion service to the list. + /// Adds a custom endpoint to the . /// /// The instance to augment. /// OpenAI model name, see https://platform.openai.com/docs/models diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs index 5db8240b1707..68540a1c32d8 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs @@ -137,7 +137,7 @@ internal static ChatOptions AddKernel(this ChatOptions options, Kernel? kernel) if (kernel is not null) { options.AdditionalProperties ??= []; - options.AdditionalProperties?.TryAdd(KernelKey, kernel); + options.AdditionalProperties.TryAdd(KernelKey, kernel); } return options; diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs index 544fefe72493..ea2dce48fc62 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs @@ -226,7 +226,7 @@ public override async Task GetResponseAsync( // A single request into this GetResponseAsync may result in multiple requests to the inner client. // Create an activity to group them together for better observability. - using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient)); + using Activity? activity = _activitySource?.StartActivity(nameof(KernelFunctionInvokingChatClient)); // Copy the original messages in order to avoid enumerating the original messages multiple times. // The IEnumerable can represent an arbitrary amount of work. diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs index cb480c71c605..0e09c1d525b7 100644 --- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs +++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs @@ -542,7 +542,6 @@ public KernelAIFunction(KernelFunction kernelFunction, Kernel? kernel) this.JsonSchema = BuildFunctionSchema(kernelFunction); } - internal string KernelFunctionName => this._kernelFunction.Name; public override string Name { get; } public override JsonElement JsonSchema { get; } public override string Description => this._kernelFunction.Description; From 7043c726c2de9d455a28fed55427f4de639f7384 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Wed, 16 Apr 2025 14:29:14 +0100 Subject: [PATCH 26/26] Address PR feedback --- .../AI/ChatClient/ChatClientAIService.cs | 2 +- .../AI/ChatCompletion/ChatHistory.cs | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs index b840b33e690b..58ea317804f9 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs @@ -33,7 +33,7 @@ internal ChatClientAIService(IChatClient chatClient) var metadata = this._chatClient.GetService(); Verify.NotNull(metadata); - this._internalAttributes["ModelId"] = metadata.DefaultModelId; + this._internalAttributes[AIServiceExtensions.ModelIdKey] = metadata.DefaultModelId; this._internalAttributes[nameof(metadata.ProviderName)] = metadata.ProviderName; this._internalAttributes[nameof(metadata.ProviderUri)] = metadata.ProviderUri; } diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs index 0054dc98a400..147cdd5ba332 100644 --- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs +++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs @@ -35,7 +35,9 @@ public ChatHistory() this._messages = []; } - // Due to changes using the AutoFunctionInvocation as a dependency of KernelInvocation, that needs to reflect + // This allows observation of the chat history changes by-reference reflecting in an + // internal IEnumerable when used from IChatClients + // with AutoFunctionInvocationFilters internal void SetOverrides( Action overrideAdd, Func overrideRemove,