From fa1776f2297c63f9abb1b44510eb00f55229498a Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Tue, 1 Apr 2025 21:34:20 +0100
Subject: [PATCH 01/26] AddChatClient OpenAI WIP
---
...tClient_AutoFunctionInvocationFiltering.cs | 167 ++++++++++++++++++
.../Connectors.OpenAI.csproj | 1 +
...penAIKernelBuilderExtensions.ChatClient.cs | 106 +++++++++++
...IServiceCollectionExtensions.ChatClient.cs | 157 ++++++++++++++++
.../AI/ChatClient/ChatClientExtensions.cs | 6 +-
.../ClientResultExceptionExtensionsTests.cs | 1 -
6 files changed, 435 insertions(+), 3 deletions(-)
create mode 100644 dotnet/samples/Concepts/Filtering/ChatClient_AutoFunctionInvocationFiltering.cs
create mode 100644 dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs
create mode 100644 dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs
diff --git a/dotnet/samples/Concepts/Filtering/ChatClient_AutoFunctionInvocationFiltering.cs b/dotnet/samples/Concepts/Filtering/ChatClient_AutoFunctionInvocationFiltering.cs
new file mode 100644
index 000000000000..1e053618a385
--- /dev/null
+++ b/dotnet/samples/Concepts/Filtering/ChatClient_AutoFunctionInvocationFiltering.cs
@@ -0,0 +1,167 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.Connectors.OpenAI;
+
+namespace Filtering;
+
+public class ChatClient_AutoFunctionInvocationFiltering(ITestOutputHelper output) : BaseTest(output)
+{
+ ///
+ /// Shows how to use .
+ ///
+ [Fact]
+ public async Task UsingAutoFunctionInvocationFilter()
+ {
+ var builder = Kernel.CreateBuilder();
+
+ builder.AddOpenAIChatClient("gpt-4", TestConfiguration.OpenAI.ApiKey);
+
+ // This filter outputs information about auto function invocation and returns overridden result.
+ builder.Services.AddSingleton(new AutoFunctionInvocationFilter(this.Output));
+
+ var kernel = builder.Build();
+
+ var function = KernelFunctionFactory.CreateFromMethod(() => "Result from function", "MyFunction");
+
+ kernel.ImportPluginFromFunctions("MyPlugin", [function]);
+
+ var executionSettings = new OpenAIPromptExecutionSettings
+ {
+ FunctionChoiceBehavior = FunctionChoiceBehavior.Required([function], autoInvoke: true)
+ };
+
+ var result = await kernel.InvokePromptAsync("Invoke provided function and return result", new(executionSettings));
+
+ Console.WriteLine(result);
+
+ // Output:
+ // Request sequence number: 0
+ // Function sequence number: 0
+ // Total number of functions: 1
+ // Result from auto function invocation filter.
+ }
+
+ ///
+ /// Shows how to get list of function calls by using .
+ ///
+ [Fact]
+ public async Task GetFunctionCallsWithFilterAsync()
+ {
+ var builder = Kernel.CreateBuilder();
+
+ builder.AddOpenAIChatCompletion("gpt-3.5-turbo-1106", TestConfiguration.OpenAI.ApiKey);
+
+ builder.Services.AddSingleton(new FunctionCallsFilter(this.Output));
+
+ var kernel = builder.Build();
+
+ kernel.ImportPluginFromFunctions("HelperFunctions",
+ [
+ kernel.CreateFunctionFromMethod(() => DateTime.UtcNow.ToString("R"), "GetCurrentUtcTime", "Retrieves the current time in UTC."),
+ kernel.CreateFunctionFromMethod((string cityName) =>
+ cityName switch
+ {
+ "Boston" => "61 and rainy",
+ "London" => "55 and cloudy",
+ "Miami" => "80 and sunny",
+ "Paris" => "60 and rainy",
+ "Tokyo" => "50 and sunny",
+ "Sydney" => "75 and sunny",
+ "Tel Aviv" => "80 and sunny",
+ _ => "31 and snowing",
+ }, "GetWeatherForCity", "Gets the current weather for the specified city"),
+ ]);
+
+ var executionSettings = new OpenAIPromptExecutionSettings
+ {
+ FunctionChoiceBehavior = FunctionChoiceBehavior.Auto()
+ };
+
+ await foreach (var chunk in kernel.InvokePromptStreamingAsync("Check current UTC time and return current weather in Boston city.", new(executionSettings)))
+ {
+ Console.WriteLine(chunk.ToString());
+ }
+
+ // Output:
+ // Request #0. Function call: HelperFunctions.GetCurrentUtcTime.
+ // Request #0. Function call: HelperFunctions.GetWeatherForCity.
+ // The current UTC time is {time of execution}, and the current weather in Boston is 61°F and rainy.
+ }
+
+ /// Shows available syntax for auto function invocation filter.
+ private sealed class AutoFunctionInvocationFilter(ITestOutputHelper output) : IAutoFunctionInvocationFilter
+ {
+ public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next)
+ {
+ // Example: get function information
+ var functionName = context.Function.Name;
+
+ // Example: get chat history
+ var chatHistory = context.ChatHistory;
+
+ // Example: get information about all functions which will be invoked
+ var functionCalls = FunctionCallContent.GetFunctionCalls(context.ChatHistory.Last());
+
+ // In function calling functionality there are two loops.
+ // Outer loop is "request" loop - it performs multiple requests to LLM until user ask will be satisfied.
+ // Inner loop is "function" loop - it handles LLM response with multiple function calls.
+
+ // Workflow example:
+ // 1. Request to LLM #1 -> Response with 3 functions to call.
+ // 1.1. Function #1 called.
+ // 1.2. Function #2 called.
+ // 1.3. Function #3 called.
+ // 2. Request to LLM #2 -> Response with 2 functions to call.
+ // 2.1. Function #1 called.
+ // 2.2. Function #2 called.
+
+ // context.RequestSequenceIndex - it's a sequence number of outer/request loop operation.
+ // context.FunctionSequenceIndex - it's a sequence number of inner/function loop operation.
+ // context.FunctionCount - number of functions which will be called per request (based on example above: 3 for first request, 2 for second request).
+
+ // Example: get request sequence index
+ output.WriteLine($"Request sequence index: {context.RequestSequenceIndex}");
+
+ // Example: get function sequence index
+ output.WriteLine($"Function sequence index: {context.FunctionSequenceIndex}");
+
+ // Example: get total number of functions which will be called
+ output.WriteLine($"Total number of functions: {context.FunctionCount}");
+
+ // Calling next filter in pipeline or function itself.
+ // By skipping this call, next filters and function won't be invoked, and function call loop will proceed to the next function.
+ await next(context);
+
+ // Example: get function result
+ var result = context.Result;
+
+ // Example: override function result value
+ context.Result = new FunctionResult(context.Result, "Result from auto function invocation filter");
+
+ // Example: Terminate function invocation
+ context.Terminate = true;
+ }
+ }
+
+ /// Shows how to get list of all function calls per request.
+ private sealed class FunctionCallsFilter(ITestOutputHelper output) : IAutoFunctionInvocationFilter
+ {
+ public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next)
+ {
+ var chatHistory = context.ChatHistory;
+ var functionCalls = FunctionCallContent.GetFunctionCalls(chatHistory.Last()).ToArray();
+
+ if (functionCalls is { Length: > 0 })
+ {
+ foreach (var functionCall in functionCalls)
+ {
+ output.WriteLine($"Request #{context.RequestSequenceIndex}. Function call: {functionCall.PluginName}.{functionCall.FunctionName}.");
+ }
+ }
+
+ await next(context);
+ }
+ }
+}
diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj b/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj
index 64a0e72bde6d..c17e878a7a42 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj
+++ b/dotnet/src/Connectors/Connectors.OpenAI/Connectors.OpenAI.csproj
@@ -37,6 +37,7 @@
+
diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs
new file mode 100644
index 000000000000..e38aa5289104
--- /dev/null
+++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs
@@ -0,0 +1,106 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System;
+using System.Diagnostics.CodeAnalysis;
+using System.Net.Http;
+using OpenAI;
+
+namespace Microsoft.SemanticKernel;
+
+///
+/// Sponsor extensions class for .
+///
+public static class OpenAIChatClientKernelBuilderExtensions
+{
+ #region Chat Completion
+
+ ///
+ /// Adds the OpenAI chat completion service to the list.
+ ///
+ /// The instance to augment.
+ /// OpenAI model name, see https://platform.openai.com/docs/models
+ /// OpenAI API key, see https://platform.openai.com/account/api-keys
+ /// OpenAI organization id. This is usually optional unless your account belongs to multiple organizations.
+ /// A local identifier for the given AI service
+ /// The HttpClient to use with this service.
+ /// The same instance as .
+ public static IKernelBuilder AddOpenAIChatClient(
+ this IKernelBuilder builder,
+ string modelId,
+ string apiKey,
+ string? orgId = null,
+ string? serviceId = null,
+ HttpClient? httpClient = null)
+ {
+ Verify.NotNull(builder);
+
+ builder.Services.AddOpenAIChatClient(
+ modelId,
+ apiKey,
+ orgId,
+ serviceId,
+ httpClient);
+
+ return builder;
+ }
+
+ ///
+ /// Adds the OpenAI chat completion service to the list.
+ ///
+ /// The instance to augment.
+ /// OpenAI model id
+ /// to use for the service. If null, one must be available in the service provider when this service is resolved.
+ /// A local identifier for the given AI service
+ /// The same instance as .
+ public static IKernelBuilder AddOpenAIChatClient(
+ this IKernelBuilder builder,
+ string modelId,
+ OpenAIClient? openAIClient = null,
+ string? serviceId = null)
+ {
+ Verify.NotNull(builder);
+
+ builder.Services.AddOpenAIChatClient(
+ modelId,
+ openAIClient,
+ serviceId);
+
+ return builder;
+ }
+
+ ///
+ /// Adds the Custom Endpoint OpenAI chat completion service to the list.
+ ///
+ /// The instance to augment.
+ /// OpenAI model name, see https://platform.openai.com/docs/models
+ /// Custom OpenAI Compatible Message API endpoint
+ /// OpenAI API key, see https://platform.openai.com/account/api-keys
+ /// OpenAI organization id. This is usually optional unless your account belongs to multiple organizations.
+ /// A local identifier for the given AI service
+ /// The HttpClient to use with this service.
+ /// The same instance as .
+ [Experimental("SKEXP0010")]
+ public static IKernelBuilder AddOpenAIChatClient(
+ this IKernelBuilder builder,
+ string modelId,
+ Uri endpoint,
+ string? apiKey,
+ string? orgId = null,
+ string? serviceId = null,
+ HttpClient? httpClient = null)
+ {
+ Verify.NotNull(builder);
+
+ builder.Services.AddOpenAIChatClient(
+ modelId,
+ endpoint,
+ apiKey,
+ orgId,
+ serviceId,
+ httpClient);
+
+ return builder;
+ }
+
+ #endregion
+}
diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs
new file mode 100644
index 000000000000..ff214c441c2e
--- /dev/null
+++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs
@@ -0,0 +1,157 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System;
+using System.ClientModel;
+using System.ClientModel.Primitives;
+using System.Net.Http;
+using Microsoft.Extensions.AI;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Logging;
+using Microsoft.SemanticKernel.ChatCompletion;
+using Microsoft.SemanticKernel.Http;
+using OpenAI;
+
+namespace Microsoft.SemanticKernel;
+
+#pragma warning disable IDE0039 // Use local function
+
+///
+/// Sponsor extensions class for .
+///
+public static class OpenAIChatClientServiceCollectionExtensions
+{
+ #region Chat Completion
+
+ ///
+ /// Adds the OpenAI chat completion service to the list.
+ ///
+ /// The instance to augment.
+ /// OpenAI model name, see https://platform.openai.com/docs/models
+ /// OpenAI API key, see https://platform.openai.com/account/api-keys
+ /// OpenAI organization id. This is usually optional unless your account belongs to multiple organizations.
+ /// A local identifier for the given AI service
+ /// The HttpClient to use with this service.
+ /// The same instance as .
+ public static IServiceCollection AddOpenAIChatClient(
+ this IServiceCollection services,
+ string modelId,
+ string apiKey,
+ string? orgId = null,
+ string? serviceId = null,
+ HttpClient? httpClient = null)
+ {
+ Verify.NotNull(services);
+
+ IChatClient Factory(IServiceProvider serviceProvider, object? _)
+ {
+ ILogger? logger = serviceProvider.GetService()?.CreateLogger();
+
+ return new Microsoft.Extensions.AI.OpenAIChatClient(
+ openAIClient: new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), options: GetClientOptions(orgId: orgId, httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider))),
+ modelId: modelId)
+ .AsKernelFunctionInvokingChatClient(logger);
+ }
+
+ services.AddKeyedSingleton(serviceId, (Func)Factory);
+
+ return services;
+ }
+
+ ///
+ /// Adds the OpenAI chat completion service to the list.
+ ///
+ /// The instance to augment.
+ /// OpenAI model id
+ /// to use for the service. If null, one must be available in the service provider when this service is resolved.
+ /// A local identifier for the given AI service
+ /// The same instance as .
+ public static IServiceCollection AddOpenAIChatClient(this IServiceCollection services,
+ string modelId,
+ OpenAIClient? openAIClient = null,
+ string? serviceId = null)
+ {
+ Verify.NotNull(services);
+
+ IChatClient Factory(IServiceProvider serviceProvider, object? _)
+ {
+ ILogger? logger = serviceProvider.GetService()?.CreateLogger();
+
+ return new Microsoft.Extensions.AI.OpenAIChatClient(
+ openAIClient ?? serviceProvider.GetRequiredService(),
+ modelId)
+ .AsKernelFunctionInvokingChatClient(logger);
+ }
+
+ services.AddKeyedSingleton(serviceId, (Func)Factory);
+
+ return services;
+ }
+
+ ///
+ /// Adds the Custom OpenAI chat completion service to the list.
+ ///
+ /// The instance to augment.
+ /// OpenAI model name, see https://platform.openai.com/docs/models
+ /// A Custom Message API compatible endpoint.
+ /// OpenAI API key, see https://platform.openai.com/account/api-keys
+ /// OpenAI organization id. This is usually optional unless your account belongs to multiple organizations.
+ /// A local identifier for the given AI service
+ /// The HttpClient to use with this service.
+ /// The same instance as .
+ public static IServiceCollection AddOpenAIChatClient(
+ this IServiceCollection services,
+ string modelId,
+ Uri endpoint,
+ string? apiKey = null,
+ string? orgId = null,
+ string? serviceId = null,
+ HttpClient? httpClient = null)
+ {
+ Verify.NotNull(services);
+
+ IChatClient Factory(IServiceProvider serviceProvider, object? _)
+ {
+ ILogger? logger = serviceProvider.GetService()?.CreateLogger();
+
+ return new Microsoft.Extensions.AI.OpenAIChatClient(
+ openAIClient: new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), GetClientOptions(endpoint, orgId, HttpClientProvider.GetHttpClient(httpClient, serviceProvider))),
+ modelId: modelId)
+ .AsKernelFunctionInvokingChatClient(logger);
+ }
+
+ services.AddKeyedSingleton(serviceId, (Func)Factory);
+
+ return services;
+ }
+
+ private static OpenAIClientOptions GetClientOptions(
+ Uri? endpoint = null,
+ string? orgId = null,
+ HttpClient? httpClient = null)
+ {
+ OpenAIClientOptions options = new();
+
+ if (endpoint is not null)
+ {
+ options.Endpoint = endpoint;
+ }
+
+ if (orgId is not null)
+ {
+ options.OrganizationId = orgId;
+ }
+
+ if (httpClient is not null)
+ {
+ options.Transport = new HttpClientPipelineTransport(httpClient);
+ }
+
+ return options;
+ }
+
+ ///
+ /// White space constant.
+ ///
+ private const string SingleSpace = " ";
+ #endregion
+}
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
index c36ca453bd04..c480772750de 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
@@ -5,6 +5,7 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
+using Microsoft.Extensions.Logging;
namespace Microsoft.SemanticKernel.ChatCompletion;
@@ -71,14 +72,15 @@ public static IChatCompletionService AsChatCompletionService(this IChatClient cl
/// Creates a new that supports for function invocation with a .
///
/// Target chat client service.
+ /// Optional logger to use for logging.
/// Function invoking chat client.
[Experimental("SKEXP0001")]
- public static IChatClient AsKernelFunctionInvokingChatClient(this IChatClient client)
+ public static IChatClient AsKernelFunctionInvokingChatClient(this IChatClient client, ILogger? logger = null)
{
Verify.NotNull(client);
return client is KernelFunctionInvokingChatClient kernelFunctionInvocationClient
? kernelFunctionInvocationClient
- : new KernelFunctionInvokingChatClient(client);
+ : new KernelFunctionInvokingChatClient(client, logger);
}
}
diff --git a/dotnet/src/SemanticKernel.UnitTests/Extensions/ClientResultExceptionExtensionsTests.cs b/dotnet/src/SemanticKernel.UnitTests/Extensions/ClientResultExceptionExtensionsTests.cs
index c9c348d1ac44..c25fa0ccd249 100644
--- a/dotnet/src/SemanticKernel.UnitTests/Extensions/ClientResultExceptionExtensionsTests.cs
+++ b/dotnet/src/SemanticKernel.UnitTests/Extensions/ClientResultExceptionExtensionsTests.cs
@@ -64,7 +64,6 @@ public void ItProvideStatusForResponsesWithoutContent()
// Assert
Assert.NotNull(httpOperationException);
Assert.NotNull(httpOperationException.StatusCode);
- Assert.Null(httpOperationException.ResponseContent);
Assert.Equal(exception, httpOperationException.InnerException);
Assert.Equal(exception.Message, httpOperationException.Message);
Assert.Equal(pipelineResponse.Status, (int)httpOperationException.StatusCode!);
From 69ff2c2a57b6f76959b9212a3cff989958097daf Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Thu, 3 Apr 2025 13:23:46 +0100
Subject: [PATCH 02/26] Adding UT and Extension Methods
---
...FunctionInvocationFilterChatClientTests.cs | 776 ++++++++++++++++++
.../Core/AutoFunctionInvocationFilterTests.cs | 2 +-
...IKernelBuilderExtensionsChatClientTests.cs | 92 +++
...viceCollectionExtensionsChatClientTests.cs | 114 +++
...penAIKernelBuilderExtensions.ChatClient.cs | 2 +-
...IServiceCollectionExtensions.ChatClient.cs | 15 +-
.../KernelFunctionInvocationContext.cs | 2 +
.../KernelFunctionInvokingChatClient.cs | 16 +-
8 files changed, 1001 insertions(+), 18 deletions(-)
create mode 100644 dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
create mode 100644 dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIKernelBuilderExtensionsChatClientTests.cs
create mode 100644 dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIServiceCollectionExtensionsChatClientTests.cs
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
new file mode 100644
index 000000000000..9a8e91a25df5
--- /dev/null
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
@@ -0,0 +1,776 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Net;
+using System.Net.Http;
+using System.Threading.Tasks;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.ChatCompletion;
+using Microsoft.SemanticKernel.Connectors.OpenAI;
+using Xunit;
+
+namespace SemanticKernel.Connectors.OpenAI.UnitTests.Core;
+
+public sealed class AutoFunctionInvocationFilterChatClientTests : IDisposable
+{
+ private readonly MultipleHttpMessageHandlerStub _messageHandlerStub;
+ private readonly HttpClient _httpClient;
+
+ public AutoFunctionInvocationFilterChatClientTests()
+ {
+ this._messageHandlerStub = new MultipleHttpMessageHandlerStub();
+
+ this._httpClient = new HttpClient(this._messageHandlerStub, false);
+ }
+
+ [Fact]
+ public async Task FiltersAreExecutedCorrectlyAsync()
+ {
+ // Arrange
+ int filterInvocations = 0;
+ int functionInvocations = 0;
+ int[] expectedRequestSequenceNumbers = [0, 0, 1, 1];
+ int[] expectedFunctionSequenceNumbers = [0, 1, 0, 1];
+ List requestSequenceNumbers = [];
+ List functionSequenceNumbers = [];
+ Kernel? contextKernel = null;
+
+ var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { functionInvocations++; return parameter; }, "Function1");
+ var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { functionInvocations++; return parameter; }, "Function2");
+
+ var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]);
+
+ var kernel = this.GetKernelWithFilter(plugin, async (context, next) =>
+ {
+ contextKernel = context.Kernel;
+
+ if (context.ChatHistory.Last() is OpenAIChatMessageContent content)
+ {
+ Assert.Equal(2, content.ToolCalls.Count);
+ }
+
+ requestSequenceNumbers.Add(context.RequestSequenceIndex);
+ functionSequenceNumbers.Add(context.FunctionSequenceIndex);
+
+ await next(context);
+
+ filterInvocations++;
+ });
+
+ this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();
+
+ // Act
+ var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings
+ {
+ ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
+ }));
+
+ // Assert
+ Assert.Equal(4, filterInvocations);
+ Assert.Equal(4, functionInvocations);
+ Assert.Equal(expectedRequestSequenceNumbers, requestSequenceNumbers);
+ Assert.Equal(expectedFunctionSequenceNumbers, functionSequenceNumbers);
+ Assert.Same(kernel, contextKernel);
+ Assert.Equal("Test chat response", result.ToString());
+ }
+
+ [Fact]
+ public async Task FunctionSequenceIndexIsCorrectForConcurrentCallsAsync()
+ {
+ // Arrange
+ List functionSequenceNumbers = [];
+ List expectedFunctionSequenceNumbers = [0, 1, 0, 1];
+
+ var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { return parameter; }, "Function1");
+ var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { return parameter; }, "Function2");
+
+ var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]);
+
+ var kernel = this.GetKernelWithFilter(plugin, async (context, next) =>
+ {
+ functionSequenceNumbers.Add(context.FunctionSequenceIndex);
+
+ await next(context);
+ });
+
+ this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();
+
+ // Act
+ var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings
+ {
+ FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(options: new()
+ {
+ AllowParallelCalls = true,
+ AllowConcurrentInvocation = true
+ })
+ }));
+
+ // Assert
+ Assert.Equal(expectedFunctionSequenceNumbers, functionSequenceNumbers);
+ }
+
+ [Fact]
+ public async Task FiltersAreExecutedCorrectlyOnStreamingAsync()
+ {
+ // Arrange
+ int filterInvocations = 0;
+ int functionInvocations = 0;
+ List requestSequenceNumbers = [];
+ List functionSequenceNumbers = [];
+
+ var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { functionInvocations++; return parameter; }, "Function1");
+ var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { functionInvocations++; return parameter; }, "Function2");
+
+ var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]);
+
+ var kernel = this.GetKernelWithFilter(plugin, async (context, next) =>
+ {
+ if (context.ChatHistory.Last() is OpenAIChatMessageContent content)
+ {
+ Assert.Equal(2, content.ToolCalls.Count);
+ }
+
+ requestSequenceNumbers.Add(context.RequestSequenceIndex);
+ functionSequenceNumbers.Add(context.FunctionSequenceIndex);
+
+ await next(context);
+
+ filterInvocations++;
+ });
+
+ this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses();
+
+ var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };
+
+ // Act
+ await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings)))
+ { }
+
+ // Assert
+ Assert.Equal(4, filterInvocations);
+ Assert.Equal(4, functionInvocations);
+ Assert.Equal([0, 0, 1, 1], requestSequenceNumbers);
+ Assert.Equal([0, 1, 0, 1], functionSequenceNumbers);
+ }
+
+ [Fact]
+ public async Task DifferentWaysOfAddingFiltersWorkCorrectlyAsync()
+ {
+ // Arrange
+ var executionOrder = new List();
+
+ var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1");
+ var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function2");
+
+ var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]);
+
+ var filter1 = new AutoFunctionInvocationFilter(async (context, next) =>
+ {
+ executionOrder.Add("Filter1-Invoking");
+ await next(context);
+ });
+
+ var filter2 = new AutoFunctionInvocationFilter(async (context, next) =>
+ {
+ executionOrder.Add("Filter2-Invoking");
+ await next(context);
+ });
+
+ var builder = Kernel.CreateBuilder();
+
+ builder.Plugins.Add(plugin);
+
+ builder.Services.AddSingleton((serviceProvider) =>
+ {
+ return new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient);
+ });
+
+ this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();
+
+ // Act
+
+ // Case #1 - Add filter to services
+ builder.Services.AddSingleton(filter1);
+
+ var kernel = builder.Build();
+
+ // Case #2 - Add filter to kernel
+ kernel.AutoFunctionInvocationFilters.Add(filter2);
+
+ var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings
+ {
+ ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
+ }));
+
+ // Assert
+ Assert.Equal("Filter1-Invoking", executionOrder[0]);
+ Assert.Equal("Filter2-Invoking", executionOrder[1]);
+ }
+
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public async Task MultipleFiltersAreExecutedInOrderAsync(bool isStreaming)
+ {
+ // Arrange
+ var executionOrder = new List();
+
+ var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1");
+ var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function2");
+
+ var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]);
+
+ var filter1 = new AutoFunctionInvocationFilter(async (context, next) =>
+ {
+ executionOrder.Add("Filter1-Invoking");
+ await next(context);
+ executionOrder.Add("Filter1-Invoked");
+ });
+
+ var filter2 = new AutoFunctionInvocationFilter(async (context, next) =>
+ {
+ executionOrder.Add("Filter2-Invoking");
+ await next(context);
+ executionOrder.Add("Filter2-Invoked");
+ });
+
+ var filter3 = new AutoFunctionInvocationFilter(async (context, next) =>
+ {
+ executionOrder.Add("Filter3-Invoking");
+ await next(context);
+ executionOrder.Add("Filter3-Invoked");
+ });
+
+ var builder = Kernel.CreateBuilder();
+
+ builder.Plugins.Add(plugin);
+
+ builder.Services.AddSingleton((serviceProvider) =>
+ {
+ return new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient);
+ });
+
+ builder.Services.AddSingleton(filter1);
+ builder.Services.AddSingleton(filter2);
+ builder.Services.AddSingleton(filter3);
+
+ var kernel = builder.Build();
+
+ var arguments = new KernelArguments(new OpenAIPromptExecutionSettings
+ {
+ ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
+ });
+
+ // Act
+ if (isStreaming)
+ {
+ this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses();
+
+ await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", arguments))
+ { }
+ }
+ else
+ {
+ this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();
+
+ await kernel.InvokePromptAsync("Test prompt", arguments);
+ }
+
+ // Assert
+ Assert.Equal("Filter1-Invoking", executionOrder[0]);
+ Assert.Equal("Filter2-Invoking", executionOrder[1]);
+ Assert.Equal("Filter3-Invoking", executionOrder[2]);
+ Assert.Equal("Filter3-Invoked", executionOrder[3]);
+ Assert.Equal("Filter2-Invoked", executionOrder[4]);
+ Assert.Equal("Filter1-Invoked", executionOrder[5]);
+ }
+
+ [Fact]
+ public async Task FilterCanOverrideArgumentsAsync()
+ {
+ // Arrange
+ const string NewValue = "NewValue";
+
+ var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { return parameter; }, "Function1");
+ var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { return parameter; }, "Function2");
+
+ var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]);
+
+ var kernel = this.GetKernelWithFilter(plugin, async (context, next) =>
+ {
+ context.Arguments!["parameter"] = NewValue;
+ await next(context);
+ context.Terminate = true;
+ });
+
+ this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();
+
+ // Act
+ var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings
+ {
+ FunctionChoiceBehavior = FunctionChoiceBehavior.Auto()
+ }));
+
+ // Assert
+ Assert.Equal("NewValue", result.ToString());
+ }
+
+ [Fact]
+ public async Task FilterCanHandleExceptionAsync()
+ {
+ // Arrange
+ var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { throw new KernelException("Exception from Function1"); }, "Function1");
+ var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => "Result from Function2", "Function2");
+ var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]);
+
+ var kernel = this.GetKernelWithFilter(plugin, async (context, next) =>
+ {
+ try
+ {
+ await next(context);
+ }
+ catch (KernelException exception)
+ {
+ Assert.Equal("Exception from Function1", exception.Message);
+ context.Result = new FunctionResult(context.Result, "Result from filter");
+ }
+ });
+
+ this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();
+
+ var chatCompletion = new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient);
+
+ var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };
+
+ var chatHistory = new ChatHistory();
+ chatHistory.AddSystemMessage("System message");
+
+ // Act
+ var result = await chatCompletion.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel);
+
+ var firstFunctionResult = chatHistory[^2].Content;
+ var secondFunctionResult = chatHistory[^1].Content;
+
+ // Assert
+ Assert.Equal("Result from filter", firstFunctionResult);
+ Assert.Equal("Result from Function2", secondFunctionResult);
+ }
+
+ [Fact]
+ public async Task FilterCanHandleExceptionOnStreamingAsync()
+ {
+ // Arrange
+ var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { throw new KernelException("Exception from Function1"); }, "Function1");
+ var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => "Result from Function2", "Function2");
+ var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]);
+
+ var kernel = this.GetKernelWithFilter(plugin, async (context, next) =>
+ {
+ try
+ {
+ await next(context);
+ }
+ catch (KernelException)
+ {
+ context.Result = new FunctionResult(context.Result, "Result from filter");
+ }
+ });
+
+ this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses();
+
+ var chatCompletion = new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient);
+
+ var chatHistory = new ChatHistory();
+ var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };
+
+ // Act
+ await foreach (var item in chatCompletion.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel))
+ { }
+
+ var firstFunctionResult = chatHistory[^2].Content;
+ var secondFunctionResult = chatHistory[^1].Content;
+
+ // Assert
+ Assert.Equal("Result from filter", firstFunctionResult);
+ Assert.Equal("Result from Function2", secondFunctionResult);
+ }
+
+ [Fact]
+ public async Task FiltersCanSkipFunctionExecutionAsync()
+ {
+ // Arrange
+ int filterInvocations = 0;
+ int firstFunctionInvocations = 0;
+ int secondFunctionInvocations = 0;
+
+ var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1");
+ var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2");
+
+ var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]);
+
+ var kernel = this.GetKernelWithFilter(plugin, async (context, next) =>
+ {
+ // Filter delegate is invoked only for second function, the first one should be skipped.
+ if (context.Function.Name == "Function2")
+ {
+ await next(context);
+ }
+
+ filterInvocations++;
+ });
+
+ using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(File.ReadAllText("TestData/filters_multiple_function_calls_test_response.json")) };
+ using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(File.ReadAllText("TestData/chat_completion_test_response.json")) };
+
+ this._messageHandlerStub.ResponsesToReturn = [response1, response2];
+
+ // Act
+ var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings
+ {
+ ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
+ }));
+
+ // Assert
+ Assert.Equal(2, filterInvocations);
+ Assert.Equal(0, firstFunctionInvocations);
+ Assert.Equal(1, secondFunctionInvocations);
+ }
+
+ [Fact]
+ public async Task PreFilterCanTerminateOperationAsync()
+ {
+ // Arrange
+ int firstFunctionInvocations = 0;
+ int secondFunctionInvocations = 0;
+
+ var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1");
+ var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2");
+
+ var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]);
+
+ var kernel = this.GetKernelWithFilter(plugin, async (context, next) =>
+ {
+ // Terminating before first function, so all functions won't be invoked.
+ context.Terminate = true;
+
+ await next(context);
+ });
+
+ this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();
+
+ // Act
+ await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings
+ {
+ ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
+ }));
+
+ // Assert
+ Assert.Equal(0, firstFunctionInvocations);
+ Assert.Equal(0, secondFunctionInvocations);
+ }
+
+ [Fact]
+ public async Task PreFilterCanTerminateOperationOnStreamingAsync()
+ {
+ // Arrange
+ int firstFunctionInvocations = 0;
+ int secondFunctionInvocations = 0;
+
+ var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1");
+ var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2");
+
+ var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]);
+
+ var kernel = this.GetKernelWithFilter(plugin, async (context, next) =>
+ {
+ // Terminating before first function, so all functions won't be invoked.
+ context.Terminate = true;
+
+ await next(context);
+ });
+
+ this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses();
+
+ var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };
+
+ // Act
+ await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings)))
+ { }
+
+ // Assert
+ Assert.Equal(0, firstFunctionInvocations);
+ Assert.Equal(0, secondFunctionInvocations);
+ }
+
+ [Fact]
+ public async Task PostFilterCanTerminateOperationAsync()
+ {
+ // Arrange
+ int firstFunctionInvocations = 0;
+ int secondFunctionInvocations = 0;
+ List requestSequenceNumbers = [];
+ List functionSequenceNumbers = [];
+
+ var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1");
+ var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2");
+
+ var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]);
+
+ var kernel = this.GetKernelWithFilter(plugin, async (context, next) =>
+ {
+ requestSequenceNumbers.Add(context.RequestSequenceIndex);
+ functionSequenceNumbers.Add(context.FunctionSequenceIndex);
+
+ await next(context);
+
+ // Terminating after first function, so second function won't be invoked.
+ context.Terminate = true;
+ });
+
+ this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();
+
+ // Act
+ var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings
+ {
+ ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
+ }));
+
+ // Assert
+ Assert.Equal(1, firstFunctionInvocations);
+ Assert.Equal(0, secondFunctionInvocations);
+ Assert.Equal([0], requestSequenceNumbers);
+ Assert.Equal([0], functionSequenceNumbers);
+
+ // Results of function invoked before termination should be returned
+ var lastMessageContent = result.GetValue();
+ Assert.NotNull(lastMessageContent);
+
+ Assert.Equal("function1-value", lastMessageContent.Content);
+ Assert.Equal(AuthorRole.Tool, lastMessageContent.Role);
+ }
+
+ [Fact]
+ public async Task PostFilterCanTerminateOperationOnStreamingAsync()
+ {
+ // Arrange
+ int firstFunctionInvocations = 0;
+ int secondFunctionInvocations = 0;
+ List requestSequenceNumbers = [];
+ List functionSequenceNumbers = [];
+
+ var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => { firstFunctionInvocations++; return parameter; }, "Function1");
+ var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => { secondFunctionInvocations++; return parameter; }, "Function2");
+
+ var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]);
+
+ var kernel = this.GetKernelWithFilter(plugin, async (context, next) =>
+ {
+ requestSequenceNumbers.Add(context.RequestSequenceIndex);
+ functionSequenceNumbers.Add(context.FunctionSequenceIndex);
+
+ await next(context);
+
+ // Terminating after first function, so second function won't be invoked.
+ context.Terminate = true;
+ });
+
+ this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses();
+
+ var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };
+
+ List streamingContent = [];
+
+ // Act
+ await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings)))
+ {
+ streamingContent.Add(item);
+ }
+
+ // Assert
+ Assert.Equal(1, firstFunctionInvocations);
+ Assert.Equal(0, secondFunctionInvocations);
+ Assert.Equal([0], requestSequenceNumbers);
+ Assert.Equal([0], functionSequenceNumbers);
+
+ // Results of function invoked before termination should be returned
+ Assert.Equal(3, streamingContent.Count);
+
+ var lastMessageContent = streamingContent[^1] as StreamingChatMessageContent;
+ Assert.NotNull(lastMessageContent);
+
+ Assert.Equal("function1-value", lastMessageContent.Content);
+ Assert.Equal(AuthorRole.Tool, lastMessageContent.Role);
+ }
+
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming)
+ {
+ // Arrange
+ bool? actualStreamingFlag = null;
+
+ var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function1");
+ var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) => parameter, "Function2");
+
+ var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]);
+
+ var filter = new AutoFunctionInvocationFilter(async (context, next) =>
+ {
+ actualStreamingFlag = context.IsStreaming;
+ await next(context);
+ });
+
+ var builder = Kernel.CreateBuilder();
+
+ builder.Plugins.Add(plugin);
+
+ builder.Services.AddSingleton((serviceProvider) =>
+ {
+ return new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient);
+ });
+
+ builder.Services.AddSingleton(filter);
+
+ var kernel = builder.Build();
+
+ var arguments = new KernelArguments(new OpenAIPromptExecutionSettings
+ {
+ ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
+ });
+
+ // Act
+ if (isStreaming)
+ {
+ this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses();
+
+ await kernel.InvokePromptStreamingAsync("Test prompt", arguments).ToListAsync();
+ }
+ else
+ {
+ this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();
+
+ await kernel.InvokePromptAsync("Test prompt", arguments);
+ }
+
+ // Assert
+ Assert.Equal(isStreaming, actualStreamingFlag);
+ }
+
+ [Fact]
+ public async Task PromptExecutionSettingsArePropagatedFromInvokePromptToFilterContextAsync()
+ {
+ // Arrange
+ this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();
+
+ var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]);
+
+ AutoFunctionInvocationContext? actualContext = null;
+
+ var kernel = this.GetKernelWithFilter(plugin, (context, next) =>
+ {
+ actualContext = context;
+ return Task.CompletedTask;
+ });
+
+ var expectedExecutionSettings = new OpenAIPromptExecutionSettings
+ {
+ ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
+ };
+
+ // Act
+ var result = await kernel.InvokePromptAsync("Test prompt", new(expectedExecutionSettings));
+
+ // Assert
+ Assert.NotNull(actualContext);
+ Assert.Same(expectedExecutionSettings, actualContext!.ExecutionSettings);
+ }
+
+ [Fact]
+ public async Task PromptExecutionSettingsArePropagatedFromInvokePromptStreamingToFilterContextAsync()
+ {
+ // Arrange
+ this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses();
+
+ var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]);
+
+ AutoFunctionInvocationContext? actualContext = null;
+
+ var kernel = this.GetKernelWithFilter(plugin, (context, next) =>
+ {
+ actualContext = context;
+ return Task.CompletedTask;
+ });
+
+ var expectedExecutionSettings = new OpenAIPromptExecutionSettings
+ {
+ ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
+ };
+
+ // Act
+ await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(expectedExecutionSettings)))
+ { }
+
+ // Assert
+ Assert.NotNull(actualContext);
+ Assert.Same(expectedExecutionSettings, actualContext!.ExecutionSettings);
+ }
+
+ public void Dispose()
+ {
+ this._httpClient.Dispose();
+ this._messageHandlerStub.Dispose();
+ }
+
+ #region private
+
+#pragma warning disable CA2000 // Dispose objects before losing scope
+ private static List GetFunctionCallingResponses()
+ {
+ return [
+ new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_multiple_function_calls_test_response.json")) },
+ new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_multiple_function_calls_test_response.json")) },
+ new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_test_response.json")) }
+ ];
+ }
+
+ private static List GetFunctionCallingStreamingResponses()
+ {
+ return [
+ new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_streaming_multiple_function_calls_test_response.txt")) },
+ new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_streaming_multiple_function_calls_test_response.txt")) },
+ new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_streaming_test_response.txt")) }
+ ];
+ }
+#pragma warning restore CA2000
+
+ private Kernel GetKernelWithFilter(
+ KernelPlugin plugin,
+ Func, Task>? onAutoFunctionInvocation)
+ {
+ var builder = Kernel.CreateBuilder();
+ var filter = new AutoFunctionInvocationFilter(onAutoFunctionInvocation);
+
+ builder.Plugins.Add(plugin);
+ builder.Services.AddSingleton(filter);
+
+ builder.AddOpenAIChatClient("model-id", "test-api-key", "organization-id", httpClient: this._httpClient);
+
+ return builder.Build();
+ }
+
+ private sealed class AutoFunctionInvocationFilter(
+ Func, Task>? onAutoFunctionInvocation) : IAutoFunctionInvocationFilter
+ {
+ private readonly Func, Task>? _onAutoFunctionInvocation = onAutoFunctionInvocation;
+
+ public Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) =>
+ this._onAutoFunctionInvocation?.Invoke(context, next) ?? Task.CompletedTask;
+ }
+
+ #endregion
+}
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs
index 19992be01667..1f23d84c4dfe 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs
@@ -312,7 +312,7 @@ public async Task FilterCanOverrideArgumentsAsync()
// Act
var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings
{
- ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
+ FunctionChoiceBehavior = FunctionChoiceBehavior.Auto()
}));
// Assert
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIKernelBuilderExtensionsChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIKernelBuilderExtensionsChatClientTests.cs
new file mode 100644
index 000000000000..437d347aa194
--- /dev/null
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIKernelBuilderExtensionsChatClientTests.cs
@@ -0,0 +1,92 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System;
+using System.Net.Http;
+using Microsoft.Extensions.AI;
+using Microsoft.SemanticKernel;
+using OpenAI;
+using Xunit;
+
+namespace SemanticKernel.Connectors.OpenAI.UnitTests.Extensions;
+
+public class OpenAIKernelBuilderExtensionsChatClientTests
+{
+ [Fact]
+ public void AddOpenAIChatClientNullArgsThrow()
+ {
+ // Arrange
+ IKernelBuilder builder = null!;
+ string modelId = "gpt-3.5-turbo";
+ string apiKey = "test_api_key";
+ string orgId = "test_org_id";
+ string serviceId = "test_service_id";
+
+ // Act & Assert
+ var exception = Assert.Throws(() => builder.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId));
+ Assert.Equal("builder", exception.ParamName);
+
+ exception = Assert.Throws(() => builder.AddOpenAIChatClient(modelId, new OpenAIClient(apiKey), serviceId));
+ Assert.Equal("builder", exception.ParamName);
+
+ using var httpClient = new HttpClient();
+ exception = Assert.Throws(() => builder.AddOpenAIChatClient(modelId, new Uri("http://localhost"), apiKey, orgId, serviceId, httpClient));
+ Assert.Equal("builder", exception.ParamName);
+ }
+
+ [Fact]
+ public void AddOpenAIChatClientDefaultValidParametersRegistersService()
+ {
+ // Arrange
+ var builder = Kernel.CreateBuilder();
+ string modelId = "gpt-3.5-turbo";
+ string apiKey = "test_api_key";
+ string orgId = "test_org_id";
+ string serviceId = "test_service_id";
+
+ // Act
+ builder.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId);
+
+ // Assert
+ var kernel = builder.Build();
+ Assert.NotNull(kernel.GetRequiredService());
+ Assert.NotNull(kernel.GetRequiredService(serviceId));
+ }
+
+ [Fact]
+ public void AddOpenAIChatClientOpenAIClientValidParametersRegistersService()
+ {
+ // Arrange
+ var builder = Kernel.CreateBuilder();
+ string modelId = "gpt-3.5-turbo";
+ var openAIClient = new OpenAIClient("test_api_key");
+ string serviceId = "test_service_id";
+
+ // Act
+ builder.AddOpenAIChatClient(modelId, openAIClient, serviceId);
+
+ // Assert
+ var kernel = builder.Build();
+ Assert.NotNull(kernel.GetRequiredService());
+ Assert.NotNull(kernel.GetRequiredService(serviceId));
+ }
+
+ [Fact]
+ public void AddOpenAIChatClientCustomEndpointValidParametersRegistersService()
+ {
+ // Arrange
+ var builder = Kernel.CreateBuilder();
+ string modelId = "gpt-3.5-turbo";
+ string apiKey = "test_api_key";
+ string orgId = "test_org_id";
+ string serviceId = "test_service_id";
+ using var httpClient = new HttpClient();
+
+ // Act
+ builder.AddOpenAIChatClient(modelId, new Uri("http://localhost"), apiKey, orgId, serviceId, httpClient);
+
+ // Assert
+ var kernel = builder.Build();
+ Assert.NotNull(kernel.GetRequiredService());
+ Assert.NotNull(kernel.GetRequiredService(serviceId));
+ }
+}
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIServiceCollectionExtensionsChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIServiceCollectionExtensionsChatClientTests.cs
new file mode 100644
index 000000000000..7a3888b95f30
--- /dev/null
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/OpenAIServiceCollectionExtensionsChatClientTests.cs
@@ -0,0 +1,114 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System;
+using System.Net.Http;
+using Microsoft.Extensions.AI;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.SemanticKernel;
+using OpenAI;
+using Xunit;
+
+namespace SemanticKernel.Connectors.OpenAI.UnitTests.Extensions;
+
+public class OpenAIServiceCollectionExtensionsChatClientTests
+{
+ [Fact]
+ public void AddOpenAIChatClientNullArgsThrow()
+ {
+ // Arrange
+ ServiceCollection services = null!;
+ string modelId = "gpt-3.5-turbo";
+ string apiKey = "test_api_key";
+ string orgId = "test_org_id";
+ string serviceId = "test_service_id";
+
+ // Act & Assert
+ var exception = Assert.Throws(() => services.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId));
+ Assert.Equal("services", exception.ParamName);
+
+ exception = Assert.Throws(() => services.AddOpenAIChatClient(modelId, new OpenAIClient(apiKey), serviceId));
+ Assert.Equal("services", exception.ParamName);
+
+ using var httpClient = new HttpClient();
+ exception = Assert.Throws(() => services.AddOpenAIChatClient(modelId, new Uri("http://localhost"), apiKey, orgId, serviceId, httpClient));
+ Assert.Equal("services", exception.ParamName);
+ }
+
+ [Fact]
+ public void AddOpenAIChatClientDefaultValidParametersRegistersService()
+ {
+ // Arrange
+ var services = new ServiceCollection();
+ string modelId = "gpt-3.5-turbo";
+ string apiKey = "test_api_key";
+ string orgId = "test_org_id";
+ string serviceId = "test_service_id";
+
+ // Act
+ services.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId);
+
+ // Assert
+ var serviceProvider = services.BuildServiceProvider();
+ var chatClient = serviceProvider.GetKeyedService(serviceId);
+ Assert.NotNull(chatClient);
+ }
+
+ [Fact]
+ public void AddOpenAIChatClientOpenAIClientValidParametersRegistersService()
+ {
+ // Arrange
+ var services = new ServiceCollection();
+ string modelId = "gpt-3.5-turbo";
+ var openAIClient = new OpenAIClient("test_api_key");
+ string serviceId = "test_service_id";
+
+ // Act
+ services.AddOpenAIChatClient(modelId, openAIClient, serviceId);
+
+ // Assert
+ var serviceProvider = services.BuildServiceProvider();
+ var chatClient = serviceProvider.GetKeyedService(serviceId);
+ Assert.NotNull(chatClient);
+ }
+
+ [Fact]
+ public void AddOpenAIChatClientCustomEndpointValidParametersRegistersService()
+ {
+ // Arrange
+ var services = new ServiceCollection();
+ string modelId = "gpt-3.5-turbo";
+ string apiKey = "test_api_key";
+ string orgId = "test_org_id";
+ string serviceId = "test_service_id";
+ using var httpClient = new HttpClient();
+ // Act
+ services.AddOpenAIChatClient(modelId, new Uri("http://localhost"), apiKey, orgId, serviceId, httpClient);
+ // Assert
+ var serviceProvider = services.BuildServiceProvider();
+ var chatClient = serviceProvider.GetKeyedService(serviceId);
+ Assert.NotNull(chatClient);
+ }
+
+ [Fact]
+ public void AddOpenAIChatClientWorksWithKernel()
+ {
+ var services = new ServiceCollection();
+ string modelId = "gpt-3.5-turbo";
+ string apiKey = "test_api_key";
+ string orgId = "test_org_id";
+ string serviceId = "test_service_id";
+
+ // Act
+ services.AddOpenAIChatClient(modelId, apiKey, orgId, serviceId);
+ services.AddKernel();
+
+ var serviceProvider = services.BuildServiceProvider();
+ var kernel = serviceProvider.GetRequiredService();
+
+ var serviceFromCollection = serviceProvider.GetKeyedService(serviceId);
+ var serviceFromKernel = kernel.GetRequiredService(serviceId);
+
+ Assert.NotNull(serviceFromKernel);
+ Assert.Same(serviceFromCollection, serviceFromKernel);
+ }
+}
diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs
index e38aa5289104..ac0005f5d5c5 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs
@@ -10,6 +10,7 @@ namespace Microsoft.SemanticKernel;
///
/// Sponsor extensions class for .
///
+[Experimental("SKEXP0010")]
public static class OpenAIChatClientKernelBuilderExtensions
{
#region Chat Completion
@@ -79,7 +80,6 @@ public static IKernelBuilder AddOpenAIChatClient(
/// A local identifier for the given AI service
/// The HttpClient to use with this service.
/// The same instance as .
- [Experimental("SKEXP0010")]
public static IKernelBuilder AddOpenAIChatClient(
this IKernelBuilder builder,
string modelId,
diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs
index ff214c441c2e..2ae915869698 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs
@@ -3,6 +3,7 @@
using System;
using System.ClientModel;
using System.ClientModel.Primitives;
+using System.Diagnostics.CodeAnalysis;
using System.Net.Http;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
@@ -13,14 +14,16 @@
namespace Microsoft.SemanticKernel;
-#pragma warning disable IDE0039 // Use local function
-
///
/// Sponsor extensions class for .
///
+[Experimental("SKEXP0010")]
public static class OpenAIChatClientServiceCollectionExtensions
{
- #region Chat Completion
+ ///
+ /// White space constant.
+ ///
+ private const string SingleSpace = " ";
///
/// Adds the OpenAI chat completion service to the list.
@@ -148,10 +151,4 @@ private static OpenAIClientOptions GetClientOptions(
return options;
}
-
- ///
- /// White space constant.
- ///
- private const string SingleSpace = " ";
- #endregion
}
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
index da5af46620fd..d782a7262e68 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
@@ -18,6 +18,8 @@ namespace Microsoft.SemanticKernel.ChatCompletion;
[ExcludeFromCodeCoverage]
internal sealed class KernelFunctionInvocationContext
{
+ internal const string KernelKey = "Kernel";
+
///
/// A nop function used to allow to be non-nullable. Default instances of
/// start with this as the target function.
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
index 1a59b8f5ccbd..117e1ebb4f81 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
@@ -186,6 +186,7 @@ public override async Task GetResponseAsync(
IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
Verify.NotNull(messages);
+ options?.AdditionalProperties?.TryGetValue(KernelFunctionInvocationContext.KernelKey, out var kernel);
// A single request into this GetResponseAsync may result in multiple requests to the inner client.
// Create an activity to group them together for better observability.
@@ -254,7 +255,7 @@ public override async Task GetResponseAsync(
// Add the responses from the function calls into the augmented history and also into the tracked
// list of response messages.
- var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, cancellationToken).ConfigureAwait(false);
+ var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, kernel, cancellationToken).ConfigureAwait(false);
responseMessages.AddRange(modeAndMessages.MessagesAdded);
if (UpdateOptionsForMode(modeAndMessages.Mode, ref options!, response.ChatThreadId))
@@ -504,15 +505,16 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti
/// The options used for the response being processed.
/// The function call contents representing the functions to be invoked.
/// The iteration number of how many roundtrips have been made to the inner client.
+ /// The containing the auto function invocations.
/// The to monitor for cancellation requests.
/// A value indicating how the caller should proceed.
private async Task<(ContinueMode Mode, IList MessagesAdded)> ProcessFunctionCallsAsync(
- List messages, ChatOptions options, List functionCallContents, int iteration, CancellationToken cancellationToken)
+ List messages, ChatOptions options, List functionCallContents, int iteration, Kernel kernel, CancellationToken cancellationToken)
{
// We must add a response for every tool call, regardless of whether we successfully executed it or not.
// If we successfully execute it, we'll add the result. If we don't, we'll add an error.
- Debug.Assert(functionCallContents.Count > 0, "Expecteded at least one function call.");
+ Debug.Assert(functionCallContents.Count > 0, "Expected at least one function call.");
// Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel.
if (functionCallContents.Count == 1)
@@ -694,13 +696,12 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi
}
/// Invokes the function asynchronously.
- ///
- /// The function invocation context detailing the function to be invoked and its arguments along with additional request information.
- ///
+ /// The function invocation context detailing the function to be invoked and its arguments along with additional request information.
+ /// The containing the auto function invocations.
/// The to monitor for cancellation requests. The default is .
/// The result of the function invocation, or if the function invocation returned .
/// is .
- internal async Task
/// The current chat contents, inclusive of the function call contents being processed.
/// The options used for the response being processed.
+ /// The response from the inner client.
/// The function call contents representing the functions to be invoked.
/// The iteration number of how many roundtrips have been made to the inner client.
/// The containing the auto function invocations.
/// The to monitor for cancellation requests.
/// A value indicating how the caller should proceed.
private async Task<(ContinueMode Mode, IList MessagesAdded)> ProcessFunctionCallsAsync(
- List messages, ChatOptions options, List functionCallContents, int iteration, Kernel kernel, CancellationToken cancellationToken)
+ List messages, ChatOptions options, ChatResponse response, List functionCallContents, int iteration, Kernel kernel, CancellationToken cancellationToken)
{
// We must add a response for every tool call, regardless of whether we successfully executed it or not.
// If we successfully execute it, we'll add the result. If we don't, we'll add an error.
@@ -520,7 +522,7 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti
if (functionCallContents.Count == 1)
{
FunctionInvocationResult result = await ProcessFunctionCallAsync(
- messages, options, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false);
+ messages, options, response, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false);
IList added = CreateResponseMessages([result]);
ThrowIfNoFunctionResultsAdded(added);
@@ -585,13 +587,14 @@ private void ThrowIfNoFunctionResultsAdded(IList? messages)
/// Processes the function call described in [].
/// The current chat contents, inclusive of the function call contents being processed.
/// The options used for the response being processed.
+ /// The response from the inner client.
/// The function call contents representing all the functions being invoked.
/// The iteration number of how many roundtrips have been made to the inner client.
/// The 0-based index of the function being called out of .
/// The to monitor for cancellation requests.
/// A value indicating how the caller should proceed.
private async Task ProcessFunctionCallAsync(
- List messages, ChatOptions options, List callContents,
+ List messages, ChatOptions options, ChatResponse response, List callContents,
int iteration, int functionCallIndex, CancellationToken cancellationToken)
{
var callContent = callContents[functionCallIndex];
@@ -695,17 +698,60 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi
}
}
+ ///
+ /// Invokes the auto function invocation filters.
+ ///
+ /// The .
+ /// The auto function invocation context.
+ /// The function to call after the filters.
+ /// The auto function invocation context.
+ private async Task OnAutoFunctionInvocationAsync(
+ Kernel kernel,
+ AutoFunctionInvocationContext context,
+ Func functionCallCallback)
+ {
+ await this.InvokeFilterOrFunctionAsync(kernel.AutoFunctionInvocationFilters, functionCallCallback, context).ConfigureAwait(false);
+
+ return context;
+ }
+
+ ///
+ /// This method will execute auto function invocation filters and function recursively.
+ /// If there are no registered filters, just function will be executed.
+ /// If there are registered filters, filter on position will be executed.
+ /// Second parameter of filter is callback. It can be either filter on + 1 position or function if there are no remaining filters to execute.
+ /// Function will be always executed as last step after all filters.
+ ///
+ private async Task InvokeFilterOrFunctionAsync(
+ IList? autoFunctionInvocationFilters,
+ Func functionCallCallback,
+ KernelFunctionInvocationContext context,
+ int index = 0)
+ {
+ if (autoFunctionInvocationFilters is { Count: > 0 } && index < autoFunctionInvocationFilters.Count)
+ {
+ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
+ context.ToAutoFunctionInvocationContext(),
+ (context) => this.InvokeFilterOrFunctionAsync(autoFunctionInvocationFilters, functionCallCallback, context, index + 1)
+ ).ConfigureAwait(false);
+ }
+ else
+ {
+ await functionCallCallback(context).ConfigureAwait(false);
+ }
+ }
+
/// Invokes the function asynchronously.
- /// The function invocation context detailing the function to be invoked and its arguments along with additional request information.
+ /// The function invocation context detailing the function to be invoked and its arguments along with additional request information.
/// The containing the auto function invocations.
/// The to monitor for cancellation requests. The default is .
/// The result of the function invocation, or if the function invocation returned .
- /// is .
- internal async Task InvokeFunctionAsync(KernelFunctionInvocationContext context, Kernel kernel, CancellationToken cancellationToken)
+ /// is .
+ internal async Task InvokeFunctionAsync(KernelFunctionInvocationContext invocationContext, Kernel kernel, CancellationToken cancellationToken)
{
- Verify.NotNull(context);
+ Verify.NotNull(invocationContext);
- using Activity? activity = _activitySource?.StartActivity(context.Function.Name);
+ using Activity? activity = _activitySource?.StartActivity(invocationContext.Function.Name);
long startingTimestamp = 0;
if (_logger.IsEnabled(LogLevel.Debug))
@@ -713,20 +759,45 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi
startingTimestamp = Stopwatch.GetTimestamp();
if (_logger.IsEnabled(LogLevel.Trace))
{
- LogInvokingSensitive(context.Function.Name, LoggingAsJson(context.CallContent.Arguments, context.Function.JsonSerializerOptions));
+ LogInvokingSensitive(invocationContext.Function.Name, LoggingAsJson(invocationContext.CallContent.Arguments, invocationContext.Function.JsonSerializerOptions));
}
else
{
- LogInvoking(context.Function.Name);
+ LogInvoking(invocationContext.Function.Name);
}
}
object? result = null;
try
{
- CurrentContext = context;
- context = this.OnAutoFunctionInvocationAsync(kernel, context, )
- result = await context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken).ConfigureAwait(false);
+ CurrentContext = invocationContext;
+ var autoFunctionInvocationContext = invocationContext.ToAutoFunctionInvocationContext(kernel, )
+ invocationContext = this.OnAutoFunctionInvocationAsync(
+ kernel,
+ invocationContext,
+ async (context) =>
+ {
+ // Check if filter requested termination
+ if (context.Terminate)
+ {
+ return;
+ }
+
+ // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any
+ // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here,
+ // as the called function could in turn telling the model about itself as a possible candidate for invocation.
+ KernelArguments? arguments = null;
+ if (context.CallContent.Arguments is not null)
+ {
+ arguments = new(context.CallContent.Arguments);
+ }
+
+ var result = await invocationContext.Function.InvokeAsync(invocationContext.CallContent.Arguments, cancellationToken).ConfigureAwait(false);
+ context.Result = new FunctionResult(context.Function.AsKernelFunction(), result);
+ }).ConfigureAwait(false);
+
+ )
+ result =
}
catch (Exception e)
{
@@ -738,11 +809,11 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi
if (e is OperationCanceledException)
{
- LogInvocationCanceled(context.Function.Name);
+ LogInvocationCanceled(invocationContext.Function.Name);
}
else
{
- LogInvocationFailed(context.Function.Name, e);
+ LogInvocationFailed(invocationContext.Function.Name, e);
}
throw;
@@ -755,11 +826,11 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi
if (result is not null && _logger.IsEnabled(LogLevel.Trace))
{
- LogInvocationCompletedSensitive(context.Function.Name, elapsed, LoggingAsJson(result, context.Function.JsonSerializerOptions));
+ LogInvocationCompletedSensitive(invocationContext.Function.Name, elapsed, LoggingAsJson(result, invocationContext.Function.JsonSerializerOptions));
}
else
{
- LogInvocationCompleted(context.Function.Name, elapsed);
+ LogInvocationCompleted(invocationContext.Function.Name, elapsed);
}
}
}
From 9c892e1efb7af30c093ed3deb846e0ae2e409bbf Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Fri, 4 Apr 2025 18:42:56 +0100
Subject: [PATCH 04/26] Auto vs KernelFIC
---
.../KernelFunctionInvocationContext.cs | 5 ---
...rnelFunctionInvocationContextExtensions.cs | 20 +++++-----
.../KernelFunctionInvokingChatClient.cs | 38 ++++++++++---------
.../AutoFunctionInvocationContext.cs | 9 ++++-
4 files changed, 38 insertions(+), 34 deletions(-)
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
index 344e021da7ad..da5af46620fd 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
@@ -18,8 +18,6 @@ namespace Microsoft.SemanticKernel.ChatCompletion;
[ExcludeFromCodeCoverage]
internal sealed class KernelFunctionInvocationContext
{
- internal const string KernelKey = "Kernel";
-
///
/// A nop function used to allow to be non-nullable. Default instances of
/// start with this as the target function.
@@ -40,9 +38,6 @@ internal KernelFunctionInvocationContext()
{
}
- /// Chat response information
- public ChatResponse? Response { get; set; }
-
/// Gets or sets the function call content information associated with this invocation.
public Microsoft.Extensions.AI.FunctionCallContent CallContent
{
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs
index 61449fcfa55f..475698c89fd5 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
-using Microsoft.Extensions.AI;
+using System.Linq;
namespace Microsoft.SemanticKernel.ChatCompletion;
@@ -11,27 +11,27 @@ namespace Microsoft.SemanticKernel.ChatCompletion;
/// Provides context for an in-flight function invocation.
internal static class KernelFunctionInvocationContextExtensions
{
- internal static AutoFunctionInvocationContext ToAutoFunctionInvocationContext(this KernelFunctionInvocationContext context, Microsoft.Extensions.AI.FunctionCallContent functionCall, Kernel kernel, ChatMessage message, bool isStreaming)
+ internal static AutoFunctionInvocationContext ToAutoFunctionInvocationContext(this KernelFunctionInvocationContext context)
{
if (context is null)
{
throw new ArgumentNullException(nameof(context));
}
+ var kernelFunction = context.Function.AsKernelFunction();
return new AutoFunctionInvocationContext(
- kernel: kernel,
- function: context.Function.AsKernelFunction(),
- result: null,
+ kernel: context.Kernel,
+ function: kernelFunction,
+ result: context.Result,
chatHistory: context.Messages.ToChatHistory(),
- chatMessageContent: message.ToChatMessageContent())
+ chatMessageContent: context.Response!.Messages.Last().ToChatMessageContent())
{
- Arguments = new(functionCall.Arguments),
- FunctionName = functionCall.Name,
+ Arguments = context.CallContent.Arguments is null ? null : new(context.CallContent.Arguments),
FunctionCount = context.FunctionCount,
FunctionSequenceIndex = context.FunctionCallIndex,
RequestSequenceIndex = context.Iteration,
- IsStreaming = isStreaming,
- ToolCallId = functionCall.CallId,
+ IsStreaming = context.IsStreaming,
+ ToolCallId = context.CallContent.CallId,
ExecutionSettings = context.Options.ToPromptExecutionSettings()
};
}
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
index 9096f3f310e8..99c689d61369 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
@@ -585,6 +585,7 @@ private void ThrowIfNoFunctionResultsAdded(IList? messages)
}
/// Processes the function call described in [].
+ /// The current chat contents, inclusive of the function call contents being processed.
/// The options used for the response being processed.
/// The response from the inner client.
@@ -594,6 +595,7 @@ private void ThrowIfNoFunctionResultsAdded(IList? messages)
/// The to monitor for cancellation requests.
/// A value indicating how the caller should proceed.
private async Task ProcessFunctionCallAsync(
+ Kernel kernel,
List messages, ChatOptions options, ChatResponse response, List callContents,
int iteration, int functionCallIndex, CancellationToken cancellationToken)
{
@@ -608,7 +610,9 @@ private async Task ProcessFunctionCallAsync(
KernelFunctionInvocationContext context = new()
{
+ Kernel = kernel,
Messages = messages,
+ Response = response,
Options = options,
CallContent = callContent,
Function = function,
@@ -701,18 +705,17 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi
///
/// Invokes the auto function invocation filters.
///
- /// The .
- /// The auto function invocation context.
+ /// The auto function invocation context.
/// The function to call after the filters.
/// The auto function invocation context.
- private async Task OnAutoFunctionInvocationAsync(
- Kernel kernel,
- AutoFunctionInvocationContext context,
+ private async Task OnAutoFunctionInvocationAsync(
+ KernelFunctionInvocationContext kernelContext,
Func functionCallCallback)
{
- await this.InvokeFilterOrFunctionAsync(kernel.AutoFunctionInvocationFilters, functionCallCallback, context).ConfigureAwait(false);
+ var autoContext = invocationContext.ToAutoFunctionInvocationContext();
+ await this.InvokeFilterOrFunctionAsync(functionCallCallback, kernelContext).ConfigureAwait(false);
- return context;
+ return kernelContext;
}
///
@@ -723,16 +726,17 @@ private async Task OnAutoFunctionInvocationAsyn
/// Function will be always executed as last step after all filters.
///
private async Task InvokeFilterOrFunctionAsync(
- IList? autoFunctionInvocationFilters,
- Func functionCallCallback,
- KernelFunctionInvocationContext context,
+ Func functionCallCallback,
+ AutoFunctionInvocationContext context,
int index = 0)
{
+ IList? autoFunctionInvocationFilters = context.Kernel.AutoFunctionInvocationFilters;
+
if (autoFunctionInvocationFilters is { Count: > 0 } && index < autoFunctionInvocationFilters.Count)
{
await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
- context.ToAutoFunctionInvocationContext(),
- (context) => this.InvokeFilterOrFunctionAsync(autoFunctionInvocationFilters, functionCallCallback, context, index + 1)
+ context,
+ (context) => this.InvokeFilterOrFunctionAsync(functionCallCallback, context, index + 1)
).ConfigureAwait(false);
}
else
@@ -747,7 +751,7 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
/// The to monitor for cancellation requests. The default is .
/// The result of the function invocation, or if the function invocation returned .
/// is .
- internal async Task InvokeFunctionAsync(KernelFunctionInvocationContext invocationContext, Kernel kernel, CancellationToken cancellationToken)
+ internal async Task InvokeFunctionAsync(KernelFunctionInvocationContext invocationContext, CancellationToken cancellationToken)
{
Verify.NotNull(invocationContext);
@@ -771,9 +775,7 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
try
{
CurrentContext = invocationContext;
- var autoFunctionInvocationContext = invocationContext.ToAutoFunctionInvocationContext(kernel, )
invocationContext = this.OnAutoFunctionInvocationAsync(
- kernel,
invocationContext,
async (context) =>
{
@@ -787,13 +789,13 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
// further calls made as part of this function invocation. In particular, we must not use function calling settings naively here,
// as the called function could in turn telling the model about itself as a possible candidate for invocation.
KernelArguments? arguments = null;
- if (context.CallContent.Arguments is not null)
+ if (context.Arguments is not null)
{
- arguments = new(context.CallContent.Arguments);
+ arguments = new(context.Arguments);
}
var result = await invocationContext.Function.InvokeAsync(invocationContext.CallContent.Arguments, cancellationToken).ConfigureAwait(false);
- context.Result = new FunctionResult(context.Function.AsKernelFunction(), result);
+ context.Result = new FunctionResult(context.Function, result);
}).ConfigureAwait(false);
)
diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
index d13d5519b652..d61923eda663 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
@@ -9,8 +9,15 @@ namespace Microsoft.SemanticKernel;
///
/// Class with data related to automatic function invocation.
///
-public class AutoFunctionInvocationContext
+public class AutoFunctionInvocationContext : KernelFunctionInvocationContext
{
+ private KernelFunctionInvocationContext _innerContext;
+
+ public AutoFunctionInvocationContext(KernelFunctionInvocationContext innerContext)
+ {
+ this._innerContext = innerContext;
+ }
+
///
/// Initializes a new instance of the class.
///
From be5583e0f4a814734483d5aa1cd77924f8f74412 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Fri, 4 Apr 2025 20:33:50 +0100
Subject: [PATCH 05/26] AutoFunctionInvocationContext as KFIC
---
.../KernelFunctionInvocationContext.cs | 6 +-
...rnelFunctionInvocationContextExtensions.cs | 38 ------
.../KernelFunctionInvokingChatClient.cs | 89 ++++++------
.../AI/ChatCompletion/ChatHistory.cs | 27 ++--
.../AutoFunctionInvocationContext.cs | 129 +++++++++++++++---
5 files changed, 162 insertions(+), 127 deletions(-)
delete mode 100644 dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
index da5af46620fd..a36f03630fd5 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
@@ -16,10 +16,10 @@ namespace Microsoft.SemanticKernel.ChatCompletion;
/// Provides context for an in-flight function invocation.
[ExcludeFromCodeCoverage]
-internal sealed class KernelFunctionInvocationContext
+public class KernelFunctionInvocationContext
{
///
- /// A nop function used to allow to be non-nullable. Default instances of
+ /// A nop function used to allow to be non-nullable. Default instances of
/// start with this as the target function.
///
private static readonly AIFunction s_nopFunction = AIFunctionFactory.Create(() => { }, nameof(KernelFunctionInvocationContext));
@@ -64,7 +64,7 @@ public IList Messages
public ChatOptions? Options { get; set; }
/// Gets or sets the AI function to be invoked.
- public AIFunction Function
+ public AIFunction AIFunction
{
get => _function;
set
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs
deleted file mode 100644
index 475698c89fd5..000000000000
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContextExtensions.cs
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright (c) Microsoft. All rights reserved.
-
-using System;
-using System.Linq;
-
-namespace Microsoft.SemanticKernel.ChatCompletion;
-
-// Slight modified source from
-// https://raw.githubusercontent.com/dotnet/extensions/refs/heads/main/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs
-
-/// Provides context for an in-flight function invocation.
-internal static class KernelFunctionInvocationContextExtensions
-{
- internal static AutoFunctionInvocationContext ToAutoFunctionInvocationContext(this KernelFunctionInvocationContext context)
- {
- if (context is null)
- {
- throw new ArgumentNullException(nameof(context));
- }
-
- var kernelFunction = context.Function.AsKernelFunction();
- return new AutoFunctionInvocationContext(
- kernel: context.Kernel,
- function: kernelFunction,
- result: context.Result,
- chatHistory: context.Messages.ToChatHistory(),
- chatMessageContent: context.Response!.Messages.Last().ToChatMessageContent())
- {
- Arguments = context.CallContent.Arguments is null ? null : new(context.CallContent.Arguments),
- FunctionCount = context.FunctionCount,
- FunctionSequenceIndex = context.FunctionCallIndex,
- RequestSequenceIndex = context.Iteration,
- IsStreaming = context.IsStreaming,
- ToolCallId = context.CallContent.CallId,
- ExecutionSettings = context.Options.ToPromptExecutionSettings()
- };
- }
-}
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
index 99c689d61369..a17e9d02730a 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
@@ -12,7 +12,6 @@
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
-using static System.Runtime.InteropServices.JavaScript.JSType;
#pragma warning disable CA2213 // Disposable fields should be disposed
#pragma warning disable IDE0009 // Use explicit 'this.' qualifier
@@ -187,7 +186,6 @@ public override async Task GetResponseAsync(
IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
Verify.NotNull(messages);
- options?.AdditionalProperties?.TryGetValue(KernelFunctionInvocationContext.KernelKey, out var kernel);
// A single request into this GetResponseAsync may result in multiple requests to the inner client.
// Create an activity to group them together for better observability.
@@ -256,7 +254,7 @@ public override async Task GetResponseAsync(
// Add the responses from the function calls into the augmented history and also into the tracked
// list of response messages.
- var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, response, functionCallContents!, iteration, kernel, cancellationToken).ConfigureAwait(false);
+ var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, cancellationToken).ConfigureAwait(false);
responseMessages.AddRange(modeAndMessages.MessagesAdded);
if (UpdateOptionsForMode(modeAndMessages.Mode, ref options!, response.ChatThreadId))
@@ -504,14 +502,12 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti
///
/// The current chat contents, inclusive of the function call contents being processed.
/// The options used for the response being processed.
- /// The response from the inner client.
/// The function call contents representing the functions to be invoked.
/// The iteration number of how many roundtrips have been made to the inner client.
- /// The containing the auto function invocations.
/// The to monitor for cancellation requests.
/// A value indicating how the caller should proceed.
private async Task<(ContinueMode Mode, IList MessagesAdded)> ProcessFunctionCallsAsync(
- List messages, ChatOptions options, ChatResponse response, List functionCallContents, int iteration, Kernel kernel, CancellationToken cancellationToken)
+ List messages, ChatOptions options, List functionCallContents, int iteration, CancellationToken cancellationToken)
{
// We must add a response for every tool call, regardless of whether we successfully executed it or not.
// If we successfully execute it, we'll add the result. If we don't, we'll add an error.
@@ -522,7 +518,7 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti
if (functionCallContents.Count == 1)
{
FunctionInvocationResult result = await ProcessFunctionCallAsync(
- messages, options, response, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false);
+ messages, options, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false);
IList added = CreateResponseMessages([result]);
ThrowIfNoFunctionResultsAdded(added);
@@ -585,18 +581,15 @@ private void ThrowIfNoFunctionResultsAdded(IList? messages)
}
/// Processes the function call described in [].
- /// The current chat contents, inclusive of the function call contents being processed.
/// The options used for the response being processed.
- /// The response from the inner client.
/// The function call contents representing all the functions being invoked.
/// The iteration number of how many roundtrips have been made to the inner client.
/// The 0-based index of the function being called out of .
/// The to monitor for cancellation requests.
/// A value indicating how the caller should proceed.
private async Task ProcessFunctionCallAsync(
- Kernel kernel,
- List messages, ChatOptions options, ChatResponse response, List callContents,
+ List messages, ChatOptions options, List callContents,
int iteration, int functionCallIndex, CancellationToken cancellationToken)
{
var callContent = callContents[functionCallIndex];
@@ -610,12 +603,10 @@ private async Task ProcessFunctionCallAsync(
KernelFunctionInvocationContext context = new()
{
- Kernel = kernel,
Messages = messages,
- Response = response,
Options = options,
CallContent = callContent,
- Function = function,
+ AIFunction = function,
Iteration = iteration,
FunctionCallIndex = functionCallIndex,
FunctionCount = callContents.Count,
@@ -705,17 +696,16 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi
///
/// Invokes the auto function invocation filters.
///
- /// The auto function invocation context.
+ /// The auto function invocation context.
/// The function to call after the filters.
/// The auto function invocation context.
private async Task OnAutoFunctionInvocationAsync(
- KernelFunctionInvocationContext kernelContext,
+ AutoFunctionInvocationContext context,
Func functionCallCallback)
{
- var autoContext = invocationContext.ToAutoFunctionInvocationContext();
- await this.InvokeFilterOrFunctionAsync(functionCallCallback, kernelContext).ConfigureAwait(false);
+ await this.InvokeFilterOrFunctionAsync(functionCallCallback, context).ConfigureAwait(false);
- return kernelContext;
+ return context;
}
///
@@ -747,7 +737,6 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
/// Invokes the function asynchronously.
/// The function invocation context detailing the function to be invoked and its arguments along with additional request information.
- /// The containing the auto function invocations.
/// The to monitor for cancellation requests. The default is .
/// The result of the function invocation, or if the function invocation returned .
/// is .
@@ -755,7 +744,7 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
{
Verify.NotNull(invocationContext);
- using Activity? activity = _activitySource?.StartActivity(invocationContext.Function.Name);
+ using Activity? activity = _activitySource?.StartActivity(invocationContext.AIFunction.Name);
long startingTimestamp = 0;
if (_logger.IsEnabled(LogLevel.Debug))
@@ -763,11 +752,11 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
startingTimestamp = Stopwatch.GetTimestamp();
if (_logger.IsEnabled(LogLevel.Trace))
{
- LogInvokingSensitive(invocationContext.Function.Name, LoggingAsJson(invocationContext.CallContent.Arguments, invocationContext.Function.JsonSerializerOptions));
+ LogInvokingSensitive(invocationContext.AIFunction.Name, LoggingAsJson(invocationContext.CallContent.Arguments, invocationContext.AIFunction.JsonSerializerOptions));
}
else
{
- LogInvoking(invocationContext.Function.Name);
+ LogInvoking(invocationContext.AIFunction.Name);
}
}
@@ -775,31 +764,31 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
try
{
CurrentContext = invocationContext;
- invocationContext = this.OnAutoFunctionInvocationAsync(
- invocationContext,
- async (context) =>
+ if (invocationContext is AutoFunctionInvocationContext autoFunctionInvocationContext)
{
- // Check if filter requested termination
- if (context.Terminate)
- {
- return;
- }
-
- // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any
- // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here,
- // as the called function could in turn telling the model about itself as a possible candidate for invocation.
- KernelArguments? arguments = null;
- if (context.Arguments is not null)
+ autoFunctionInvocationContext = await this.OnAutoFunctionInvocationAsync(
+ autoFunctionInvocationContext,
+ async (context) =>
{
- arguments = new(context.Arguments);
- }
-
- var result = await invocationContext.Function.InvokeAsync(invocationContext.CallContent.Arguments, cancellationToken).ConfigureAwait(false);
- context.Result = new FunctionResult(context.Function, result);
- }).ConfigureAwait(false);
-
- )
- result =
+ // Check if filter requested termination
+ if (context.Terminate)
+ {
+ return;
+ }
+
+ // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any
+ // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here,
+ // as the called function could in turn telling the model about itself as a possible candidate for invocation.
+ KernelArguments? arguments = null;
+ if (context.Arguments is not null)
+ {
+ arguments = new(context.Arguments);
+ }
+
+ var result = await invocationContext.AIFunction.InvokeAsync(invocationContext.CallContent.Arguments, cancellationToken).ConfigureAwait(false);
+ context.Result = new FunctionResult(context.Function, result);
+ }).ConfigureAwait(false);
+ }
}
catch (Exception e)
{
@@ -811,11 +800,11 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
if (e is OperationCanceledException)
{
- LogInvocationCanceled(invocationContext.Function.Name);
+ LogInvocationCanceled(invocationContext.AIFunction.Name);
}
else
{
- LogInvocationFailed(invocationContext.Function.Name, e);
+ LogInvocationFailed(invocationContext.AIFunction.Name, e);
}
throw;
@@ -828,11 +817,11 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
if (result is not null && _logger.IsEnabled(LogLevel.Trace))
{
- LogInvocationCompletedSensitive(invocationContext.Function.Name, elapsed, LoggingAsJson(result, invocationContext.Function.JsonSerializerOptions));
+ LogInvocationCompletedSensitive(invocationContext.AIFunction.Name, elapsed, LoggingAsJson(result, invocationContext.AIFunction.JsonSerializerOptions));
}
else
{
- LogInvocationCompleted(invocationContext.Function.Name, elapsed);
+ LogInvocationCompleted(invocationContext.AIFunction.Name, elapsed);
}
}
}
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs
index 22968c47ea38..cfdad840349c 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs
@@ -37,8 +37,7 @@ public ChatHistory(string message, AuthorRole role)
{
Verify.NotNullOrWhiteSpace(message);
- this._messages = [];
- this.Add(new ChatMessageContent(role, message));
+ this._messages = [new ChatMessageContent(role, message)];
}
///
@@ -60,7 +59,7 @@ public ChatHistory(IEnumerable messages)
}
/// Gets the number of messages in the history.
- public int Count => this._messages.Count;
+ public virtual int Count => this._messages.Count;
///
/// Role of the message author
@@ -118,7 +117,7 @@ public void AddDeveloperMessage(string content) =>
/// Adds a message to the history.
/// The message to add.
/// is null.
- public void Add(ChatMessageContent item)
+ public virtual void Add(ChatMessageContent item)
{
Verify.NotNull(item);
this._messages.Add(item);
@@ -127,7 +126,7 @@ public void Add(ChatMessageContent item)
/// Adds the messages to the history.
/// The collection whose messages should be added to the history.
/// is null.
- public void AddRange(IEnumerable items)
+ public virtual void AddRange(IEnumerable items)
{
Verify.NotNull(items);
this._messages.AddRange(items);
@@ -137,7 +136,7 @@ public void AddRange(IEnumerable items)
/// The index at which the item should be inserted.
/// The message to insert.
/// is null.
- public void Insert(int index, ChatMessageContent item)
+ public virtual void Insert(int index, ChatMessageContent item)
{
Verify.NotNull(item);
this._messages.Insert(index, item);
@@ -151,17 +150,17 @@ public void Insert(int index, ChatMessageContent item)
/// is null.
/// The number of messages in the history is greater than the available space from to the end of .
/// is less than 0.
- public void CopyTo(ChatMessageContent[] array, int arrayIndex) => this._messages.CopyTo(array, arrayIndex);
+ public virtual void CopyTo(ChatMessageContent[] array, int arrayIndex) => this._messages.CopyTo(array, arrayIndex);
/// Removes all messages from the history.
- public void Clear() => this._messages.Clear();
+ public virtual void Clear() => this._messages.Clear();
/// Gets or sets the message at the specified index in the history.
/// The index of the message to get or set.
/// The message at the specified index.
/// is null.
/// The was not valid for this history.
- public ChatMessageContent this[int index]
+ public virtual ChatMessageContent this[int index]
{
get => this._messages[index];
set
@@ -175,7 +174,7 @@ public ChatMessageContent this[int index]
/// The message to locate.
/// true if the message is found in the history; otherwise, false.
/// is null.
- public bool Contains(ChatMessageContent item)
+ public virtual bool Contains(ChatMessageContent item)
{
Verify.NotNull(item);
return this._messages.Contains(item);
@@ -185,7 +184,7 @@ public bool Contains(ChatMessageContent item)
/// The message to locate.
/// The index of the first found occurrence of the specified message; -1 if the message could not be found.
/// is null.
- public int IndexOf(ChatMessageContent item)
+ public virtual int IndexOf(ChatMessageContent item)
{
Verify.NotNull(item);
return this._messages.IndexOf(item);
@@ -194,13 +193,13 @@ public int IndexOf(ChatMessageContent item)
/// Removes the message at the specified index from the history.
/// The index of the message to remove.
/// The was not valid for this history.
- public void RemoveAt(int index) => this._messages.RemoveAt(index);
+ public virtual void RemoveAt(int index) => this._messages.RemoveAt(index);
/// Removes the first occurrence of the specified message from the history.
/// The message to remove from the history.
/// true if the item was successfully removed; false if it wasn't located in the history.
/// is null.
- public bool Remove(ChatMessageContent item)
+ public virtual bool Remove(ChatMessageContent item)
{
Verify.NotNull(item);
return this._messages.Remove(item);
@@ -214,7 +213,7 @@ public bool Remove(ChatMessageContent item)
/// is less than 0.
/// is less than 0.
/// and do not denote a valid range of messages.
- public void RemoveRange(int index, int count)
+ public virtual void RemoveRange(int index, int count)
{
this._messages.RemoveRange(index, count);
}
diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
index d61923eda663..68e06140cc9b 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
@@ -1,7 +1,10 @@
// Copyright (c) Microsoft. All rights reserved.
+using System.Collections;
+using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
+using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel.ChatCompletion;
namespace Microsoft.SemanticKernel;
@@ -11,11 +14,28 @@ namespace Microsoft.SemanticKernel;
///
public class AutoFunctionInvocationContext : KernelFunctionInvocationContext
{
- private KernelFunctionInvocationContext _innerContext;
-
+ private readonly KernelFunction? _kernelFunction;
+ private readonly KernelFunctionInvocationContext? _innerContext;
+ ///
+ /// Initializes a new instance of the class from an existing .
+ ///
public AutoFunctionInvocationContext(KernelFunctionInvocationContext innerContext)
{
+ Verify.NotNull(innerContext);
+ Verify.NotNull(innerContext.Options);
+ Verify.NotNull(innerContext.Options.AdditionalProperties);
+
+ innerContext.Options.AdditionalProperties.TryGetValue("Kernel", out var kernel);
+ Verify.NotNull(kernel);
+
+ innerContext.Options.AdditionalProperties.TryGetValue("ChatMessageContent", out var chatMessageContent);
+ Verify.NotNull(chatMessageContent);
+
this._innerContext = innerContext;
+ this.ChatHistory = new ChatMessageHistory(innerContext.Messages);
+ this.ChatMessageContent = chatMessageContent;
+ this.Kernel = kernel;
+ this.Result = new FunctionResult(this._kernelFunction!) { Culture = kernel.Culture };
}
///
@@ -40,7 +60,7 @@ public AutoFunctionInvocationContext(
Verify.NotNull(chatMessageContent);
this.Kernel = kernel;
- this.Function = function;
+ this._kernelFunction = function;
this.Result = result;
this.ChatHistory = chatHistory;
this.ChatMessageContent = chatMessageContent;
@@ -72,11 +92,6 @@ public AutoFunctionInvocationContext(
///
public int FunctionSequenceIndex { get; init; }
- ///
- /// Number of functions that will be invoked during auto function invocation request.
- ///
- public int FunctionCount { get; init; }
-
///
/// The ID of the tool call.
///
@@ -101,7 +116,10 @@ public AutoFunctionInvocationContext(
///
/// Gets the with which this filter is associated.
///
- public KernelFunction Function { get; }
+ public KernelFunction Function
+ {
+ get => this._innerContext?.AIFunction.AsKernelFunction() ?? this._kernelFunction!;
+ }
///
/// Gets the containing services, plugins, and other state for use throughout the operation.
@@ -114,18 +132,85 @@ public AutoFunctionInvocationContext(
public FunctionResult Result { get; set; }
///
- /// Gets or sets a value indicating whether the operation associated with the filter should be terminated.
- ///
- /// By default, this value is , which means all functions will be invoked.
- /// If set to , the behavior depends on how functions are invoked:
- ///
- /// - If functions are invoked sequentially (the default behavior), the remaining functions will not be invoked,
- /// and the last request to the LLM will not be performed.
- ///
- /// - If functions are invoked concurrently (controlled by the option),
- /// other functions will still be invoked, and the last request to the LLM will not be performed.
- ///
- /// In both cases, the automatic function invocation process will be terminated, and the result of the last executed function will be returned to the caller.
+ /// Mutable chat message as chat history.
///
- public bool Terminate { get; set; }
+ internal class ChatMessageHistory : ChatHistory, IEnumerable
+ {
+ private readonly List _messages;
+
+ public ChatMessageHistory(IEnumerable messages) : base(messages.ToChatHistory())
+ {
+ this._messages = new List(messages);
+ }
+
+ public override void Add(ChatMessageContent item)
+ {
+ item.GetHashCode();
+ base.Add(item);
+ this._messages.Add(item.ToChatMessage());
+ }
+
+ public override void Clear()
+ {
+ base.Clear();
+ this._messages.Clear();
+ }
+
+ public override bool Contains(ChatMessageContent item)
+ {
+ return base.Contains(item);
+ }
+
+ public override bool Remove(ChatMessageContent item)
+ {
+ var index = base.IndexOf(item);
+
+ if (index < 0)
+ {
+ return false;
+ }
+
+ this._messages.RemoveAt(index);
+ base.RemoveAt(index);
+
+ return true;
+ }
+
+ public override void Insert(int index, ChatMessageContent item)
+ {
+ base.Insert(index, item);
+ this._messages.Insert(index, item.ToChatMessage());
+ }
+
+ public override void RemoveAt(int index)
+ {
+ this._messages.RemoveAt(index);
+ base.RemoveAt(index);
+ }
+
+ public override ChatMessageContent this[int index]
+ {
+ get => this._messages[index].ToChatMessageContent();
+ set
+ {
+ this._messages[index] = value.ToChatMessage();
+ base[index] = value;
+ }
+ }
+
+ public override int Count => this._messages.Count;
+
+ // Explicit implementation of IEnumerable.GetEnumerator()
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ foreach (var message in this._messages)
+ {
+ yield return message.ToChatMessageContent(); // Convert and yield each item
+ }
+ }
+
+ // Explicit implementation of non-generic IEnumerable.GetEnumerator()
+ IEnumerator IEnumerable.GetEnumerator()
+ => ((IEnumerable)this).GetEnumerator();
+ }
}
From 8663ab08d46fe5dbd517c79e76e0733d55e19fd6 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Mon, 7 Apr 2025 16:34:03 +0100
Subject: [PATCH 06/26] AutoFunctionInvocation WIP
---
.../Connectors.OpenAI.UnitTests.csproj | 6 +
...FunctionInvocationFilterChatClientTests.cs | 10 +-
..._multiple_function_calls_test_response.txt | 9 ++
...multiple_function_calls_test_response.json | 40 +++++
.../AI/ChatClient/ChatOptionsExtensions.cs | 21 +++
.../KernelFunctionInvokingChatClient.cs | 143 ++++++++++++------
.../AI/PromptExecutionSettingsExtensions.cs | 4 +-
.../AutoFunctionInvocationContext.cs | 9 +-
.../Functions/KernelFunction.cs | 2 +-
9 files changed, 185 insertions(+), 59 deletions(-)
create mode 100644 dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/chat_completion_streaming_chatclient_multiple_function_calls_test_response.txt
create mode 100644 dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_multiple_function_calls_test_response.json
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj
index 57f983dd5c9b..33557430e615 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj
@@ -96,6 +96,12 @@
Always
+
+ Always
+
+
+ Always
+
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
index 9a8e91a25df5..d28c67346c3f 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
@@ -596,7 +596,7 @@ public async Task PostFilterCanTerminateOperationOnStreamingAsync()
Assert.Equal([0], requestSequenceNumbers);
Assert.Equal([0], functionSequenceNumbers);
- // Results of function invoked before termination should be returned
+ // Results of function invoked before termination should be returned
Assert.Equal(3, streamingContent.Count);
var lastMessageContent = streamingContent[^1] as StreamingChatMessageContent;
@@ -732,8 +732,8 @@ public void Dispose()
private static List GetFunctionCallingResponses()
{
return [
- new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_multiple_function_calls_test_response.json")) },
- new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_multiple_function_calls_test_response.json")) },
+ new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_chatclient_multiple_function_calls_test_response.json")) },
+ new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_chatclient_multiple_function_calls_test_response.json")) },
new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_test_response.json")) }
];
}
@@ -741,8 +741,8 @@ private static List GetFunctionCallingResponses()
private static List GetFunctionCallingStreamingResponses()
{
return [
- new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_streaming_multiple_function_calls_test_response.txt")) },
- new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_streaming_multiple_function_calls_test_response.txt")) },
+ new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt")) },
+ new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt")) },
new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_streaming_test_response.txt")) }
];
}
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/chat_completion_streaming_chatclient_multiple_function_calls_test_response.txt b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/chat_completion_streaming_chatclient_multiple_function_calls_test_response.txt
new file mode 100644
index 000000000000..17ce94647fd5
--- /dev/null
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/chat_completion_streaming_chatclient_multiple_function_calls_test_response.txt
@@ -0,0 +1,9 @@
+data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":0,"id":"1","type":"function","function":{"name":"MyPlugin_GetCurrentWeather","arguments":"{\n\"location\": \"Boston, MA\"\n}"}}]},"finish_reason":"tool_calls"}]}
+
+data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":1,"id":"2","type":"function","function":{"name":"MyPlugin_FunctionWithException","arguments":"{\n\"argument\": \"value\"\n}"}}]},"finish_reason":"tool_calls"}]}
+
+data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":2,"id":"3","type":"function","function":{"name":"MyPlugin_NonExistentFunction","arguments":"{\n\"argument\": \"value\"\n}"}}]},"finish_reason":"tool_calls"}]}
+
+data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":3,"id":"4","type":"function","function":{"name":"MyPlugin_InvalidArguments","arguments":"invalid_arguments_format"}}]},"finish_reason":"tool_calls"}]}
+
+data: [DONE]
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_multiple_function_calls_test_response.json b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_multiple_function_calls_test_response.json
new file mode 100644
index 000000000000..2c499b14089f
--- /dev/null
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_multiple_function_calls_test_response.json
@@ -0,0 +1,40 @@
+{
+ "id": "response-id",
+ "object": "chat.completion",
+ "created": 1699896916,
+ "model": "gpt-3.5-turbo-0613",
+ "choices": [
+ {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": null,
+ "tool_calls": [
+ {
+ "id": "1",
+ "type": "function",
+ "function": {
+ "name": "MyPlugin_Function1",
+ "arguments": "{\n\"parameter\": \"function1-value\"\n}"
+ }
+ },
+ {
+ "id": "2",
+ "type": "function",
+ "function": {
+ "name": "MyPlugin_Function2",
+ "arguments": "{\n\"parameter\": \"function2-value\"\n}"
+ }
+ }
+ ]
+ },
+ "logprobs": null,
+ "finish_reason": "tool_calls"
+ }
+ ],
+ "usage": {
+ "prompt_tokens": 82,
+ "completion_tokens": 17,
+ "total_tokens": 99
+ }
+}
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs
index d8fab37e57bd..d6b5109f9b17 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs
@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Text.Json;
using Microsoft.Extensions.AI;
@@ -13,6 +14,10 @@ namespace Microsoft.SemanticKernel.ChatCompletion;
///
internal static class ChatOptionsExtensions
{
+ internal const string KernelKey = "AutoInvokingKernel";
+ internal const string IsStreamingKey = "AutoInvokingIsStreaming";
+ internal const string ChatMessageContentKey = "AutoInvokingChatCompletionContent";
+
/// Converts a to a .
internal static PromptExecutionSettings? ToPromptExecutionSettings(this ChatOptions? options)
{
@@ -118,4 +123,20 @@ internal static class ChatOptionsExtensions
return settings;
}
+
+ ///
+ /// To enable usage of AutoFunctionInvocationFilters with ChatClient's the kernel needs to be provided in the ChatOptions
+ ///
+ /// Chat options.
+ /// Kernel to be used for auto function invocation.
+ internal static ChatOptions? AddKernel(this ChatOptions options, Kernel kernel)
+ {
+ Verify.NotNull(options);
+ Verify.NotNull(kernel);
+
+ options.AdditionalProperties ??= [];
+ options.AdditionalProperties?.TryAdd(KernelKey, kernel);
+
+ return options;
+ }
}
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
index a17e9d02730a..360d3d53f184 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
@@ -48,7 +48,7 @@ namespace Microsoft.SemanticKernel.ChatCompletion;
internal sealed partial class KernelFunctionInvokingChatClient : DelegatingChatClient
{
/// The for the current function invocation.
- private static readonly AsyncLocal _currentContext = new();
+ private static readonly AsyncLocal s_currentContext = new();
/// The logger to use for logging information about function invocation.
private readonly ILogger _logger;
@@ -68,8 +68,8 @@ internal sealed partial class KernelFunctionInvokingChatClient : DelegatingChatC
public KernelFunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null)
: base(innerClient)
{
- _logger = logger ?? NullLogger.Instance;
- _activitySource = innerClient.GetService();
+ this._logger = logger ?? NullLogger.Instance;
+ this._activitySource = innerClient.GetService();
}
///
@@ -80,8 +80,8 @@ public KernelFunctionInvokingChatClient(IChatClient innerClient, ILogger? logger
///
internal static KernelFunctionInvocationContext? CurrentContext
{
- get => _currentContext.Value;
- set => _currentContext.Value = value;
+ get => s_currentContext.Value;
+ set => s_currentContext.Value = value;
}
///
@@ -97,7 +97,7 @@ internal static KernelFunctionInvocationContext? CurrentContext
///
///
/// Changing the value of this property while the client is in use might result in inconsistencies
- /// as to whether errors are retried during an in-flight request.
+ /// whether errors are retried during an in-flight request.
///
public bool RetryOnError { get; set; }
@@ -169,7 +169,7 @@ internal static KernelFunctionInvocationContext? CurrentContext
///
public int? MaximumIterationsPerRequest
{
- get => _maximumIterationsPerRequest;
+ get => this._maximumIterationsPerRequest;
set
{
if (value < 1)
@@ -177,7 +177,7 @@ public int? MaximumIterationsPerRequest
throw new ArgumentOutOfRangeException(nameof(value));
}
- _maximumIterationsPerRequest = value;
+ this._maximumIterationsPerRequest = value;
}
}
@@ -189,7 +189,7 @@ public override async Task GetResponseAsync(
// A single request into this GetResponseAsync may result in multiple requests to the inner client.
// Create an activity to group them together for better observability.
- using Activity? activity = _activitySource?.StartActivity(nameof(KernelFunctionInvokingChatClient));
+ using Activity? activity = this._activitySource?.StartActivity(nameof(KernelFunctionInvokingChatClient));
// Copy the original messages in order to avoid enumerating the original messages multiple times.
// The IEnumerable can represent an arbitrary amount of work.
@@ -217,7 +217,7 @@ public override async Task GetResponseAsync(
// Any function call work to do? If yes, ensure we're tracking that work in functionCallContents.
bool requiresFunctionInvocation =
options?.Tools is { Count: > 0 } &&
- (!MaximumIterationsPerRequest.HasValue || iteration < MaximumIterationsPerRequest.GetValueOrDefault()) &&
+ (!this.MaximumIterationsPerRequest.HasValue || iteration < this.MaximumIterationsPerRequest.GetValueOrDefault()) &&
CopyFunctionCalls(response.Messages, ref functionCallContents);
// In a common case where we make a request and there's no function calling work required,
@@ -252,11 +252,17 @@ public override async Task GetResponseAsync(
// Prepare the history for the next iteration.
FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId);
+ // Prepare the options for the next auto function invocation iteration.
+ UpdateOptionsForAutoFunctionInvocation(ref options!, response.Messages.Last().ToChatMessageContent(), isStreaming: false);
+
// Add the responses from the function calls into the augmented history and also into the tracked
// list of response messages.
- var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, cancellationToken).ConfigureAwait(false);
+ var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, cancellationToken).ConfigureAwait(false);
responseMessages.AddRange(modeAndMessages.MessagesAdded);
+ // Clear the auto function invocation options.
+ ClearOptionsForAutoFunctionInvocation(ref options!);
+
if (UpdateOptionsForMode(modeAndMessages.Mode, ref options!, response.ChatThreadId))
{
// Terminate
@@ -279,7 +285,7 @@ public override async IAsyncEnumerable GetStreamingResponseA
// A single request into this GetStreamingResponseAsync may result in multiple requests to the inner client.
// Create an activity to group them together for better observability.
- using Activity? activity = _activitySource?.StartActivity(nameof(KernelFunctionInvokingChatClient));
+ using Activity? activity = this._activitySource?.StartActivity(nameof(KernelFunctionInvokingChatClient));
// Copy the original messages in order to avoid enumerating the original messages multiple times.
// The IEnumerable can represent an arbitrary amount of work.
@@ -315,7 +321,7 @@ public override async IAsyncEnumerable GetStreamingResponseA
// If there are no tools to call, or for any other reason we should stop, return the response.
if (functionCallContents is not { Count: > 0 } ||
options?.Tools is not { Count: > 0 } ||
- (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations))
+ (this.MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations))
{
break;
}
@@ -327,12 +333,18 @@ public override async IAsyncEnumerable GetStreamingResponseA
// Prepare the history for the next iteration.
FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId);
- // Process all of the functions, adding their results into the history.
- var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false);
+ // Prepare the options for the next auto function invocation iteration.
+ UpdateOptionsForAutoFunctionInvocation(ref options!, response.Messages.Last().ToChatMessageContent(), isStreaming: true);
+
+ // Process all the functions, adding their results into the history.
+ var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false);
responseMessages.AddRange(modeAndMessages.MessagesAdded);
+ // Clear the auto function invocation options.
+ ClearOptionsForAutoFunctionInvocation(ref options!);
+
// Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages
- // includes all activities, including generated function results.
+ // include all activities, including generated function results.
string toolResponseId = Guid.NewGuid().ToString("N");
foreach (var message in modeAndMessages.MessagesAdded)
{
@@ -454,6 +466,37 @@ private static bool CopyFunctionCalls(
return any;
}
+ private static void UpdateOptionsForAutoFunctionInvocation(ref ChatOptions options, ChatMessageContent content, bool isStreaming)
+ {
+ if (options.AdditionalProperties?.ContainsKey(ChatOptionsExtensions.IsStreamingKey) ?? false)
+ {
+ throw new KernelException($"The reserved key name '{ChatOptionsExtensions.IsStreamingKey}' is already specified in the options. Avoid using this key name.");
+ }
+
+ if (options.AdditionalProperties?.ContainsKey(ChatOptionsExtensions.ChatMessageContentKey) ?? false)
+ {
+ throw new KernelException($"The reserved key name '{ChatOptionsExtensions.ChatMessageContentKey}' is already specified in the options. Avoid using this key name.");
+ }
+
+ options.AdditionalProperties ??= [];
+
+ options.AdditionalProperties[ChatOptionsExtensions.IsStreamingKey] = isStreaming;
+ options.AdditionalProperties[ChatOptionsExtensions.ChatMessageContentKey] = content;
+ }
+
+ private static void ClearOptionsForAutoFunctionInvocation(ref ChatOptions options)
+ {
+ if (options.AdditionalProperties?.ContainsKey(ChatOptionsExtensions.IsStreamingKey) ?? false)
+ {
+ options.AdditionalProperties.Remove(ChatOptionsExtensions.IsStreamingKey);
+ }
+
+ if (options.AdditionalProperties?.ContainsKey(ChatOptionsExtensions.ChatMessageContentKey) ?? false)
+ {
+ options.AdditionalProperties.Remove(ChatOptionsExtensions.ChatMessageContentKey);
+ }
+ }
+
/// Updates for the response.
/// true if the function calling loop should terminate; otherwise, false.
private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions options, string? chatThreadId)
@@ -517,11 +560,11 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti
// Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel.
if (functionCallContents.Count == 1)
{
- FunctionInvocationResult result = await ProcessFunctionCallAsync(
+ FunctionInvocationResult result = await this.ProcessFunctionCallAsync(
messages, options, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false);
- IList added = CreateResponseMessages([result]);
- ThrowIfNoFunctionResultsAdded(added);
+ IList added = this.CreateResponseMessages([result]);
+ this.ThrowIfNoFunctionResultsAdded(added);
messages.AddRange(added);
return (result.ContinueMode, added);
@@ -530,12 +573,12 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti
{
FunctionInvocationResult[] results;
- if (AllowConcurrentInvocation)
+ if (this.AllowConcurrentInvocation)
{
// Schedule the invocation of every function.
results = await Task.WhenAll(
from i in Enumerable.Range(0, functionCallContents.Count)
- select Task.Run(() => ProcessFunctionCallAsync(
+ select Task.Run(() => this.ProcessFunctionCallAsync(
messages, options, functionCallContents,
iteration, i, cancellationToken))).ConfigureAwait(false);
}
@@ -545,7 +588,7 @@ select Task.Run(() => ProcessFunctionCallAsync(
results = new FunctionInvocationResult[functionCallContents.Count];
for (int i = 0; i < results.Length; i++)
{
- results[i] = await ProcessFunctionCallAsync(
+ results[i] = await this.ProcessFunctionCallAsync(
messages, options, functionCallContents,
iteration, i, cancellationToken).ConfigureAwait(false);
}
@@ -553,8 +596,8 @@ select Task.Run(() => ProcessFunctionCallAsync(
ContinueMode continueMode = ContinueMode.Continue;
- IList added = CreateResponseMessages(results);
- ThrowIfNoFunctionResultsAdded(added);
+ IList added = this.CreateResponseMessages(results);
+ this.ThrowIfNoFunctionResultsAdded(added);
messages.AddRange(added);
foreach (FunctionInvocationResult fir in results)
@@ -576,7 +619,7 @@ private void ThrowIfNoFunctionResultsAdded(IList? messages)
{
if (messages is null || messages.Count == 0)
{
- throw new InvalidOperationException($"{GetType().Name}.{nameof(CreateResponseMessages)} returned null or an empty collection of messages.");
+ throw new InvalidOperationException($"{this.GetType().Name}.{nameof(this.CreateResponseMessages)} returned null or an empty collection of messages.");
}
}
@@ -595,13 +638,14 @@ private async Task ProcessFunctionCallAsync(
var callContent = callContents[functionCallIndex];
// Look up the AIFunction for the function call. If the requested function isn't available, send back an error.
- AIFunction? function = options.Tools!.OfType().FirstOrDefault(t => t.Name == callContent.Name);
+ AIFunction? function = options.Tools!.OfType().FirstOrDefault(
+ t => t.Name == callContent.Name);
if (function is null)
{
return new(ContinueMode.Continue, FunctionInvocationStatus.NotFound, callContent, result: null, exception: null);
}
- KernelFunctionInvocationContext context = new()
+ var context = new AutoFunctionInvocationContext(new()
{
Messages = messages,
Options = options,
@@ -609,18 +653,17 @@ private async Task ProcessFunctionCallAsync(
AIFunction = function,
Iteration = iteration,
FunctionCallIndex = functionCallIndex,
- FunctionCount = callContents.Count,
- };
+ FunctionCount = callContents.Count
+ });
object? result;
try
{
- result = await InvokeFunctionAsync(context, cancellationToken).ConfigureAwait(false);
+ result = await this.InvokeFunctionAsync(context, cancellationToken).ConfigureAwait(false);
}
catch (Exception e) when (!cancellationToken.IsCancellationRequested)
{
- return new(
- RetryOnError ? ContinueMode.Continue : ContinueMode.AllowOneMoreRoundtrip, // We won't allow further function calls, hence the LLM will just get one more chance to give a final answer.
+ return new(this.RetryOnError ? ContinueMode.Continue : ContinueMode.AllowOneMoreRoundtrip, // We won't allow further function calls, hence the LLM will just get one more chance to give a final answer.
FunctionInvocationStatus.Exception,
callContent,
result: null,
@@ -681,7 +724,7 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi
_ => "Error: Unknown error.",
};
- if (IncludeDetailedErrors && result.Exception is not null)
+ if (this.IncludeDetailedErrors && result.Exception is not null)
{
message = $"{message} Exception: {result.Exception.Message}";
}
@@ -744,19 +787,19 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
{
Verify.NotNull(invocationContext);
- using Activity? activity = _activitySource?.StartActivity(invocationContext.AIFunction.Name);
+ using Activity? activity = this._activitySource?.StartActivity(invocationContext.AIFunction.Name);
long startingTimestamp = 0;
- if (_logger.IsEnabled(LogLevel.Debug))
+ if (this._logger.IsEnabled(LogLevel.Debug))
{
startingTimestamp = Stopwatch.GetTimestamp();
- if (_logger.IsEnabled(LogLevel.Trace))
+ if (this._logger.IsEnabled(LogLevel.Trace))
{
- LogInvokingSensitive(invocationContext.AIFunction.Name, LoggingAsJson(invocationContext.CallContent.Arguments, invocationContext.AIFunction.JsonSerializerOptions));
+ this.LogInvokingSensitive(invocationContext.AIFunction.Name, LoggingAsJson(invocationContext.CallContent.Arguments, invocationContext.AIFunction.JsonSerializerOptions));
}
else
{
- LogInvoking(invocationContext.AIFunction.Name);
+ this.LogInvoking(invocationContext.AIFunction.Name);
}
}
@@ -788,6 +831,8 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
var result = await invocationContext.AIFunction.InvokeAsync(invocationContext.CallContent.Arguments, cancellationToken).ConfigureAwait(false);
context.Result = new FunctionResult(context.Function, result);
}).ConfigureAwait(false);
+
+ invocationContext = autoFunctionInvocationContext;
}
}
catch (Exception e)
@@ -800,28 +845,28 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
if (e is OperationCanceledException)
{
- LogInvocationCanceled(invocationContext.AIFunction.Name);
+ this.LogInvocationCanceled(invocationContext.AIFunction.Name);
}
else
{
- LogInvocationFailed(invocationContext.AIFunction.Name, e);
+ this.LogInvocationFailed(invocationContext.AIFunction.Name, e);
}
throw;
}
finally
{
- if (_logger.IsEnabled(LogLevel.Debug))
+ if (this._logger.IsEnabled(LogLevel.Debug))
{
TimeSpan elapsed = GetElapsedTime(startingTimestamp);
- if (result is not null && _logger.IsEnabled(LogLevel.Trace))
+ if (result is not null && this._logger.IsEnabled(LogLevel.Trace))
{
- LogInvocationCompletedSensitive(invocationContext.AIFunction.Name, elapsed, LoggingAsJson(result, invocationContext.AIFunction.JsonSerializerOptions));
+ this.LogInvocationCompletedSensitive(invocationContext.AIFunction.Name, elapsed, LoggingAsJson(result, invocationContext.AIFunction.JsonSerializerOptions));
}
else
{
- LogInvocationCompleted(invocationContext.AIFunction.Name, elapsed);
+ this.LogInvocationCompleted(invocationContext.AIFunction.Name, elapsed);
}
}
}
@@ -881,11 +926,11 @@ public sealed class FunctionInvocationResult
{
internal FunctionInvocationResult(ContinueMode continueMode, FunctionInvocationStatus status, Microsoft.Extensions.AI.FunctionCallContent callContent, object? result, Exception? exception)
{
- ContinueMode = continueMode;
- Status = status;
- CallContent = callContent;
- Result = result;
- Exception = exception;
+ this.ContinueMode = continueMode;
+ this.Status = status;
+ this.CallContent = callContent;
+ this.Result = result;
+ this.Exception = exception;
}
/// Gets status about how the function invocation completed.
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs
index 98bb09be6f85..f9d4a5db151b 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs
@@ -8,6 +8,7 @@
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using Microsoft.Extensions.AI;
+using Microsoft.SemanticKernel.ChatCompletion;
namespace Microsoft.SemanticKernel;
@@ -149,7 +150,8 @@ internal static class PromptExecutionSettingsExtensions
options.Tools = functions.Select(f => f.AsAIFunction(kernel)).Cast().ToList();
}
- return options;
+ // Enables usage of AutoFunctionInvocationFilters
+ return options.AddKernel(kernel!);
// Be a little lenient on the types of the values used in the extension data,
// e.g. allow doubles even when requesting floats.
diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
index 68e06140cc9b..9dd8b094e50c 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
@@ -25,17 +25,20 @@ public AutoFunctionInvocationContext(KernelFunctionInvocationContext innerContex
Verify.NotNull(innerContext.Options);
Verify.NotNull(innerContext.Options.AdditionalProperties);
- innerContext.Options.AdditionalProperties.TryGetValue("Kernel", out var kernel);
+ innerContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.KernelKey, out var kernel);
Verify.NotNull(kernel);
- innerContext.Options.AdditionalProperties.TryGetValue("ChatMessageContent", out var chatMessageContent);
+ innerContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.ChatMessageContentKey, out var chatMessageContent);
Verify.NotNull(chatMessageContent);
+ innerContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.IsStreamingKey, out var isStreaming);
+ this.IsStreaming = isStreaming;
+
this._innerContext = innerContext;
this.ChatHistory = new ChatMessageHistory(innerContext.Messages);
this.ChatMessageContent = chatMessageContent;
this.Kernel = kernel;
- this.Result = new FunctionResult(this._kernelFunction!) { Culture = kernel.Culture };
+ this.Result = new FunctionResult(this.Function) { Culture = kernel.Culture };
}
///
diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs
index a0a425aca1ec..a3af72a44d08 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs
@@ -532,7 +532,7 @@ public KernelAIFunction(KernelFunction kernelFunction, Kernel? kernel)
this.JsonSchema = BuildFunctionSchema(kernelFunction);
}
-
+ internal string KernelFunctionName => this._kernelFunction.Name;
public override string Name { get; }
public override JsonElement JsonSchema { get; }
public override string Description => this._kernelFunction.Description;
From 704c80b07babd16dae14886ed28590adb60ee7d2 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Tue, 8 Apr 2025 17:35:35 +0100
Subject: [PATCH 07/26] AutoContext losing result
---
.../KernelFunctionInvokingChatClient.cs | 44 ++++----
.../AutoFunctionInvocationContext.cs | 105 +++++++++++-------
2 files changed, 88 insertions(+), 61 deletions(-)
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
index 360d3d53f184..66b6e46e84ec 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
@@ -645,16 +645,20 @@ private async Task ProcessFunctionCallAsync(
return new(ContinueMode.Continue, FunctionInvocationStatus.NotFound, callContent, result: null, exception: null);
}
- var context = new AutoFunctionInvocationContext(new()
+ if (callContent.Arguments is not null)
+ {
+ callContent.Arguments = new KernelArguments(callContent.Arguments);
+ }
+
+ var context = new AutoFunctionInvocationContext(options)
{
Messages = messages,
- Options = options,
CallContent = callContent,
AIFunction = function,
Iteration = iteration,
FunctionCallIndex = functionCallIndex,
FunctionCount = callContents.Count
- });
+ };
object? result;
try
@@ -810,29 +814,23 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
if (invocationContext is AutoFunctionInvocationContext autoFunctionInvocationContext)
{
autoFunctionInvocationContext = await this.OnAutoFunctionInvocationAsync(
- autoFunctionInvocationContext,
- async (context) =>
- {
- // Check if filter requested termination
- if (context.Terminate)
- {
- return;
- }
-
- // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any
- // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here,
- // as the called function could in turn telling the model about itself as a possible candidate for invocation.
- KernelArguments? arguments = null;
- if (context.Arguments is not null)
+ autoFunctionInvocationContext,
+ async (context) =>
{
- arguments = new(context.Arguments);
- }
-
- var result = await invocationContext.AIFunction.InvokeAsync(invocationContext.CallContent.Arguments, cancellationToken).ConfigureAwait(false);
- context.Result = new FunctionResult(context.Function, result);
+ // Check if filter requested termination
+ if (context.Terminate)
+ {
+ return;
+ }
+
+ // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any
+ // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here,
+ // as the called function could in turn telling the model about itself as a possible candidate for invocation.
+ result = await invocationContext.AIFunction.InvokeAsync(autoFunctionInvocationContext.Arguments, cancellationToken).ConfigureAwait(false);
+ context.Result = new FunctionResult(context.Function, result);
}).ConfigureAwait(false);
- invocationContext = autoFunctionInvocationContext;
+ result = autoFunctionInvocationContext.Result.GetValue();
}
}
catch (Exception e)
diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
index 9dd8b094e50c..a431ed868381 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
+using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
@@ -14,30 +15,29 @@ namespace Microsoft.SemanticKernel;
///
public class AutoFunctionInvocationContext : KernelFunctionInvocationContext
{
- private readonly KernelFunction? _kernelFunction;
- private readonly KernelFunctionInvocationContext? _innerContext;
+ private ChatHistory? _chatHistory;
+
///
/// Initializes a new instance of the class from an existing .
///
- public AutoFunctionInvocationContext(KernelFunctionInvocationContext innerContext)
+ public AutoFunctionInvocationContext(ChatOptions options)
{
- Verify.NotNull(innerContext);
- Verify.NotNull(innerContext.Options);
- Verify.NotNull(innerContext.Options.AdditionalProperties);
+ this.Options = options;
+
+ // To create a AutoFunctionInvocationContext from a KernelFunctionInvocationContext,
+ // the ChatOptions must be provided with AdditionalProperties.
+ Verify.NotNull(options.AdditionalProperties);
- innerContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.KernelKey, out var kernel);
+ options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.KernelKey, out var kernel);
Verify.NotNull(kernel);
- innerContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.ChatMessageContentKey, out var chatMessageContent);
+ options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.ChatMessageContentKey, out var chatMessageContent);
Verify.NotNull(chatMessageContent);
- innerContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.IsStreamingKey, out var isStreaming);
- this.IsStreaming = isStreaming;
+ options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.IsStreamingKey, out var isStreaming);
+ Verify.NotNull(isStreaming);
+ this.IsStreaming = isStreaming.Value;
- this._innerContext = innerContext;
- this.ChatHistory = new ChatMessageHistory(innerContext.Messages);
- this.ChatMessageContent = chatMessageContent;
- this.Kernel = kernel;
this.Result = new FunctionResult(this.Function) { Culture = kernel.Culture };
}
@@ -62,11 +62,18 @@ public AutoFunctionInvocationContext(
Verify.NotNull(chatHistory);
Verify.NotNull(chatMessageContent);
- this.Kernel = kernel;
- this._kernelFunction = function;
+ this.Options = new()
+ {
+ AdditionalProperties = new()
+ {
+ [ChatOptionsExtensions.ChatMessageContentKey] = chatMessageContent,
+ [ChatOptionsExtensions.KernelKey] = kernel
+ }
+ };
+
+ this.Messages = chatHistory.ToChatMessageList();
+ this.AIFunction = function.AsAIFunction();
this.Result = result;
- this.ChatHistory = chatHistory;
- this.ChatMessageContent = chatMessageContent;
}
///
@@ -83,27 +90,48 @@ public AutoFunctionInvocationContext(
///
/// Gets the arguments associated with the operation.
///
- public KernelArguments? Arguments { get; init; }
+ public KernelArguments? Arguments
+ {
+ get => this.CallContent.Arguments is KernelArguments kernelArguments ? kernelArguments : null;
+ init => this.CallContent.Arguments = value;
+ }
///
/// Request sequence index of automatic function invocation process. Starts from 0.
///
- public int RequestSequenceIndex { get; init; }
+ public int RequestSequenceIndex
+ {
+ get => this.Iteration;
+ init => this.Iteration = value;
+ }
///
/// Function sequence index. Starts from 0.
///
- public int FunctionSequenceIndex { get; init; }
+ public int FunctionSequenceIndex
+ {
+ get => this.FunctionCallIndex;
+ init => this.FunctionCallIndex = value;
+ }
///
/// The ID of the tool call.
///
- public string? ToolCallId { get; init; }
+ public string? ToolCallId
+ {
+ get => this.CallContent.CallId;
+ init
+ {
+ Verify.NotNull(value);
+ // ToolCallId
+ this.CallContent = new Microsoft.Extensions.AI.FunctionCallContent(value, this.CallContent.Name, this.CallContent.Arguments);
+ }
+ }
///
/// The chat message content associated with automatic function invocation.
///
- public ChatMessageContent ChatMessageContent { get; }
+ public ChatMessageContent ChatMessageContent => (this.Options?.AdditionalProperties?[ChatOptionsExtensions.ChatMessageContentKey] as ChatMessageContent)!;
///
/// The execution settings associated with the operation.
@@ -114,20 +142,27 @@ public AutoFunctionInvocationContext(
///
/// Gets the associated with automatic function invocation.
///
- public ChatHistory ChatHistory { get; }
+ public ChatHistory ChatHistory => this._chatHistory ??= new ChatMessageHistory(this.Messages);
///
/// Gets the with which this filter is associated.
///
- public KernelFunction Function
- {
- get => this._innerContext?.AIFunction.AsKernelFunction() ?? this._kernelFunction!;
- }
+ public KernelFunction Function => this.AIFunction.AsKernelFunction();
///
/// Gets the containing services, plugins, and other state for use throughout the operation.
///
- public Kernel Kernel { get; }
+ public Kernel Kernel
+ {
+ get
+ {
+ Kernel? kernel = null;
+ this.Options?.AdditionalProperties?.TryGetValue(ChatOptionsExtensions.KernelKey, out kernel);
+
+ // To avoid exception from properties, when attempting to retrieve a kernel from a non-ready context, it will give a null.
+ return kernel!;
+ }
+ }
///
/// Gets or sets the result of the function's invocation.
@@ -135,20 +170,19 @@ public KernelFunction Function
public FunctionResult Result { get; set; }
///
- /// Mutable chat message as chat history.
+ /// Mutable IEnumerable of chat message as chat history.
///
- internal class ChatMessageHistory : ChatHistory, IEnumerable
+ private class ChatMessageHistory : ChatHistory, IEnumerable
{
private readonly List _messages;
- public ChatMessageHistory(IEnumerable messages) : base(messages.ToChatHistory())
+ internal ChatMessageHistory(IEnumerable messages) : base(messages.ToChatHistory())
{
this._messages = new List(messages);
}
public override void Add(ChatMessageContent item)
{
- item.GetHashCode();
base.Add(item);
this._messages.Add(item.ToChatMessage());
}
@@ -159,11 +193,6 @@ public override void Clear()
this._messages.Clear();
}
- public override bool Contains(ChatMessageContent item)
- {
- return base.Contains(item);
- }
-
public override bool Remove(ChatMessageContent item)
{
var index = base.IndexOf(item);
From e4ae4948bf1b435500c5e54f57f471230ad7c2db Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Wed, 9 Apr 2025 13:59:46 +0100
Subject: [PATCH 08/26] FilterCanOverrideArguments
---
...FunctionInvocationFilterChatClientTests.cs | 20 ++++++++++++++++++-
1 file changed, 19 insertions(+), 1 deletion(-)
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
index d28c67346c3f..908acf9356a0 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
@@ -7,6 +7,7 @@
using System.Net;
using System.Net.Http;
using System.Threading.Tasks;
+using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
@@ -316,7 +317,12 @@ public async Task FilterCanOverrideArgumentsAsync()
}));
// Assert
- Assert.Equal("NewValue", result.ToString());
+ var chatResponse = Assert.IsType(result.GetValue());
+ Assert.NotNull(chatResponse);
+
+ var lastFunctionResult = GetLastFunctionResultFromChatResponse(chatResponse);
+ Assert.NotNull(lastFunctionResult);
+ Assert.Equal("NewValue", lastFunctionResult.ToString());
}
[Fact]
@@ -728,6 +734,18 @@ public void Dispose()
#region private
+ private static object? GetLastFunctionResultFromChatResponse(ChatResponse chatResponse)
+ {
+ Assert.NotEmpty(chatResponse.Messages);
+ var chatMessage = chatResponse.Messages[^1];
+
+ Assert.NotEmpty(chatMessage.Contents);
+ Assert.Contains(chatMessage.Contents, c => c is Microsoft.Extensions.AI.FunctionResultContent);
+
+ var resultContent = (Microsoft.Extensions.AI.FunctionResultContent)chatMessage.Contents.Last(c => c is Microsoft.Extensions.AI.FunctionResultContent);
+ return resultContent.Result;
+ }
+
#pragma warning disable CA2000 // Dispose objects before losing scope
private static List GetFunctionCallingResponses()
{
From 7a5fc08158a2e9cbae81faf2fb34d04a2504bcf9 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Wed, 9 Apr 2025 14:51:23 +0100
Subject: [PATCH 09/26] Fix failing UT + Aded PromptExecutionSettings to be
propagated into AutoInvocationContext
---
.../Connectors.OpenAI.UnitTests.csproj | 3 ++
...FunctionInvocationFilterChatClientTests.cs | 28 +++++++++----------
..._multiple_function_calls_test_response.txt | 5 ++++
.../AI/ChatClient/ChatClientExtensions.cs | 26 +++++++++++++++++
.../AI/ChatClient/ChatOptionsExtensions.cs | 1 +
.../KernelFunctionInvokingChatClient.cs | 2 +-
.../AutoFunctionInvocationContext.cs | 13 ++++++++-
.../Functions/KernelFunctionFromPrompt.cs | 2 +-
8 files changed, 63 insertions(+), 17 deletions(-)
create mode 100644 dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj
index 33557430e615..04d35b9e6561 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Connectors.OpenAI.UnitTests.csproj
@@ -102,6 +102,9 @@
Always
+
+ Always
+
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
index 908acf9356a0..3031ead057b4 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
@@ -67,7 +67,7 @@ public async Task FiltersAreExecutedCorrectlyAsync()
// Act
var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings
{
- ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
+ FunctionChoiceBehavior = FunctionChoiceBehavior.Auto()
}));
// Assert
@@ -145,7 +145,7 @@ public async Task FiltersAreExecutedCorrectlyOnStreamingAsync()
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses();
- var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };
+ var executionSettings = new OpenAIPromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };
// Act
await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings)))
@@ -350,7 +350,7 @@ public async Task FilterCanHandleExceptionAsync()
var chatCompletion = new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient);
- var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };
+ var executionSettings = new OpenAIPromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };
var chatHistory = new ChatHistory();
chatHistory.AddSystemMessage("System message");
@@ -391,7 +391,7 @@ public async Task FilterCanHandleExceptionOnStreamingAsync()
var chatCompletion = new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient);
var chatHistory = new ChatHistory();
- var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };
+ var executionSettings = new OpenAIPromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };
// Act
await foreach (var item in chatCompletion.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel))
@@ -437,7 +437,7 @@ public async Task FiltersCanSkipFunctionExecutionAsync()
// Act
var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings
{
- ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
+ FunctionChoiceBehavior = FunctionChoiceBehavior.Auto()
}));
// Assert
@@ -469,9 +469,9 @@ public async Task PreFilterCanTerminateOperationAsync()
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();
// Act
- await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings
+ await kernel.InvokePromptAsync("Test prompt", new(new PromptExecutionSettings
{
- ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
+ FunctionChoiceBehavior = FunctionChoiceBehavior.Auto()
}));
// Assert
@@ -501,7 +501,7 @@ public async Task PreFilterCanTerminateOperationOnStreamingAsync()
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses();
- var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };
+ var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };
// Act
await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings)))
@@ -542,7 +542,7 @@ public async Task PostFilterCanTerminateOperationAsync()
// Act
var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings
{
- ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
+ FunctionChoiceBehavior = FunctionChoiceBehavior.Auto()
}));
// Assert
@@ -586,7 +586,7 @@ public async Task PostFilterCanTerminateOperationOnStreamingAsync()
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses();
- var executionSettings = new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions };
+ var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };
List streamingContent = [];
@@ -683,9 +683,9 @@ public async Task PromptExecutionSettingsArePropagatedFromInvokePromptToFilterCo
return Task.CompletedTask;
});
- var expectedExecutionSettings = new OpenAIPromptExecutionSettings
+ var expectedExecutionSettings = new PromptExecutionSettings
{
- ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
+ FunctionChoiceBehavior = FunctionChoiceBehavior.Auto()
};
// Act
@@ -712,9 +712,9 @@ public async Task PromptExecutionSettingsArePropagatedFromInvokePromptStreamingT
return Task.CompletedTask;
});
- var expectedExecutionSettings = new OpenAIPromptExecutionSettings
+ var expectedExecutionSettings = new PromptExecutionSettings
{
- ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
+ FunctionChoiceBehavior = FunctionChoiceBehavior.Auto()
};
// Act
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt
new file mode 100644
index 000000000000..c113e3fa97ca
--- /dev/null
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/TestData/filters_chatclient_streaming_multiple_function_calls_test_response.txt
@@ -0,0 +1,5 @@
+data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":0,"id":"1","type":"function","function":{"name":"MyPlugin_Function1","arguments":"{\n\"parameter\": \"function1-value\"\n}"}}]},"finish_reason":"tool_calls"}]}
+
+data: {"id":"response-id","object":"chat.completion.chunk","created":1704212243,"model":"gpt-4","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"Test chat streaming response","tool_calls":[{"index":1,"id":"2","type":"function","function":{"name":"MyPlugin_Function2","arguments":"{\n\"parameter\": \"function2-value\"\n}"}}]},"finish_reason":"tool_calls"}]}
+
+data: [DONE]
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
index c480772750de..f3aa9c4ad069 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
+using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;
@@ -30,6 +31,9 @@ internal static Task GetResponseAsync(
{
var chatOptions = executionSettings?.ToChatOptions(kernel);
+ // Passing by reference to be used by AutoFunctionInvocationFilters
+ chatOptions.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = executionSettings;
+
// Try to parse the text as a chat history
if (ChatPromptParser.TryParse(prompt, out var chatHistoryFromPrompt))
{
@@ -40,6 +44,28 @@ internal static Task GetResponseAsync(
return chatClient.GetResponseAsync(prompt, chatOptions, cancellationToken);
}
+ /// Get ChatClient streaming response for the prompt, settings and kernel.
+ /// Target chat client service.
+ /// The standardized prompt input.
+ /// The AI execution settings (optional).
+ /// The containing services, plugins, and other state for use throughout the operation.
+ /// The to monitor for cancellation requests. The default is .
+ /// Streaming list of different completion streaming string updates generated by the remote model
+ internal static IAsyncEnumerable GetStreamingResponseAsync(
+ this IChatClient chatClient,
+ string prompt,
+ PromptExecutionSettings? executionSettings = null,
+ Kernel? kernel = null,
+ CancellationToken cancellationToken = default)
+ {
+ var chatOptions = executionSettings?.ToChatOptions(kernel);
+
+ // Passing by reference to be used by AutoFunctionInvocationFilters
+ chatOptions.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = executionSettings;
+
+ return chatClient.GetStreamingResponseAsync(prompt, chatOptions, cancellationToken);
+ }
+
/// Creates an for the specified .
/// The chat client to be represented as a chat completion service.
/// An optional that can be used to resolve services to use in the instance.
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs
index d6b5109f9b17..de2c6831b93d 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs
@@ -17,6 +17,7 @@ internal static class ChatOptionsExtensions
internal const string KernelKey = "AutoInvokingKernel";
internal const string IsStreamingKey = "AutoInvokingIsStreaming";
internal const string ChatMessageContentKey = "AutoInvokingChatCompletionContent";
+ internal const string PromptExecutionSettingsKey = "AutoInvokingPromptExecutionSettings";
/// Converts a to a .
internal static PromptExecutionSettings? ToPromptExecutionSettings(this ChatOptions? options)
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
index 66b6e46e84ec..3f54113ab997 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
@@ -657,7 +657,7 @@ private async Task ProcessFunctionCallAsync(
AIFunction = function,
Iteration = iteration,
FunctionCallIndex = functionCallIndex,
- FunctionCount = callContents.Count
+ FunctionCount = callContents.Count,
};
object? result;
diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
index a431ed868381..e642c1a55d8f 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
@@ -34,6 +34,9 @@ public AutoFunctionInvocationContext(ChatOptions options)
options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.ChatMessageContentKey, out var chatMessageContent);
Verify.NotNull(chatMessageContent);
+ options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.PromptExecutionSettingsKey, out var executionSettings);
+ this.ExecutionSettings = executionSettings;
+
options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.IsStreamingKey, out var isStreaming);
Verify.NotNull(isStreaming);
this.IsStreaming = isStreaming.Value;
@@ -137,7 +140,15 @@ public string? ToolCallId
/// The execution settings associated with the operation.
///
[Experimental("SKEXP0001")]
- public PromptExecutionSettings? ExecutionSettings { get; init; }
+ public PromptExecutionSettings? ExecutionSettings
+ {
+ get => this.Options?.AdditionalProperties?[ChatOptionsExtensions.PromptExecutionSettingsKey] as PromptExecutionSettings;
+ init
+ {
+ this.Options.AdditionalProperties ??= [];
+ this.Options.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = value;
+ }
+ }
///
/// Gets the associated with automatic function invocation.
diff --git a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs
index 0fc3984f8d80..9db8f8e1ce46 100644
--- a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs
+++ b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs
@@ -286,7 +286,7 @@ protected override async IAsyncEnumerable InvokeStreamingCoreAsync
Date: Wed, 9 Apr 2025 18:31:20 +0100
Subject: [PATCH 10/26] Resolving UT for FilterCanHandleException
---
...FunctionInvocationFilterChatClientTests.cs | 30 +++++++++++--------
.../ChatCompletion/ChatHistoryExtensions.cs | 5 +++-
.../AI/PromptExecutionSettingsExtensions.cs | 5 ++--
3 files changed, 24 insertions(+), 16 deletions(-)
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
index 3031ead057b4..0e0a7c1d4246 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
@@ -6,6 +6,7 @@
using System.Linq;
using System.Net;
using System.Net.Http;
+using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
@@ -185,10 +186,7 @@ public async Task DifferentWaysOfAddingFiltersWorkCorrectlyAsync()
builder.Plugins.Add(plugin);
- builder.Services.AddSingleton((serviceProvider) =>
- {
- return new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient);
- });
+ builder.Services.AddOpenAIChatClient("model-id", "test-api-key", "organization-id", httpClient: this._httpClient);
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();
@@ -202,9 +200,9 @@ public async Task DifferentWaysOfAddingFiltersWorkCorrectlyAsync()
// Case #2 - Add filter to kernel
kernel.AutoFunctionInvocationFilters.Add(filter2);
- var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings
+ var result = await kernel.InvokePromptAsync("Test prompt", new(new PromptExecutionSettings
{
- ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
+ FunctionChoiceBehavior = FunctionChoiceBehavior.Auto()
}));
// Assert
@@ -348,22 +346,28 @@ public async Task FilterCanHandleExceptionAsync()
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();
- var chatCompletion = new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient);
+ var chatClient = kernel.GetRequiredService();
- var executionSettings = new OpenAIPromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };
+ var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };
var chatHistory = new ChatHistory();
chatHistory.AddSystemMessage("System message");
// Act
- var result = await chatCompletion.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel);
+ var messageList = chatHistory.ToChatMessageList();
+ var options = executionSettings.ToChatOptions(kernel);
+ var resultMessages = await chatClient.GetResponseAsync(messageList, options, CancellationToken.None);
- var firstFunctionResult = chatHistory[^2].Content;
- var secondFunctionResult = chatHistory[^1].Content;
+ var firstToolMessage = resultMessages.Messages.First(m => m.Role == ChatRole.Tool);
+ Assert.NotNull(firstToolMessage);
+ var firstFunctionResult = firstToolMessage.Contents[^2] as Microsoft.Extensions.AI.FunctionResultContent;
+ var secondFunctionResult = firstToolMessage.Contents[^1] as Microsoft.Extensions.AI.FunctionResultContent;
// Assert
- Assert.Equal("Result from filter", firstFunctionResult);
- Assert.Equal("Result from Function2", secondFunctionResult);
+ Assert.NotNull(firstFunctionResult);
+ Assert.NotNull(secondFunctionResult);
+ Assert.Equal("Result from filter", firstFunctionResult.Result!.ToString());
+ Assert.Equal("Result from Function2", secondFunctionResult.Result!.ToString());
}
[Fact]
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs
index a238e77417da..f027a9b7a31c 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs
@@ -82,6 +82,9 @@ public static async Task ReduceAsync(this ChatHistory chatHistory,
return chatHistory;
}
- internal static List ToChatMessageList(this ChatHistory chatHistory)
+ /// Converts a to a list.
+ /// The chat history to convert.
+ /// A list of objects.
+ public static List ToChatMessageList(this ChatHistory chatHistory)
=> chatHistory.Select(m => m.ToChatMessage()).ToList();
}
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs
index f9d4a5db151b..74fe27c2e841 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/PromptExecutionSettingsExtensions.cs
@@ -12,10 +12,11 @@
namespace Microsoft.SemanticKernel;
-internal static class PromptExecutionSettingsExtensions
+/// Extensions methods for .
+public static class PromptExecutionSettingsExtensions
{
/// Converts a pair of and to a .
- internal static ChatOptions? ToChatOptions(this PromptExecutionSettings? settings, Kernel? kernel)
+ public static ChatOptions? ToChatOptions(this PromptExecutionSettings? settings, Kernel? kernel)
{
if (settings is null)
{
From 85c8cc44ed3688e1643bd58f77c94f44fd4f6c2e Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Wed, 9 Apr 2025 18:58:34 +0100
Subject: [PATCH 11/26] Adjust for HandleExceptionONStreaming
---
...FunctionInvocationFilterChatClientTests.cs | 38 +++++++++++--------
.../ChatCompletion/ChatHistoryExtensions.cs | 2 +-
2 files changed, 23 insertions(+), 17 deletions(-)
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
index 0e0a7c1d4246..3c42a004eb32 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
@@ -349,21 +349,18 @@ public async Task FilterCanHandleExceptionAsync()
var chatClient = kernel.GetRequiredService();
var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };
-
- var chatHistory = new ChatHistory();
- chatHistory.AddSystemMessage("System message");
+ var options = executionSettings.ToChatOptions(kernel);
+ List messageList = [new(ChatRole.System, "System message")];
// Act
- var messageList = chatHistory.ToChatMessageList();
- var options = executionSettings.ToChatOptions(kernel);
var resultMessages = await chatClient.GetResponseAsync(messageList, options, CancellationToken.None);
+ // Assert
var firstToolMessage = resultMessages.Messages.First(m => m.Role == ChatRole.Tool);
Assert.NotNull(firstToolMessage);
var firstFunctionResult = firstToolMessage.Contents[^2] as Microsoft.Extensions.AI.FunctionResultContent;
var secondFunctionResult = firstToolMessage.Contents[^1] as Microsoft.Extensions.AI.FunctionResultContent;
- // Assert
Assert.NotNull(firstFunctionResult);
Assert.NotNull(secondFunctionResult);
Assert.Equal("Result from filter", firstFunctionResult.Result!.ToString());
@@ -392,21 +389,30 @@ public async Task FilterCanHandleExceptionOnStreamingAsync()
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses();
- var chatCompletion = new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient);
+ var chatClient = kernel.GetRequiredService();
- var chatHistory = new ChatHistory();
- var executionSettings = new OpenAIPromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };
+ var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };
+ var options = executionSettings.ToChatOptions(kernel);
+ List messageList = [];
// Act
- await foreach (var item in chatCompletion.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel))
- { }
-
- var firstFunctionResult = chatHistory[^2].Content;
- var secondFunctionResult = chatHistory[^1].Content;
+ List streamingContent = [];
+ await foreach (var update in chatClient.GetStreamingResponseAsync(messageList, options, CancellationToken.None))
+ {
+ streamingContent.Add(update);
+ }
+ var chatResponse = streamingContent.ToChatResponse();
// Assert
- Assert.Equal("Result from filter", firstFunctionResult);
- Assert.Equal("Result from Function2", secondFunctionResult);
+ var firstToolMessage = chatResponse.Messages.First(m => m.Role == ChatRole.Tool);
+ Assert.NotNull(firstToolMessage);
+ var firstFunctionResult = firstToolMessage.Contents[^2] as Microsoft.Extensions.AI.FunctionResultContent;
+ var secondFunctionResult = firstToolMessage.Contents[^1] as Microsoft.Extensions.AI.FunctionResultContent;
+
+ Assert.NotNull(firstFunctionResult);
+ Assert.NotNull(secondFunctionResult);
+ Assert.Equal("Result from filter", firstFunctionResult.Result!.ToString());
+ Assert.Equal("Result from Function2", secondFunctionResult.Result!.ToString());
}
[Fact]
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs
index f027a9b7a31c..9702bc5da22e 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs
@@ -85,6 +85,6 @@ public static async Task ReduceAsync(this ChatHistory chatHistory,
/// Converts a to a list.
/// The chat history to convert.
/// A list of objects.
- public static List ToChatMessageList(this ChatHistory chatHistory)
+ internal static List ToChatMessageList(this ChatHistory chatHistory)
=> chatHistory.Select(m => m.ToChatMessage()).ToList();
}
From ead8061c86aa1686193e2f00f5c486106d68a9c2 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Wed, 9 Apr 2025 20:48:03 +0100
Subject: [PATCH 12/26] Added all behavior except Skipping
---
...FunctionInvocationFilterChatClientTests.cs | 60 ++++++++----------
.../FunctionCalling/FunctionCallsProcessor.cs | 6 +-
.../KernelFunctionInvokingChatClient.cs | 61 ++++++++++++-------
.../AutoFunctionInvocationContext.cs | 4 --
4 files changed, 67 insertions(+), 64 deletions(-)
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
index 3c42a004eb32..1abb85a74b52 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
@@ -248,10 +248,7 @@ public async Task MultipleFiltersAreExecutedInOrderAsync(bool isStreaming)
builder.Plugins.Add(plugin);
- builder.Services.AddSingleton((serviceProvider) =>
- {
- return new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient);
- });
+ builder.Services.AddOpenAIChatClient("model-id", "test-api-key", "organization-id", httpClient: this._httpClient);
builder.Services.AddSingleton(filter1);
builder.Services.AddSingleton(filter2);
@@ -259,24 +256,21 @@ public async Task MultipleFiltersAreExecutedInOrderAsync(bool isStreaming)
var kernel = builder.Build();
- var arguments = new KernelArguments(new OpenAIPromptExecutionSettings
- {
- ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
- });
+ var settings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };
// Act
if (isStreaming)
{
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses();
- await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", arguments))
+ await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(settings)))
{ }
}
else
{
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();
- await kernel.InvokePromptAsync("Test prompt", arguments);
+ await kernel.InvokePromptAsync("Test prompt", new(settings));
}
// Assert
@@ -439,13 +433,13 @@ public async Task FiltersCanSkipFunctionExecutionAsync()
filterInvocations++;
});
- using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(File.ReadAllText("TestData/filters_multiple_function_calls_test_response.json")) };
+ using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(File.ReadAllText("TestData/filters_chatclient_multiple_function_calls_test_response.json")) };
using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(File.ReadAllText("TestData/chat_completion_test_response.json")) };
this._messageHandlerStub.ResponsesToReturn = [response1, response2];
// Act
- var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings
+ var result = await kernel.InvokePromptAsync("Test prompt", new(new PromptExecutionSettings
{
FunctionChoiceBehavior = FunctionChoiceBehavior.Auto()
}));
@@ -550,7 +544,7 @@ public async Task PostFilterCanTerminateOperationAsync()
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();
// Act
- var result = await kernel.InvokePromptAsync("Test prompt", new(new OpenAIPromptExecutionSettings
+ var functionResult = await kernel.InvokePromptAsync("Test prompt", new(new PromptExecutionSettings
{
FunctionChoiceBehavior = FunctionChoiceBehavior.Auto()
}));
@@ -562,11 +556,12 @@ public async Task PostFilterCanTerminateOperationAsync()
Assert.Equal([0], functionSequenceNumbers);
// Results of function invoked before termination should be returned
- var lastMessageContent = result.GetValue();
- Assert.NotNull(lastMessageContent);
+ var chatResponse = functionResult.GetValue();
+ Assert.NotNull(chatResponse);
- Assert.Equal("function1-value", lastMessageContent.Content);
- Assert.Equal(AuthorRole.Tool, lastMessageContent.Role);
+ var result = GetLastFunctionResultFromChatResponse(chatResponse);
+ Assert.NotNull(result);
+ Assert.Equal("function1-value", result.ToString());
}
[Fact]
@@ -598,12 +593,12 @@ public async Task PostFilterCanTerminateOperationOnStreamingAsync()
var executionSettings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };
- List streamingContent = [];
+ List streamingContent = [];
// Act
- await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings)))
+ await foreach (var update in kernel.InvokePromptStreamingAsync("Test prompt", new(executionSettings)))
{
- streamingContent.Add(item);
+ streamingContent.Add(update);
}
// Assert
@@ -613,13 +608,14 @@ public async Task PostFilterCanTerminateOperationOnStreamingAsync()
Assert.Equal([0], functionSequenceNumbers);
// Results of function invoked before termination should be returned
- Assert.Equal(3, streamingContent.Count);
+ Assert.Equal(4, streamingContent.Count);
- var lastMessageContent = streamingContent[^1] as StreamingChatMessageContent;
- Assert.NotNull(lastMessageContent);
+ var chatResponse = streamingContent.ToChatResponse();
+ Assert.NotNull(chatResponse);
- Assert.Equal("function1-value", lastMessageContent.Content);
- Assert.Equal(AuthorRole.Tool, lastMessageContent.Role);
+ var result = GetLastFunctionResultFromChatResponse(chatResponse);
+ Assert.NotNull(result);
+ Assert.Equal("function1-value", result.ToString());
}
[Theory]
@@ -645,32 +641,26 @@ public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming)
builder.Plugins.Add(plugin);
- builder.Services.AddSingleton((serviceProvider) =>
- {
- return new OpenAIChatCompletionService("model-id", "test-api-key", "organization-id", this._httpClient);
- });
+ builder.Services.AddOpenAIChatClient("model-id", "test-api-key", "organization-id", httpClient: this._httpClient);
builder.Services.AddSingleton(filter);
var kernel = builder.Build();
- var arguments = new KernelArguments(new OpenAIPromptExecutionSettings
- {
- ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
- });
+ var settings = new PromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };
// Act
if (isStreaming)
{
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses();
- await kernel.InvokePromptStreamingAsync("Test prompt", arguments).ToListAsync();
+ await kernel.InvokePromptStreamingAsync("Test prompt", new(settings)).ToListAsync();
}
else
{
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();
- await kernel.InvokePromptAsync("Test prompt", arguments);
+ await kernel.InvokePromptAsync("Test prompt", new(settings));
}
// Assert
diff --git a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs
index 97cb426c307d..88f3da9d6a53 100644
--- a/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs
+++ b/dotnet/src/InternalUtilities/connectors/AI/FunctionCalling/FunctionCallsProcessor.cs
@@ -211,7 +211,7 @@ public FunctionCallsProcessor(ILogger? logger = null)
{
bool terminationRequested = false;
- // Wait for all of the function invocations to complete, then add the results to the chat, but stop when we hit a
+ // Wait for all the function invocations to complete, then add the results to the chat, but stop when we hit a
// function for which termination was requested.
FunctionResultContext[] resultContexts = await Task.WhenAll(functionTasks).ConfigureAwait(false);
foreach (FunctionResultContext resultContext in resultContexts)
@@ -487,8 +487,8 @@ public static string ProcessFunctionResult(object functionResult)
return stringResult;
}
- // This is an optimization to use ChatMessageContent content directly
- // without unnecessary serialization of the whole message content class.
+ // This is an optimization to use ChatMessageContent content directly
+ // without unnecessary serialization of the whole message content class.
if (functionResult is ChatMessageContent chatMessageContent)
{
return chatMessageContent.ToString();
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
index 3f54113ab997..75d052842604 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
@@ -257,7 +257,7 @@ public override async Task GetResponseAsync(
// Add the responses from the function calls into the augmented history and also into the tracked
// list of response messages.
- var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, cancellationToken).ConfigureAwait(false);
+ var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, isStreaming: false, cancellationToken).ConfigureAwait(false);
responseMessages.AddRange(modeAndMessages.MessagesAdded);
// Clear the auto function invocation options.
@@ -337,7 +337,7 @@ public override async IAsyncEnumerable GetStreamingResponseA
UpdateOptionsForAutoFunctionInvocation(ref options!, response.Messages.Last().ToChatMessageContent(), isStreaming: true);
// Process all the functions, adding their results into the history.
- var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false);
+ var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, isStreaming: true, cancellationToken).ConfigureAwait(false);
responseMessages.AddRange(modeAndMessages.MessagesAdded);
// Clear the auto function invocation options.
@@ -547,21 +547,23 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti
/// The options used for the response being processed.
/// The function call contents representing the functions to be invoked.
/// The iteration number of how many roundtrips have been made to the inner client.
+ /// Whether the function calls are being processed in a streaming context.
/// The to monitor for cancellation requests.
/// A value indicating how the caller should proceed.
private async Task<(ContinueMode Mode, IList MessagesAdded)> ProcessFunctionCallsAsync(
- List messages, ChatOptions options, List functionCallContents, int iteration, CancellationToken cancellationToken)
+ List messages, ChatOptions options, List functionCallContents, int iteration, bool isStreaming, CancellationToken cancellationToken)
{
// We must add a response for every tool call, regardless of whether we successfully executed it or not.
// If we successfully execute it, we'll add the result. If we don't, we'll add an error.
Debug.Assert(functionCallContents.Count > 0, "Expected at least one function call.");
+ ContinueMode continueMode = ContinueMode.Continue;
// Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel.
if (functionCallContents.Count == 1)
{
FunctionInvocationResult result = await this.ProcessFunctionCallAsync(
- messages, options, functionCallContents, iteration, 0, cancellationToken).ConfigureAwait(false);
+ messages, options, functionCallContents, iteration, 0, isStreaming, cancellationToken).ConfigureAwait(false);
IList added = this.CreateResponseMessages([result]);
this.ThrowIfNoFunctionResultsAdded(added);
@@ -571,40 +573,54 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti
}
else
{
- FunctionInvocationResult[] results;
+ List results = [];
+ var terminationRequested = false;
if (this.AllowConcurrentInvocation)
{
// Schedule the invocation of every function.
- results = await Task.WhenAll(
+ results.AddRange(await Task.WhenAll(
from i in Enumerable.Range(0, functionCallContents.Count)
select Task.Run(() => this.ProcessFunctionCallAsync(
messages, options, functionCallContents,
- iteration, i, cancellationToken))).ConfigureAwait(false);
+ iteration, i, isStreaming, cancellationToken))).ConfigureAwait(false));
+
+ terminationRequested = results.Any(r => r.ContinueMode == ContinueMode.Terminate);
}
else
{
// Invoke each function serially.
- results = new FunctionInvocationResult[functionCallContents.Count];
- for (int i = 0; i < results.Length; i++)
+ for (int i = 0; i < functionCallContents.Count; i++)
{
- results[i] = await this.ProcessFunctionCallAsync(
+ var result = await this.ProcessFunctionCallAsync(
messages, options, functionCallContents,
- iteration, i, cancellationToken).ConfigureAwait(false);
+ iteration, i, isStreaming, cancellationToken).ConfigureAwait(false);
+
+ results.Add(result);
+
+ if (result.ContinueMode == ContinueMode.Terminate)
+ {
+ continueMode = ContinueMode.Terminate;
+ terminationRequested = true;
+ break;
+ }
}
}
- ContinueMode continueMode = ContinueMode.Continue;
-
IList added = this.CreateResponseMessages(results);
this.ThrowIfNoFunctionResultsAdded(added);
-
messages.AddRange(added);
- foreach (FunctionInvocationResult fir in results)
+
+ if (!terminationRequested)
{
- if (fir.ContinueMode > continueMode)
+ // If any function requested termination, we'll terminate.
+ continueMode = ContinueMode.Continue;
+ foreach (FunctionInvocationResult fir in results)
{
- continueMode = fir.ContinueMode;
+ if (fir.ContinueMode > continueMode)
+ {
+ continueMode = fir.ContinueMode;
+ }
}
}
@@ -629,11 +645,12 @@ private void ThrowIfNoFunctionResultsAdded(IList? messages)
/// The function call contents representing all the functions being invoked.
/// The iteration number of how many roundtrips have been made to the inner client.
/// The 0-based index of the function being called out of .
+ /// Whether the function calls are being processed in a streaming context.
/// The to monitor for cancellation requests.
/// A value indicating how the caller should proceed.
private async Task ProcessFunctionCallAsync(
List messages, ChatOptions options, List callContents,
- int iteration, int functionCallIndex, CancellationToken cancellationToken)
+ int iteration, int functionCallIndex, bool isStreaming, CancellationToken cancellationToken)
{
var callContent = callContents[functionCallIndex];
@@ -658,6 +675,7 @@ private async Task ProcessFunctionCallAsync(
Iteration = iteration,
FunctionCallIndex = functionCallIndex,
FunctionCount = callContents.Count,
+ IsStreaming = isStreaming
};
object? result;
@@ -700,10 +718,10 @@ internal enum ContinueMode
/// Information about the function call invocations and results.
/// A list of all chat messages created from .
internal IList CreateResponseMessages(
- ReadOnlySpan results)
+ IReadOnlyList results)
{
- var contents = new List(results.Length);
- for (int i = 0; i < results.Length; i++)
+ var contents = new List(results.Count);
+ for (int i = 0; i < results.Count; i++)
{
contents.Add(CreateFunctionResultContent(results[i]));
}
@@ -829,7 +847,6 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
result = await invocationContext.AIFunction.InvokeAsync(autoFunctionInvocationContext.Arguments, cancellationToken).ConfigureAwait(false);
context.Result = new FunctionResult(context.Function, result);
}).ConfigureAwait(false);
-
result = autoFunctionInvocationContext.Result.GetValue();
}
}
diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
index e642c1a55d8f..842352fe95dd 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
@@ -37,10 +37,6 @@ public AutoFunctionInvocationContext(ChatOptions options)
options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.PromptExecutionSettingsKey, out var executionSettings);
this.ExecutionSettings = executionSettings;
- options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.IsStreamingKey, out var isStreaming);
- Verify.NotNull(isStreaming);
- this.IsStreaming = isStreaming.Value;
-
this.Result = new FunctionResult(this.Function) { Culture = kernel.Culture };
}
From df4b108a6f44ee5afecdfba281aa7bde102a119f Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Thu, 10 Apr 2025 19:00:13 +0100
Subject: [PATCH 13/26] Fix warnings
---
.../AI/ChatClient/ChatClientExtensions.cs | 19 +++++++++++--------
.../AI/ChatClient/ChatOptionsExtensions.cs | 11 +++++++----
.../AutoFunctionInvocationContext.cs | 2 +-
3 files changed, 19 insertions(+), 13 deletions(-)
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
index f3aa9c4ad069..6d6b80fe28bb 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
@@ -29,10 +29,7 @@ internal static Task GetResponseAsync(
Kernel? kernel = null,
CancellationToken cancellationToken = default)
{
- var chatOptions = executionSettings?.ToChatOptions(kernel);
-
- // Passing by reference to be used by AutoFunctionInvocationFilters
- chatOptions.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = executionSettings;
+ var chatOptions = GetChatOptionsFromSettings(executionSettings, kernel);
// Try to parse the text as a chat history
if (ChatPromptParser.TryParse(prompt, out var chatHistoryFromPrompt))
@@ -58,10 +55,7 @@ internal static IAsyncEnumerable GetStreamingResponseAsync(
Kernel? kernel = null,
CancellationToken cancellationToken = default)
{
- var chatOptions = executionSettings?.ToChatOptions(kernel);
-
- // Passing by reference to be used by AutoFunctionInvocationFilters
- chatOptions.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = executionSettings;
+ var chatOptions = GetChatOptionsFromSettings(executionSettings, kernel);
return chatClient.GetStreamingResponseAsync(prompt, chatOptions, cancellationToken);
}
@@ -109,4 +103,13 @@ public static IChatClient AsKernelFunctionInvokingChatClient(this IChatClient cl
? kernelFunctionInvocationClient
: new KernelFunctionInvokingChatClient(client, logger);
}
+
+ private static ChatOptions GetChatOptionsFromSettings(PromptExecutionSettings? executionSettings, Kernel? kernel)
+ {
+ ChatOptions chatOptions = executionSettings?.ToChatOptions(kernel) ?? new ChatOptions().AddKernel(kernel);
+
+ // Passing by reference to be used by AutoFunctionInvocationFilters
+ chatOptions.AdditionalProperties![ChatOptionsExtensions.PromptExecutionSettingsKey] = executionSettings;
+ return chatOptions;
+ }
}
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs
index de2c6831b93d..93d076090dcc 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs
@@ -130,13 +130,16 @@ internal static class ChatOptionsExtensions
///
/// Chat options.
/// Kernel to be used for auto function invocation.
- internal static ChatOptions? AddKernel(this ChatOptions options, Kernel kernel)
+ internal static ChatOptions AddKernel(this ChatOptions options, Kernel? kernel)
{
Verify.NotNull(options);
- Verify.NotNull(kernel);
- options.AdditionalProperties ??= [];
- options.AdditionalProperties?.TryAdd(KernelKey, kernel);
+ // Only add the kernel if it is provided
+ if (kernel is not null)
+ {
+ options.AdditionalProperties ??= [];
+ options.AdditionalProperties?.TryAdd(KernelKey, kernel);
+ }
return options;
}
diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
index 842352fe95dd..d4fbe564fcd4 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
@@ -1,6 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
-using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
@@ -141,6 +140,7 @@ public PromptExecutionSettings? ExecutionSettings
get => this.Options?.AdditionalProperties?[ChatOptionsExtensions.PromptExecutionSettingsKey] as PromptExecutionSettings;
init
{
+ this.Options ??= new();
this.Options.AdditionalProperties ??= [];
this.Options.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = value;
}
From ffc01d925bd1c903afdc0f2fd47ea962ffdfc61d Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Fri, 11 Apr 2025 11:19:51 +0100
Subject: [PATCH 14/26] AutoINvocation Skipping and ChatReducer passing
---
...FunctionInvocationFilterChatClientTests.cs | 2 +-
.../Core/AutoFunctionInvocationFilterTests.cs | 2 +-
.../AI/ChatCompletion/ChatHistory.cs | 61 +++++++++++++++--
.../ChatCompletion/ChatHistoryExtensions.cs | 58 +++++++++++++++++
.../AutoFunctionInvocationContext.cs | 65 ++++++++++++++++++-
.../FunctionCallsProcessorTests.cs | 12 ++--
6 files changed, 187 insertions(+), 13 deletions(-)
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
index 1abb85a74b52..cb47c48af8d7 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
@@ -425,7 +425,7 @@ public async Task FiltersCanSkipFunctionExecutionAsync()
var kernel = this.GetKernelWithFilter(plugin, async (context, next) =>
{
// Filter delegate is invoked only for second function, the first one should be skipped.
- if (context.Function.Name == "Function2")
+ if (context.Function.Name == "MyPlugin_Function2")
{
await next(context);
}
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs
index 1f23d84c4dfe..b308206b12d5 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterTests.cs
@@ -596,7 +596,7 @@ public async Task PostFilterCanTerminateOperationOnStreamingAsync()
Assert.Equal([0], requestSequenceNumbers);
Assert.Equal([0], functionSequenceNumbers);
- // Results of function invoked before termination should be returned
+ // Results of function invoked before termination should be returned
Assert.Equal(3, streamingContent.Count);
var lastMessageContent = streamingContent[^1] as StreamingChatMessageContent;
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs
index cfdad840349c..0054dc98a400 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs
@@ -18,6 +18,14 @@ public class ChatHistory : IList, IReadOnlyListThe messages.
private readonly List _messages;
+ private Action? _overrideAdd;
+ private Func? _overrideRemove;
+ private Action? _overrideClear;
+ private Action? _overrideInsert;
+ private Action? _overrideRemoveAt;
+ private Action? _overrideRemoveRange;
+ private Action>? _overrideAddRange;
+
/// Initializes an empty history.
///
/// Creates a new instance of the class
@@ -27,6 +35,36 @@ public ChatHistory()
this._messages = [];
}
+ // Due to changes using the AutoFunctionInvocation as a dependency of KernelInvocation, that needs to reflect
+ internal void SetOverrides(
+ Action overrideAdd,
+ Func overrideRemove,
+ Action onClear,
+ Action overrideInsert,
+ Action overrideRemoveAt,
+ Action overrideRemoveRange,
+ Action> overrideAddRange)
+ {
+ this._overrideAdd = overrideAdd;
+ this._overrideRemove = overrideRemove;
+ this._overrideClear = onClear;
+ this._overrideInsert = overrideInsert;
+ this._overrideRemoveAt = overrideRemoveAt;
+ this._overrideRemoveRange = overrideRemoveRange;
+ this._overrideAddRange = overrideAddRange;
+ }
+
+ internal void ClearOverrides()
+ {
+ this._overrideAdd = null;
+ this._overrideRemove = null;
+ this._overrideClear = null;
+ this._overrideInsert = null;
+ this._overrideRemoveAt = null;
+ this._overrideRemoveRange = null;
+ this._overrideAddRange = null;
+ }
+
///
/// Creates a new instance of the with a first message in the provided .
/// If not role is provided then the first message will default to role.
@@ -121,6 +159,7 @@ public virtual void Add(ChatMessageContent item)
{
Verify.NotNull(item);
this._messages.Add(item);
+ this._overrideAdd?.Invoke(item);
}
/// Adds the messages to the history.
@@ -130,6 +169,7 @@ public virtual void AddRange(IEnumerable items)
{
Verify.NotNull(items);
this._messages.AddRange(items);
+ this._overrideAddRange?.Invoke(items);
}
/// Inserts a message into the history at the specified index.
@@ -140,6 +180,7 @@ public virtual void Insert(int index, ChatMessageContent item)
{
Verify.NotNull(item);
this._messages.Insert(index, item);
+ this._overrideInsert?.Invoke(index, item);
}
///
@@ -150,10 +191,15 @@ public virtual void Insert(int index, ChatMessageContent item)
/// is null.
/// The number of messages in the history is greater than the available space from to the end of .
/// is less than 0.
- public virtual void CopyTo(ChatMessageContent[] array, int arrayIndex) => this._messages.CopyTo(array, arrayIndex);
+ public virtual void CopyTo(ChatMessageContent[] array, int arrayIndex)
+ => this._messages.CopyTo(array, arrayIndex);
/// Removes all messages from the history.
- public virtual void Clear() => this._messages.Clear();
+ public virtual void Clear()
+ {
+ this._messages.Clear();
+ this._overrideClear?.Invoke();
+ }
/// Gets or sets the message at the specified index in the history.
/// The index of the message to get or set.
@@ -193,7 +239,11 @@ public virtual int IndexOf(ChatMessageContent item)
/// Removes the message at the specified index from the history.
/// The index of the message to remove.
/// The was not valid for this history.
- public virtual void RemoveAt(int index) => this._messages.RemoveAt(index);
+ public virtual void RemoveAt(int index)
+ {
+ this._messages.RemoveAt(index);
+ this._overrideRemoveAt?.Invoke(index);
+ }
/// Removes the first occurrence of the specified message from the history.
/// The message to remove from the history.
@@ -202,7 +252,9 @@ public virtual int IndexOf(ChatMessageContent item)
public virtual bool Remove(ChatMessageContent item)
{
Verify.NotNull(item);
- return this._messages.Remove(item);
+ var result = this._messages.Remove(item);
+ this._overrideRemove?.Invoke(item);
+ return result;
}
///
@@ -216,6 +268,7 @@ public virtual bool Remove(ChatMessageContent item)
public virtual void RemoveRange(int index, int count)
{
this._messages.RemoveRange(index, count);
+ this._overrideRemoveRange?.Invoke(index, count);
}
///
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs
index 9702bc5da22e..8a65e502464f 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistoryExtensions.cs
@@ -87,4 +87,62 @@ public static async Task ReduceAsync(this ChatHistory chatHistory,
/// A list of objects.
internal static List ToChatMessageList(this ChatHistory chatHistory)
=> chatHistory.Select(m => m.ToChatMessage()).ToList();
+
+ internal static void SetChatMessageHandlers(this ChatHistory chatHistory, IList messages)
+ {
+ chatHistory.SetOverrides(Add, Remove, Clear, Insert, RemoveAt, RemoveRange, AddRange);
+
+ void Add(ChatMessageContent item)
+ {
+ messages.Add(item.ToChatMessage());
+ }
+
+ void Clear()
+ {
+ messages.Clear();
+ }
+
+ bool Remove(ChatMessageContent item)
+ {
+ var index = chatHistory.IndexOf(item);
+
+ if (index < 0)
+ {
+ return false;
+ }
+
+ messages.RemoveAt(index);
+
+ return true;
+ }
+
+ void Insert(int index, ChatMessageContent item)
+ {
+ messages.Insert(index, item.ToChatMessage());
+ }
+
+ void RemoveAt(int index)
+ {
+ messages.RemoveAt(index);
+ }
+
+ void RemoveRange(int index, int count)
+ {
+ if (messages is List messageList)
+ {
+ messageList.RemoveRange(index, count);
+ return;
+ }
+
+ foreach (var chatMessage in messages.Skip(index).Take(count))
+ {
+ messages.Remove(chatMessage);
+ }
+ }
+
+ void AddRange(IEnumerable items)
+ {
+ messages.AddRange(items.Select(i => i.ToChatMessage()));
+ }
+ }
}
diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
index d4fbe564fcd4..b15ac9b00fe7 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
@@ -1,8 +1,10 @@
// Copyright (c) Microsoft. All rights reserved.
+using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
+using System.Linq;
using System.Threading;
using Microsoft.Extensions.AI;
using Microsoft.SemanticKernel.ChatCompletion;
@@ -15,6 +17,7 @@ namespace Microsoft.SemanticKernel;
public class AutoFunctionInvocationContext : KernelFunctionInvocationContext
{
private ChatHistory? _chatHistory;
+ private KernelFunction? _kernelFunction;
///
/// Initializes a new instance of the class from an existing .
@@ -69,7 +72,10 @@ public AutoFunctionInvocationContext(
}
};
+ this._kernelFunction = function;
+ this._chatHistory = chatHistory;
this.Messages = chatHistory.ToChatMessageList();
+ chatHistory.SetChatMessageHandlers(this.Messages);
this.AIFunction = function.AsAIFunction();
this.Result = result;
}
@@ -154,7 +160,21 @@ public PromptExecutionSettings? ExecutionSettings
///
/// Gets the with which this filter is associated.
///
- public KernelFunction Function => this.AIFunction.AsKernelFunction();
+ public KernelFunction Function
+ {
+ get
+ {
+ if (this._kernelFunction is null
+ // If the schemas are different,
+ // AIFunction reference potentially was modified and the kernel function should be regenerated.
+ || !IsSameSchema(this._kernelFunction, this.AIFunction))
+ {
+ this._kernelFunction = this.AIFunction.AsKernelFunction();
+ }
+
+ return this._kernelFunction;
+ }
+ }
///
/// Gets the containing services, plugins, and other state for use throughout the operation.
@@ -176,6 +196,18 @@ public Kernel Kernel
///
public FunctionResult Result { get; set; }
+ private static bool IsSameSchema(KernelFunction kernelFunction, AIFunction aiFunction)
+ {
+ // Compares the schemas, should be similar.
+ return string.Equals(
+ kernelFunction.AsAIFunction().JsonSchema.ToString(),
+ aiFunction.JsonSchema.ToString(),
+ StringComparison.OrdinalIgnoreCase);
+
+ // TODO: Later can be improved by comparing the underlying methods.
+ // return kernelFunction.UnderlyingMethod == aiFunction.UnderlyingMethod;
+ }
+
///
/// Mutable IEnumerable of chat message as chat history.
///
@@ -237,6 +269,30 @@ public override ChatMessageContent this[int index]
}
}
+ public override void RemoveRange(int index, int count)
+ {
+ this._messages.RemoveRange(index, count);
+ base.RemoveRange(index, count);
+ }
+
+ public override void CopyTo(ChatMessageContent[] array, int arrayIndex)
+ {
+ for (int i = 0; i < this._messages.Count; i++)
+ {
+ array[arrayIndex + i] = this._messages[i].ToChatMessageContent();
+ }
+ }
+
+ public override bool Contains(ChatMessageContent item) => base.Contains(item);
+
+ public override int IndexOf(ChatMessageContent item) => base.IndexOf(item);
+
+ public override void AddRange(IEnumerable items)
+ {
+ base.AddRange(items);
+ this._messages.AddRange(items.Select(i => i.ToChatMessage()));
+ }
+
public override int Count => this._messages.Count;
// Explicit implementation of IEnumerable.GetEnumerator()
@@ -252,4 +308,11 @@ IEnumerator IEnumerable.GetEnumerator()
IEnumerator IEnumerable.GetEnumerator()
=> ((IEnumerable)this).GetEnumerator();
}
+
+ ~AutoFunctionInvocationContext()
+ {
+ // The moment this class is destroyed, we need to clear the overrides and
+ // overrides to update to message
+ this._chatHistory?.ClearOverrides();
+ }
}
diff --git a/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs b/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs
index 9015b2908b2f..ca3f269abc64 100644
--- a/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs
+++ b/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs
@@ -184,7 +184,7 @@ public async Task ItShouldAddFunctionInvocationExceptionToChatHistoryAsync()
var chatHistory = new ChatHistory();
var chatMessageContent = new ChatMessageContent();
- chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin"));
+ chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test"));
// Act
await this._sut.ProcessFunctionCallsAsync(
@@ -213,7 +213,7 @@ public async Task ItShouldAddErrorToChatHistoryIfFunctionCallNotAdvertisedAsync(
var chatHistory = new ChatHistory();
var chatMessageContent = new ChatMessageContent();
- chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin"));
+ chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test"));
// Act
await this._sut.ProcessFunctionCallsAsync(
@@ -242,7 +242,7 @@ public async Task ItShouldAddErrorToChatHistoryIfFunctionIsNotRegisteredOnKernel
var chatHistory = new ChatHistory();
var chatMessageContent = new ChatMessageContent();
- chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin")); // The call for function that is not registered on the kernel
+ chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test")); // The call for function that is not registered on the kernel
// Act
await this._sut.ProcessFunctionCallsAsync(
@@ -743,7 +743,7 @@ public async Task ItShouldHandleChatMessageContentAsFunctionResultAsync()
var chatHistory = new ChatHistory();
var chatMessageContent = new ChatMessageContent();
- chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin"));
+ chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test"));
// Act
await this._sut.ProcessFunctionCallsAsync(
@@ -779,7 +779,7 @@ public async Task ItShouldSerializeFunctionResultOfUnknownTypeAsync()
var chatHistory = new ChatHistory();
var chatMessageContent = new ChatMessageContent();
- chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin"));
+ chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test"));
// Act
await this._sut.ProcessFunctionCallsAsync(
@@ -870,7 +870,7 @@ public async Task ItShouldPassPromptExecutionSettingsToAutoFunctionInvocationFil
});
var chatMessageContent = new ChatMessageContent();
- chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", arguments: new KernelArguments() { ["parameter"] = "function1-result" }));
+ chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test", arguments: new KernelArguments() { ["parameter"] = "function1-result" }));
// Act
await this._sut.ProcessFunctionCallsAsync(
From f7ee7a2aa22922cc25498572a29a45ac4169e11a Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Fri, 11 Apr 2025 11:30:06 +0100
Subject: [PATCH 15/26] Fix CallId optionality
---
.../AI/ChatClient/ChatOptionsExtensions.cs | 1 -
.../ChatClient/KernelFunctionInvokingChatClient.cs | 2 +-
.../AutoFunctionInvocationContext.cs | 5 ++---
.../AIConnectors/FunctionCallsProcessorTests.cs | 12 ++++++------
4 files changed, 9 insertions(+), 11 deletions(-)
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs
index 93d076090dcc..5db8240b1707 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs
@@ -2,7 +2,6 @@
using System;
using System.Collections.Generic;
-using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Text.Json;
using Microsoft.Extensions.AI;
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
index 75d052842604..1d48111b47af 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
@@ -846,7 +846,7 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
// as the called function could in turn telling the model about itself as a possible candidate for invocation.
result = await invocationContext.AIFunction.InvokeAsync(autoFunctionInvocationContext.Arguments, cancellationToken).ConfigureAwait(false);
context.Result = new FunctionResult(context.Function, result);
- }).ConfigureAwait(false);
+ }).ConfigureAwait(false);
result = autoFunctionInvocationContext.Result.GetValue();
}
}
diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
index b15ac9b00fe7..e3318555b129 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
@@ -126,9 +126,7 @@ public string? ToolCallId
get => this.CallContent.CallId;
init
{
- Verify.NotNull(value);
- // ToolCallId
- this.CallContent = new Microsoft.Extensions.AI.FunctionCallContent(value, this.CallContent.Name, this.CallContent.Arguments);
+ this.CallContent = new Microsoft.Extensions.AI.FunctionCallContent(value ?? string.Empty, this.CallContent.Name, this.CallContent.Arguments);
}
}
@@ -309,6 +307,7 @@ IEnumerator IEnumerable.GetEnumerator()
=> ((IEnumerable)this).GetEnumerator();
}
+ /// Destructor to clear the overrides and update the message.
~AutoFunctionInvocationContext()
{
// The moment this class is destroyed, we need to clear the overrides and
diff --git a/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs b/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs
index ca3f269abc64..9015b2908b2f 100644
--- a/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs
+++ b/dotnet/src/SemanticKernel.UnitTests/Utilities/AIConnectors/FunctionCallsProcessorTests.cs
@@ -184,7 +184,7 @@ public async Task ItShouldAddFunctionInvocationExceptionToChatHistoryAsync()
var chatHistory = new ChatHistory();
var chatMessageContent = new ChatMessageContent();
- chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test"));
+ chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin"));
// Act
await this._sut.ProcessFunctionCallsAsync(
@@ -213,7 +213,7 @@ public async Task ItShouldAddErrorToChatHistoryIfFunctionCallNotAdvertisedAsync(
var chatHistory = new ChatHistory();
var chatMessageContent = new ChatMessageContent();
- chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test"));
+ chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin"));
// Act
await this._sut.ProcessFunctionCallsAsync(
@@ -242,7 +242,7 @@ public async Task ItShouldAddErrorToChatHistoryIfFunctionIsNotRegisteredOnKernel
var chatHistory = new ChatHistory();
var chatMessageContent = new ChatMessageContent();
- chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test")); // The call for function that is not registered on the kernel
+ chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin")); // The call for function that is not registered on the kernel
// Act
await this._sut.ProcessFunctionCallsAsync(
@@ -743,7 +743,7 @@ public async Task ItShouldHandleChatMessageContentAsFunctionResultAsync()
var chatHistory = new ChatHistory();
var chatMessageContent = new ChatMessageContent();
- chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test"));
+ chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin"));
// Act
await this._sut.ProcessFunctionCallsAsync(
@@ -779,7 +779,7 @@ public async Task ItShouldSerializeFunctionResultOfUnknownTypeAsync()
var chatHistory = new ChatHistory();
var chatMessageContent = new ChatMessageContent();
- chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test"));
+ chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin"));
// Act
await this._sut.ProcessFunctionCallsAsync(
@@ -870,7 +870,7 @@ public async Task ItShouldPassPromptExecutionSettingsToAutoFunctionInvocationFil
});
var chatMessageContent = new ChatMessageContent();
- chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", id: "callid_test", arguments: new KernelArguments() { ["parameter"] = "function1-result" }));
+ chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", arguments: new KernelArguments() { ["parameter"] = "function1-result" }));
// Act
await this._sut.ProcessFunctionCallsAsync(
From 7344afd3489b9673826fdb6ea481a16c674733dd Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Fri, 11 Apr 2025 12:26:11 +0100
Subject: [PATCH 16/26] Fix warnings
---
.../Core/AutoFunctionInvocationFilterChatClientTests.cs | 1 -
1 file changed, 1 deletion(-)
diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
index cb47c48af8d7..dd8d94c99824 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Core/AutoFunctionInvocationFilterChatClientTests.cs
@@ -11,7 +11,6 @@
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel;
-using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Xunit;
From 196fa98c426120e77593e3be38f077db316e28e8 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Fri, 11 Apr 2025 13:01:50 +0100
Subject: [PATCH 17/26] Fix 9.4.0 conflicts and errors
---
.../Agents/ChatCompletion_ServiceSelection.cs | 10 ++--
.../Kernel/CustomAIServiceSelector.cs | 8 +--
.../Step06_DependencyInjection.cs | 22 ++++----
...IServiceCollectionExtensions.ChatClient.cs | 24 ++++-----
.../OpenAI/OpenAIAudioToTextTests.cs | 2 +-
.../OpenAI/OpenAIChatCompletionTests.cs | 53 +++++++++----------
.../OpenAI/OpenAITextToAudioTests.cs | 2 +-
.../samples/InternalUtilities/BaseTest.cs | 22 ++++----
.../AI/ChatClient/AIFunctionFactory.cs | 12 ++---
.../AI/ChatClient/ChatClientAIService.cs | 2 +-
.../AI/ChatClient/ChatClientExtensions.cs | 2 +-
.../KernelFunctionInvokingChatClient.cs | 2 +-
.../Functions/KernelFunctionFromPrompt.cs | 4 +-
.../AIFunctionKernelFunctionTests.cs | 4 +-
.../CustomAIChatClientSelectorTests.cs | 2 +-
.../OrderedAIServiceSelectorTests.cs | 2 +-
16 files changed, 87 insertions(+), 86 deletions(-)
diff --git a/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs b/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs
index a0aa892a7802..be69fe412d5e 100644
--- a/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs
+++ b/dotnet/samples/Concepts/Agents/ChatCompletion_ServiceSelection.cs
@@ -109,11 +109,11 @@ private Kernel CreateKernelWithTwoServices(bool useChatClient)
{
builder.Services.AddKeyedChatClient(
ServiceKeyBad,
- new OpenAI.OpenAIClient("bad-key").AsChatClient(TestConfiguration.OpenAI.ChatModelId));
+ new OpenAI.OpenAIClient("bad-key").GetChatClient(TestConfiguration.OpenAI.ChatModelId).AsIChatClient());
builder.Services.AddKeyedChatClient(
ServiceKeyGood,
- new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey).AsChatClient(TestConfiguration.OpenAI.ChatModelId));
+ new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey).GetChatClient(TestConfiguration.OpenAI.ChatModelId).AsIChatClient());
}
else
{
@@ -122,14 +122,16 @@ private Kernel CreateKernelWithTwoServices(bool useChatClient)
new Azure.AI.OpenAI.AzureOpenAIClient(
new Uri(TestConfiguration.AzureOpenAI.Endpoint),
new Azure.AzureKeyCredential("bad-key"))
- .AsChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName));
+ .GetChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName)
+ .AsIChatClient());
builder.Services.AddKeyedChatClient(
ServiceKeyGood,
new Azure.AI.OpenAI.AzureOpenAIClient(
new Uri(TestConfiguration.AzureOpenAI.Endpoint),
new Azure.AzureKeyCredential(TestConfiguration.AzureOpenAI.ApiKey))
- .AsChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName));
+ .GetChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName)
+ .AsIChatClient());
}
}
else
diff --git a/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs b/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs
index d4631323c24d..02ddbdb3ec35 100644
--- a/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs
+++ b/dotnet/samples/Concepts/Kernel/CustomAIServiceSelector.cs
@@ -10,7 +10,7 @@
namespace KernelExamples;
///
-/// This sample shows how to use a custom AI service selector to select a specific model by matching it's id.
+/// This sample shows how to use a custom AI service selector to select a specific model by matching the model id.
///
public class CustomAIServiceSelector(ITestOutputHelper output) : BaseTest(output)
{
@@ -39,7 +39,8 @@ public async Task UsingCustomSelectToSelectServiceByMatchingModelId()
builder.Services
.AddSingleton(customSelector)
.AddKeyedChatClient("OpenAIChatClient", new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey)
- .AsChatClient("gpt-4o")); // Add a IChatClient to the kernel
+ .GetChatClient("gpt-4o")
+ .AsIChatClient()); // Add a IChatClient to the kernel
Kernel kernel = builder.Build();
@@ -60,7 +61,6 @@ private sealed class GptAIServiceSelector(string modelNameStartsWith, ITestOutpu
private readonly ITestOutputHelper _output = output;
private readonly string _modelNameStartsWith = modelNameStartsWith;
- ///
private bool TrySelect(
Kernel kernel, KernelFunction function, KernelArguments arguments,
[NotNullWhen(true)] out T? service, out PromptExecutionSettings? serviceSettings) where T : class
@@ -78,7 +78,7 @@ private bool TrySelect(
else if (serviceToCheck is IChatClient chatClient)
{
var metadata = chatClient.GetService();
- serviceModelId = metadata?.ModelId;
+ serviceModelId = metadata?.DefaultModelId;
endpoint = metadata?.ProviderUri?.ToString();
}
diff --git a/dotnet/samples/GettingStartedWithAgents/Step06_DependencyInjection.cs b/dotnet/samples/GettingStartedWithAgents/Step06_DependencyInjection.cs
index 8935e4d66d48..39106b957841 100644
--- a/dotnet/samples/GettingStartedWithAgents/Step06_DependencyInjection.cs
+++ b/dotnet/samples/GettingStartedWithAgents/Step06_DependencyInjection.cs
@@ -43,25 +43,25 @@ public async Task UseDependencyInjectionToCreateAgentAsync(bool useChatClient)
IChatClient chatClient;
if (this.UseOpenAIConfig)
{
- chatClient = new Microsoft.Extensions.AI.OpenAIChatClient(
- new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey),
- TestConfiguration.OpenAI.ChatModelId);
+ chatClient = new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey)
+ .GetChatClient(TestConfiguration.OpenAI.ChatModelId)
+ .AsIChatClient();
}
else if (!string.IsNullOrEmpty(this.ApiKey))
{
- chatClient = new Microsoft.Extensions.AI.OpenAIChatClient(
- openAIClient: new AzureOpenAIClient(
+ chatClient = new AzureOpenAIClient(
endpoint: new Uri(TestConfiguration.AzureOpenAI.Endpoint),
- credential: new ApiKeyCredential(TestConfiguration.AzureOpenAI.ApiKey)),
- modelId: TestConfiguration.AzureOpenAI.ChatModelId);
+ credential: new ApiKeyCredential(TestConfiguration.AzureOpenAI.ApiKey))
+ .GetChatClient(TestConfiguration.OpenAI.ChatModelId)
+ .AsIChatClient();
}
else
{
- chatClient = new Microsoft.Extensions.AI.OpenAIChatClient(
- openAIClient: new AzureOpenAIClient(
+ chatClient = new AzureOpenAIClient(
endpoint: new Uri(TestConfiguration.AzureOpenAI.Endpoint),
- credential: new AzureCliCredential()),
- modelId: TestConfiguration.AzureOpenAI.ChatModelId);
+ credential: new AzureCliCredential())
+ .GetChatClient(TestConfiguration.OpenAI.ChatModelId)
+ .AsIChatClient();
}
var functionCallingChatClient = chatClient!.AsKernelFunctionInvokingChatClient();
diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs
index 2ae915869698..7fa002d806ea 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs
@@ -47,11 +47,11 @@ public static IServiceCollection AddOpenAIChatClient(
IChatClient Factory(IServiceProvider serviceProvider, object? _)
{
- ILogger? logger = serviceProvider.GetService()?.CreateLogger();
+ ILogger? logger = serviceProvider.GetService()?.CreateLogger();
- return new Microsoft.Extensions.AI.OpenAIChatClient(
- openAIClient: new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), options: GetClientOptions(orgId: orgId, httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider))),
- modelId: modelId)
+ return new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), options: GetClientOptions(orgId: orgId, httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider)))
+ .GetChatClient(modelId)
+ .AsIChatClient()
.AsKernelFunctionInvokingChatClient(logger);
}
@@ -77,11 +77,11 @@ public static IServiceCollection AddOpenAIChatClient(this IServiceCollection ser
IChatClient Factory(IServiceProvider serviceProvider, object? _)
{
- ILogger? logger = serviceProvider.GetService()?.CreateLogger();
+ ILogger? logger = serviceProvider.GetService()?.CreateLogger();
- return new Microsoft.Extensions.AI.OpenAIChatClient(
- openAIClient ?? serviceProvider.GetRequiredService(),
- modelId)
+ return (openAIClient ?? serviceProvider.GetRequiredService())
+ .GetChatClient(modelId)
+ .AsIChatClient()
.AsKernelFunctionInvokingChatClient(logger);
}
@@ -114,11 +114,11 @@ public static IServiceCollection AddOpenAIChatClient(
IChatClient Factory(IServiceProvider serviceProvider, object? _)
{
- ILogger? logger = serviceProvider.GetService()?.CreateLogger();
+ ILogger? logger = serviceProvider.GetService()?.CreateLogger();
- return new Microsoft.Extensions.AI.OpenAIChatClient(
- openAIClient: new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), GetClientOptions(endpoint, orgId, HttpClientProvider.GetHttpClient(httpClient, serviceProvider))),
- modelId: modelId)
+ return new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), GetClientOptions(endpoint, orgId, HttpClientProvider.GetHttpClient(httpClient, serviceProvider)))
+ .GetChatClient(modelId)
+ .AsIChatClient()
.AsKernelFunctionInvokingChatClient(logger);
}
diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIAudioToTextTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIAudioToTextTests.cs
index 90375307c533..9e1127fa8b55 100644
--- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIAudioToTextTests.cs
+++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIAudioToTextTests.cs
@@ -22,7 +22,7 @@ public sealed class OpenAIAudioToTextTests()
.AddUserSecrets()
.Build();
- [RetryFact]//(Skip = "OpenAI will often throttle requests. This test is for manual verification.")]
+ [RetryFact] //(Skip = "OpenAI will often throttle requests. This test is for manual verification.")]
public async Task OpenAIAudioToTextTestAsync()
{
// Arrange
diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs
index fe8ff155d9c5..ddfe6b997a25 100644
--- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs
+++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletionTests.cs
@@ -12,7 +12,6 @@
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Http.Resilience;
-using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using OpenAI;
@@ -48,16 +47,16 @@ public async Task ItCanUseOpenAiChatForTextGenerationAsync()
[Fact]
public async Task ItCanUseOpenAiChatClientAndContentsAsync()
{
- var OpenAIConfiguration = this._configuration.GetSection("OpenAI").Get();
- Assert.NotNull(OpenAIConfiguration);
- Assert.NotNull(OpenAIConfiguration.ChatModelId);
- Assert.NotNull(OpenAIConfiguration.ApiKey);
- Assert.NotNull(OpenAIConfiguration.ServiceId);
+ var openAIConfiguration = this._configuration.GetSection("OpenAI").Get();
+ Assert.NotNull(openAIConfiguration);
+ Assert.NotNull(openAIConfiguration.ChatModelId);
+ Assert.NotNull(openAIConfiguration.ApiKey);
+ Assert.NotNull(openAIConfiguration.ServiceId);
// Arrange
- var openAIClient = new OpenAIClient(OpenAIConfiguration.ApiKey);
+ var openAIClient = new OpenAIClient(openAIConfiguration.ApiKey);
var builder = Kernel.CreateBuilder();
- builder.Services.AddChatClient(openAIClient.AsChatClient(OpenAIConfiguration.ChatModelId));
+ builder.Services.AddChatClient(openAIClient.GetChatClient(openAIConfiguration.ChatModelId).AsIChatClient());
var kernel = builder.Build();
var func = kernel.CreateFunctionFromPrompt(
@@ -104,16 +103,16 @@ public async Task OpenAIStreamingTestAsync()
[Fact]
public async Task ItCanUseOpenAiStreamingChatClientAndContentsAsync()
{
- var OpenAIConfiguration = this._configuration.GetSection("OpenAI").Get();
- Assert.NotNull(OpenAIConfiguration);
- Assert.NotNull(OpenAIConfiguration.ChatModelId);
- Assert.NotNull(OpenAIConfiguration.ApiKey);
- Assert.NotNull(OpenAIConfiguration.ServiceId);
+ var openAIConfiguration = this._configuration.GetSection("OpenAI").Get();
+ Assert.NotNull(openAIConfiguration);
+ Assert.NotNull(openAIConfiguration.ChatModelId);
+ Assert.NotNull(openAIConfiguration.ApiKey);
+ Assert.NotNull(openAIConfiguration.ServiceId);
// Arrange
- var openAIClient = new OpenAIClient(OpenAIConfiguration.ApiKey);
+ var openAIClient = new OpenAIClient(openAIConfiguration.ApiKey);
var builder = Kernel.CreateBuilder();
- builder.Services.AddChatClient(openAIClient.AsChatClient(OpenAIConfiguration.ChatModelId));
+ builder.Services.AddChatClient(openAIClient.GetChatClient(openAIConfiguration.ChatModelId).AsIChatClient());
var kernel = builder.Build();
var plugins = TestHelpers.ImportSamplePlugins(kernel, "ChatPlugin");
@@ -179,7 +178,7 @@ public async Task OpenAIHttpRetryPolicyTestAsync()
// Assert
Assert.All(statusCodes, s => Assert.Equal(HttpStatusCode.Unauthorized, s));
- Assert.Equal(HttpStatusCode.Unauthorized, ((HttpOperationException)exception).StatusCode);
+ Assert.Equal(HttpStatusCode.Unauthorized, exception.StatusCode);
}
[Fact]
@@ -258,11 +257,11 @@ public async Task SemanticKernelVersionHeaderIsSentAsync()
var kernel = this.CreateAndInitializeKernel(httpClient);
// Act
- var result = await kernel.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?");
+ await kernel.InvokePromptAsync("Where is the most famous fish market in Seattle, Washington, USA?");
// Assert
Assert.NotNull(httpHeaderHandler.RequestHeaders);
- Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var values));
+ Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var _));
}
//[Theory(Skip = "This test is for manual verification.")]
@@ -301,18 +300,18 @@ public async Task LogProbsDataIsReturnedWhenRequestedAsync(bool? logprobs, int?
private Kernel CreateAndInitializeKernel(HttpClient? httpClient = null)
{
- var OpenAIConfiguration = this._configuration.GetSection("OpenAI").Get();
- Assert.NotNull(OpenAIConfiguration);
- Assert.NotNull(OpenAIConfiguration.ChatModelId);
- Assert.NotNull(OpenAIConfiguration.ApiKey);
- Assert.NotNull(OpenAIConfiguration.ServiceId);
+ var openAIConfiguration = this._configuration.GetSection("OpenAI").Get();
+ Assert.NotNull(openAIConfiguration);
+ Assert.NotNull(openAIConfiguration.ChatModelId);
+ Assert.NotNull(openAIConfiguration.ApiKey);
+ Assert.NotNull(openAIConfiguration.ServiceId);
- var kernelBuilder = base.CreateKernelBuilder();
+ var kernelBuilder = this.CreateKernelBuilder();
kernelBuilder.AddOpenAIChatCompletion(
- modelId: OpenAIConfiguration.ChatModelId,
- apiKey: OpenAIConfiguration.ApiKey,
- serviceId: OpenAIConfiguration.ServiceId,
+ modelId: openAIConfiguration.ChatModelId,
+ apiKey: openAIConfiguration.ApiKey,
+ serviceId: openAIConfiguration.ServiceId,
httpClient: httpClient);
return kernelBuilder.Build();
diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAITextToAudioTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAITextToAudioTests.cs
index c2818abe2502..420295fe4349 100644
--- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAITextToAudioTests.cs
+++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAITextToAudioTests.cs
@@ -18,7 +18,7 @@ public sealed class OpenAITextToAudioTests
.AddUserSecrets()
.Build();
- [Fact]//(Skip = "OpenAI will often throttle requests. This test is for manual verification.")]
+ [Fact] //(Skip = "OpenAI will often throttle requests. This test is for manual verification.")]
public async Task OpenAITextToAudioTestAsync()
{
// Arrange
diff --git a/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs b/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs
index 2fefb6ee9d16..3a8d561f4eaf 100644
--- a/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs
+++ b/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs
@@ -99,25 +99,25 @@ protected IChatClient AddChatClientToKernel(IKernelBuilder builder)
IChatClient chatClient;
if (this.UseOpenAIConfig)
{
- chatClient = new Microsoft.Extensions.AI.OpenAIChatClient(
- new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey),
- TestConfiguration.OpenAI.ChatModelId);
+ chatClient = new OpenAI.OpenAIClient(TestConfiguration.OpenAI.ApiKey)
+ .GetChatClient(TestConfiguration.OpenAI.ChatModelId)
+ .AsIChatClient();
}
else if (!string.IsNullOrEmpty(this.ApiKey))
{
- chatClient = new Microsoft.Extensions.AI.OpenAIChatClient(
- openAIClient: new AzureOpenAIClient(
+ chatClient = new AzureOpenAIClient(
endpoint: new Uri(TestConfiguration.AzureOpenAI.Endpoint),
- credential: new ApiKeyCredential(TestConfiguration.AzureOpenAI.ApiKey)),
- modelId: TestConfiguration.AzureOpenAI.ChatDeploymentName);
+ credential: new ApiKeyCredential(TestConfiguration.AzureOpenAI.ApiKey))
+ .GetChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName)
+ .AsIChatClient();
}
else
{
- chatClient = new Microsoft.Extensions.AI.OpenAIChatClient(
- openAIClient: new AzureOpenAIClient(
+ chatClient = new AzureOpenAIClient(
endpoint: new Uri(TestConfiguration.AzureOpenAI.Endpoint),
- credential: new AzureCliCredential()),
- modelId: TestConfiguration.AzureOpenAI.ChatDeploymentName);
+ credential: new AzureCliCredential())
+ .GetChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName)
+ .AsIChatClient();
}
var functionCallingChatClient = chatClient!.AsKernelFunctionInvokingChatClient();
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs
index a0d6b1865a8f..e7afdd77e153 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs
@@ -194,8 +194,8 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor,
public override MethodInfo UnderlyingMethod => FunctionDescriptor.Method;
public override JsonElement JsonSchema => FunctionDescriptor.JsonSchema;
public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions;
- protected override Task InvokeCoreAsync(
- IEnumerable>? arguments,
+ protected override ValueTask InvokeCoreAsync(
+ AIFunctionArguments? arguments,
CancellationToken cancellationToken)
{
var paramMarshallers = FunctionDescriptor.ParameterMarshallers;
@@ -283,7 +283,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
public JsonSerializerOptions JsonSerializerOptions { get; }
public JsonElement JsonSchema { get; }
public Func, CancellationToken, object?>[] ParameterMarshallers { get; }
- public Func> ReturnParameterMarshaller { get; }
+ public Func> ReturnParameterMarshaller { get; }
public ReflectionAIFunction? CachedDefaultInstance { get; set; }
private static string GetFunctionName(MethodInfo method)
@@ -395,7 +395,7 @@ static bool IsAsyncMethod(MethodInfo method)
///
/// Gets a delegate for handling the result value of a method, converting it into the to return from the invocation.
///
- private static Func> GetReturnParameterMarshaller(MethodInfo method, JsonSerializerOptions serializerOptions)
+ private static Func> GetReturnParameterMarshaller(MethodInfo method, JsonSerializerOptions serializerOptions)
{
Type returnType = method.ReturnType;
JsonTypeInfo returnTypeInfo;
@@ -403,7 +403,7 @@ static bool IsAsyncMethod(MethodInfo method)
// Void
if (returnType == typeof(void))
{
- return static (_, _) => Task.FromResult(null);
+ return static (_, _) => new ValueTask((object?)null);
}
// Task
@@ -461,7 +461,7 @@ static bool IsAsyncMethod(MethodInfo method)
returnTypeInfo = serializerOptions.GetTypeInfo(returnType);
return (result, cancellationToken) => SerializeResultAsync(result, returnTypeInfo, cancellationToken);
- static async Task SerializeResultAsync(object? result, JsonTypeInfo returnTypeInfo, CancellationToken cancellationToken)
+ static async ValueTask SerializeResultAsync(object? result, JsonTypeInfo returnTypeInfo, CancellationToken cancellationToken)
{
if (returnTypeInfo.Kind is JsonTypeInfoKind.None)
{
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs
index 8a5abc42a6e0..b840b33e690b 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs
@@ -33,7 +33,7 @@ internal ChatClientAIService(IChatClient chatClient)
var metadata = this._chatClient.GetService();
Verify.NotNull(metadata);
- this._internalAttributes[nameof(metadata.ModelId)] = metadata.ModelId;
+ this._internalAttributes["ModelId"] = metadata.DefaultModelId;
this._internalAttributes[nameof(metadata.ProviderName)] = metadata.ProviderName;
this._internalAttributes[nameof(metadata.ProviderUri)] = metadata.ProviderUri;
}
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
index 6d6b80fe28bb..0ffdd5fec99d 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
@@ -85,7 +85,7 @@ public static IChatCompletionService AsChatCompletionService(this IChatClient cl
{
Verify.NotNull(client);
- return client.GetService()?.ModelId;
+ return client.GetService()?.DefaultModelId;
}
///
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
index 1d48111b47af..10552971aba8 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
@@ -844,7 +844,7 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
// Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any
// further calls made as part of this function invocation. In particular, we must not use function calling settings naively here,
// as the called function could in turn telling the model about itself as a possible candidate for invocation.
- result = await invocationContext.AIFunction.InvokeAsync(autoFunctionInvocationContext.Arguments, cancellationToken).ConfigureAwait(false);
+ result = await invocationContext.AIFunction.InvokeAsync(new(autoFunctionInvocationContext.Arguments), cancellationToken).ConfigureAwait(false);
context.Result = new FunctionResult(context.Function, result);
}).ConfigureAwait(false);
result = autoFunctionInvocationContext.Result.GetValue();
diff --git a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs
index 6446dd1b97d5..8037f513ee65 100644
--- a/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs
+++ b/dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs
@@ -818,10 +818,10 @@ private async Task GetChatClientResultAsync(
};
}
- var modelId = chatClient.GetService()?.ModelId;
+ var modelId = chatClient.GetService()?.DefaultModelId;
// Usage details are global and duplicated for each chat message content, use first one to get usage information
- this.CaptureUsageDetails(chatClient.GetService()?.ModelId, chatResponse.Usage, this._logger);
+ this.CaptureUsageDetails(chatClient.GetService()?.DefaultModelId, chatResponse.Usage, this._logger);
return new FunctionResult(this, chatResponse)
{
diff --git a/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/AIFunctionKernelFunctionTests.cs b/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/AIFunctionKernelFunctionTests.cs
index 065784118c3d..20a2460e1751 100644
--- a/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/AIFunctionKernelFunctionTests.cs
+++ b/dotnet/src/SemanticKernel.UnitTests/AI/ChatCompletion/AIFunctionKernelFunctionTests.cs
@@ -14,8 +14,8 @@ public class AIFunctionKernelFunctionTests
public void ShouldAssignIsRequiredParameterMetadataPropertyCorrectly()
{
// Arrange and Act
- AIFunction aiFunction = AIFunctionFactory.Create((string p1, int? p2 = null) => p1,
- new AIFunctionFactoryOptions { JsonSchemaCreateOptions = new AIJsonSchemaCreateOptions { RequireAllProperties = false } });
+ AIFunction aiFunction = Microsoft.Extensions.AI.AIFunctionFactory.Create((string p1, int? p2 = null) => p1,
+ new Microsoft.Extensions.AI.AIFunctionFactoryOptions { JsonSchemaCreateOptions = new AIJsonSchemaCreateOptions { RequireAllProperties = false } });
AIFunctionKernelFunction sut = new(aiFunction);
diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs
index 322c5f3b935f..7400875cdc9b 100644
--- a/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs
+++ b/dotnet/src/SemanticKernel.UnitTests/Functions/CustomAIChatClientSelectorTests.cs
@@ -70,7 +70,7 @@ private sealed class ChatClientTest : IChatClient
public ChatClientTest()
{
- this._metadata = new ChatClientMetadata(modelId: "Value1");
+ this._metadata = new ChatClientMetadata(defaultModelId: "Value1");
}
public void Dispose()
diff --git a/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs b/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs
index b0a57fa8cdf6..52619c22af6e 100644
--- a/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs
+++ b/dotnet/src/SemanticKernel.UnitTests/Functions/OrderedAIServiceSelectorTests.cs
@@ -527,7 +527,7 @@ private sealed class ChatClient : IChatClient
public ChatClient(string modelId)
{
- this.Metadata = new ChatClientMetadata(modelId: modelId);
+ this.Metadata = new ChatClientMetadata(defaultModelId: modelId);
}
public Task> GetTextContentsAsync(string prompt, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
From 83c2e18389303c1715db6dbaba49ac409ecdcf47 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Fri, 11 Apr 2025 17:21:47 +0100
Subject: [PATCH 18/26] Internalizing components
---
.../KernelFunctionInvocationContext.cs | 6 +-
.../KernelFunctionInvokingChatClient.cs | 95 ++++++++--------
.../AutoFunctionInvocationContext.cs | 106 ++++++++++++------
3 files changed, 123 insertions(+), 84 deletions(-)
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
index a36f03630fd5..c8f293b2bb0c 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
@@ -16,10 +16,10 @@ namespace Microsoft.SemanticKernel.ChatCompletion;
/// Provides context for an in-flight function invocation.
[ExcludeFromCodeCoverage]
-public class KernelFunctionInvocationContext
+internal class KernelFunctionInvocationContext
{
///
- /// A nop function used to allow to be non-nullable. Default instances of
+ /// A nop function used to allow to be non-nullable. Default instances of
/// start with this as the target function.
///
private static readonly AIFunction s_nopFunction = AIFunctionFactory.Create(() => { }, nameof(KernelFunctionInvocationContext));
@@ -64,7 +64,7 @@ public IList Messages
public ChatOptions? Options { get; set; }
/// Gets or sets the AI function to be invoked.
- public AIFunction AIFunction
+ public AIFunction Function
{
get => _function;
set
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
index 10552971aba8..3a5eb88b5cf3 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
@@ -47,8 +47,8 @@ namespace Microsoft.SemanticKernel.ChatCompletion;
[ExcludeFromCodeCoverage]
internal sealed partial class KernelFunctionInvokingChatClient : DelegatingChatClient
{
- /// The for the current function invocation.
- private static readonly AsyncLocal s_currentContext = new();
+ /// The for the current function invocation.
+ private static readonly AsyncLocal s_currentContext = new();
/// The logger to use for logging information about function invocation.
private readonly ILogger _logger;
@@ -78,7 +78,7 @@ public KernelFunctionInvokingChatClient(IChatClient innerClient, ILogger? logger
///
/// This value flows across async calls.
///
- internal static KernelFunctionInvocationContext? CurrentContext
+ internal static AutoFunctionInvocationContext? CurrentContext
{
get => s_currentContext.Value;
set => s_currentContext.Value = value;
@@ -120,13 +120,13 @@ internal static KernelFunctionInvocationContext? CurrentContext
///
///
/// Setting the value to can help the underlying bypass problems on
- /// its own, for example by retrying the function call with different arguments. However it might
+ /// its own, for example by retrying the function call with different arguments. However, it might
/// result in disclosing the raw exception information to external users, which can be a security
/// concern depending on the application scenario.
///
///
/// Changing the value of this property while the client is in use might result in inconsistencies
- /// as to whether detailed errors are provided during an in-flight request.
+ /// whether detailed errors are provided during an in-flight request.
///
///
public bool IncludeDetailedErrors { get; set; }
@@ -227,7 +227,7 @@ public override async Task GetResponseAsync(
return response;
}
- // Track aggregatable details from the response, including all of the response messages and usage details.
+ // Track aggregate details from the response, including all the response messages and usage details.
(responseMessages ??= []).AddRange(response.Messages);
if (response.Usage is not null)
{
@@ -257,13 +257,13 @@ public override async Task GetResponseAsync(
// Add the responses from the function calls into the augmented history and also into the tracked
// list of response messages.
- var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, isStreaming: false, cancellationToken).ConfigureAwait(false);
+ var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents!, iteration, isStreaming: false, cancellationToken).ConfigureAwait(false);
responseMessages.AddRange(modeAndMessages.MessagesAdded);
// Clear the auto function invocation options.
- ClearOptionsForAutoFunctionInvocation(ref options!);
+ ClearOptionsForAutoFunctionInvocation(ref options);
- if (UpdateOptionsForMode(modeAndMessages.Mode, ref options!, response.ChatThreadId))
+ if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, response.ChatThreadId))
{
// Terminate
break;
@@ -271,7 +271,7 @@ public override async Task GetResponseAsync(
}
Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages.");
- response.Messages = responseMessages!;
+ response.Messages = responseMessages;
response.Usage = totalUsage;
return response;
@@ -334,14 +334,14 @@ public override async IAsyncEnumerable GetStreamingResponseA
FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId);
// Prepare the options for the next auto function invocation iteration.
- UpdateOptionsForAutoFunctionInvocation(ref options!, response.Messages.Last().ToChatMessageContent(), isStreaming: true);
+ UpdateOptionsForAutoFunctionInvocation(ref options, response.Messages.Last().ToChatMessageContent(), isStreaming: true);
// Process all the functions, adding their results into the history.
var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, isStreaming: true, cancellationToken).ConfigureAwait(false);
responseMessages.AddRange(modeAndMessages.MessagesAdded);
// Clear the auto function invocation options.
- ClearOptionsForAutoFunctionInvocation(ref options!);
+ ClearOptionsForAutoFunctionInvocation(ref options);
// Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages
// include all activities, including generated function results.
@@ -408,7 +408,7 @@ private static void FixupHistories(
{
// In the very rare case where the inner client returned a response with a thread ID but then
// returned a subsequent response without one, we want to reconstitute the full history. To do that,
- // we can populate the history with the original chat messages and then all of the response
+ // we can populate the history with the original chat messages and then all the response
// messages up until this point, which includes the most recent ones.
augmentedHistory ??= [];
augmentedHistory.Clear();
@@ -667,16 +667,16 @@ private async Task ProcessFunctionCallAsync(
callContent.Arguments = new KernelArguments(callContent.Arguments);
}
- var context = new AutoFunctionInvocationContext(options)
+ var context = new AutoFunctionInvocationContext(new KernelFunctionInvocationContext
{
+ Options = options,
Messages = messages,
CallContent = callContent,
- AIFunction = function,
+ Function = function,
Iteration = iteration,
FunctionCallIndex = functionCallIndex,
FunctionCount = callContents.Count,
- IsStreaming = isStreaming
- };
+ }) { IsStreaming = isStreaming };
object? result;
try
@@ -721,9 +721,9 @@ internal IList CreateResponseMessages(
IReadOnlyList results)
{
var contents = new List(results.Count);
- for (int i = 0; i < results.Count; i++)
+ foreach (var t in results)
{
- contents.Add(CreateFunctionResultContent(results[i]));
+ contents.Add(CreateFunctionResultContent(t));
}
return [new(ChatRole.Tool, contents)];
@@ -778,20 +778,20 @@ private async Task OnAutoFunctionInvocationAsync(
/// If there are no registered filters, just function will be executed.
/// If there are registered filters, filter on position will be executed.
/// Second parameter of filter is callback. It can be either filter on + 1 position or function if there are no remaining filters to execute.
- /// Function will be always executed as last step after all filters.
+ /// Function will always be executed as last step after all filters.
///
private async Task InvokeFilterOrFunctionAsync(
Func functionCallCallback,
AutoFunctionInvocationContext context,
int index = 0)
{
- IList? autoFunctionInvocationFilters = context.Kernel.AutoFunctionInvocationFilters;
+ IList autoFunctionInvocationFilters = context.Kernel.AutoFunctionInvocationFilters;
if (autoFunctionInvocationFilters is { Count: > 0 } && index < autoFunctionInvocationFilters.Count)
{
await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
context,
- (context) => this.InvokeFilterOrFunctionAsync(functionCallCallback, context, index + 1)
+ (ctx) => this.InvokeFilterOrFunctionAsync(functionCallCallback, ctx, index + 1)
).ConfigureAwait(false);
}
else
@@ -805,11 +805,11 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
/// The to monitor for cancellation requests. The default is .
/// The result of the function invocation, or if the function invocation returned .
/// is .
- internal async Task InvokeFunctionAsync(KernelFunctionInvocationContext invocationContext, CancellationToken cancellationToken)
+ private async Task InvokeFunctionAsync(AutoFunctionInvocationContext invocationContext, CancellationToken cancellationToken)
{
Verify.NotNull(invocationContext);
- using Activity? activity = this._activitySource?.StartActivity(invocationContext.AIFunction.Name);
+ using Activity? activity = this._activitySource?.StartActivity(invocationContext.Function.Name);
long startingTimestamp = 0;
if (this._logger.IsEnabled(LogLevel.Debug))
@@ -817,11 +817,11 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
startingTimestamp = Stopwatch.GetTimestamp();
if (this._logger.IsEnabled(LogLevel.Trace))
{
- this.LogInvokingSensitive(invocationContext.AIFunction.Name, LoggingAsJson(invocationContext.CallContent.Arguments, invocationContext.AIFunction.JsonSerializerOptions));
+ this.LogInvokingSensitive(invocationContext.Function.Name, LoggingAsJson(invocationContext.CallContent.Arguments, invocationContext.AIFunction.JsonSerializerOptions));
}
else
{
- this.LogInvoking(invocationContext.AIFunction.Name);
+ this.LogInvoking(invocationContext.Function.Name);
}
}
@@ -829,26 +829,23 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
try
{
CurrentContext = invocationContext;
- if (invocationContext is AutoFunctionInvocationContext autoFunctionInvocationContext)
- {
- autoFunctionInvocationContext = await this.OnAutoFunctionInvocationAsync(
- autoFunctionInvocationContext,
- async (context) =>
+ invocationContext = await this.OnAutoFunctionInvocationAsync(
+ invocationContext,
+ async (context) =>
+ {
+ // Check if filter requested termination
+ if (context.Terminate)
{
- // Check if filter requested termination
- if (context.Terminate)
- {
- return;
- }
-
- // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any
- // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here,
- // as the called function could in turn telling the model about itself as a possible candidate for invocation.
- result = await invocationContext.AIFunction.InvokeAsync(new(autoFunctionInvocationContext.Arguments), cancellationToken).ConfigureAwait(false);
- context.Result = new FunctionResult(context.Function, result);
- }).ConfigureAwait(false);
- result = autoFunctionInvocationContext.Result.GetValue();
- }
+ return;
+ }
+
+ // Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any
+ // further calls made as part of this function invocation. In particular, we must not use function calling settings naively here,
+ // as the called function could in turn telling the model about itself as a possible candidate for invocation.
+ result = await invocationContext.AIFunction.InvokeAsync(new(invocationContext.Arguments), cancellationToken).ConfigureAwait(false);
+ context.Result = new FunctionResult(context.Function, result);
+ }).ConfigureAwait(false);
+ result = invocationContext.Result.GetValue();
}
catch (Exception e)
{
@@ -860,11 +857,11 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
if (e is OperationCanceledException)
{
- this.LogInvocationCanceled(invocationContext.AIFunction.Name);
+ this.LogInvocationCanceled(invocationContext.Function.Name);
}
else
{
- this.LogInvocationFailed(invocationContext.AIFunction.Name, e);
+ this.LogInvocationFailed(invocationContext.Function.Name, e);
}
throw;
@@ -877,11 +874,11 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
if (result is not null && this._logger.IsEnabled(LogLevel.Trace))
{
- this.LogInvocationCompletedSensitive(invocationContext.AIFunction.Name, elapsed, LoggingAsJson(result, invocationContext.AIFunction.JsonSerializerOptions));
+ this.LogInvocationCompletedSensitive(invocationContext.Function.Name, elapsed, LoggingAsJson(result, invocationContext.AIFunction.JsonSerializerOptions));
}
else
{
- this.LogInvocationCompleted(invocationContext.AIFunction.Name, elapsed);
+ this.LogInvocationCompleted(invocationContext.Function.Name, elapsed);
}
}
}
diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
index d46e006e7ebb..f70205365330 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
@@ -13,30 +13,32 @@ namespace Microsoft.SemanticKernel;
///
/// Class with data related to automatic function invocation.
///
-public class AutoFunctionInvocationContext : KernelFunctionInvocationContext
+public class AutoFunctionInvocationContext
{
private ChatHistory? _chatHistory;
private KernelFunction? _kernelFunction;
+ private readonly KernelFunctionInvocationContext _invocationContext = new();
///
/// Initializes a new instance of the class from an existing .
///
- public AutoFunctionInvocationContext(ChatOptions options)
+ internal AutoFunctionInvocationContext(KernelFunctionInvocationContext invocationContext)
{
- this.Options = options;
+ Verify.NotNull(invocationContext);
+ Verify.NotNull(invocationContext.Options);
- // To create a AutoFunctionInvocationContext from a KernelFunctionInvocationContext,
// the ChatOptions must be provided with AdditionalProperties.
- Verify.NotNull(options.AdditionalProperties);
+ Verify.NotNull(invocationContext.Options.AdditionalProperties);
- options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.KernelKey, out var kernel);
+ invocationContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.KernelKey, out var kernel);
Verify.NotNull(kernel);
- options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.ChatMessageContentKey, out var chatMessageContent);
+ invocationContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.ChatMessageContentKey, out var chatMessageContent);
Verify.NotNull(chatMessageContent);
- options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.PromptExecutionSettingsKey, out var executionSettings);
+ invocationContext.Options.AdditionalProperties.TryGetValue(ChatOptionsExtensions.PromptExecutionSettingsKey, out var executionSettings);
this.ExecutionSettings = executionSettings;
+ this._invocationContext = invocationContext;
this.Result = new FunctionResult(this.Function) { Culture = kernel.Culture };
}
@@ -62,7 +64,7 @@ public AutoFunctionInvocationContext(
Verify.NotNull(chatHistory);
Verify.NotNull(chatMessageContent);
- this.Options = new()
+ this._invocationContext.Options = new()
{
AdditionalProperties = new()
{
@@ -73,9 +75,9 @@ public AutoFunctionInvocationContext(
this._kernelFunction = function;
this._chatHistory = chatHistory;
- this.Messages = chatHistory.ToChatMessageList();
- chatHistory.SetChatMessageHandlers(this.Messages);
- this.AIFunction = function.AsAIFunction();
+ this._invocationContext.Messages = chatHistory.ToChatMessageList();
+ chatHistory.SetChatMessageHandlers(this._invocationContext.Messages);
+ this._invocationContext.Function = function.AsAIFunction();
this.Result = result;
}
@@ -95,8 +97,8 @@ public AutoFunctionInvocationContext(
///
public KernelArguments? Arguments
{
- get => this.CallContent.Arguments is KernelArguments kernelArguments ? kernelArguments : null;
- init => this.CallContent.Arguments = value;
+ get => this._invocationContext.CallContent.Arguments is KernelArguments kernelArguments ? kernelArguments : null;
+ init => this._invocationContext.CallContent.Arguments = value;
}
///
@@ -104,8 +106,8 @@ public KernelArguments? Arguments
///
public int RequestSequenceIndex
{
- get => this.Iteration;
- init => this.Iteration = value;
+ get => this._invocationContext.Iteration;
+ init => this._invocationContext.Iteration = value;
}
///
@@ -113,8 +115,19 @@ public int RequestSequenceIndex
///
public int FunctionSequenceIndex
{
- get => this.FunctionCallIndex;
- init => this.FunctionCallIndex = value;
+ get => this._invocationContext.FunctionCallIndex;
+ init => this._invocationContext.FunctionCallIndex = value;
+ }
+
+ /// Gets or sets the total number of function call requests within the iteration.
+ ///
+ /// The response from the underlying client might include multiple function call requests.
+ /// This count indicates how many there were.
+ ///
+ public int FunctionCount
+ {
+ get => this._invocationContext.FunctionCount;
+ init => this._invocationContext.FunctionCount = value;
}
///
@@ -122,36 +135,40 @@ public int FunctionSequenceIndex
///
public string? ToolCallId
{
- get => this.CallContent.CallId;
+ get => this._invocationContext.CallContent.CallId;
init
{
- this.CallContent = new Microsoft.Extensions.AI.FunctionCallContent(value ?? string.Empty, this.CallContent.Name, this.CallContent.Arguments);
+ this._invocationContext.CallContent = new Microsoft.Extensions.AI.FunctionCallContent(
+ callId: value ?? string.Empty,
+ name: this._invocationContext.CallContent.Name,
+ arguments: this._invocationContext.CallContent.Arguments);
}
}
///
/// The chat message content associated with automatic function invocation.
///
- public ChatMessageContent ChatMessageContent => (this.Options?.AdditionalProperties?[ChatOptionsExtensions.ChatMessageContentKey] as ChatMessageContent)!;
+ public ChatMessageContent ChatMessageContent
+ => (this._invocationContext.Options?.AdditionalProperties?[ChatOptionsExtensions.ChatMessageContentKey] as ChatMessageContent)!;
///
/// The execution settings associated with the operation.
///
public PromptExecutionSettings? ExecutionSettings
{
- get => this.Options?.AdditionalProperties?[ChatOptionsExtensions.PromptExecutionSettingsKey] as PromptExecutionSettings;
+ get => this._invocationContext.Options?.AdditionalProperties?[ChatOptionsExtensions.PromptExecutionSettingsKey] as PromptExecutionSettings;
init
{
- this.Options ??= new();
- this.Options.AdditionalProperties ??= [];
- this.Options.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = value;
+ this._invocationContext.Options ??= new();
+ this._invocationContext.Options.AdditionalProperties ??= [];
+ this._invocationContext.Options.AdditionalProperties[ChatOptionsExtensions.PromptExecutionSettingsKey] = value;
}
}
///
/// Gets the associated with automatic function invocation.
///
- public ChatHistory ChatHistory => this._chatHistory ??= new ChatMessageHistory(this.Messages);
+ public ChatHistory ChatHistory => this._chatHistory ??= new ChatMessageHistory(this._invocationContext.Messages);
///
/// Gets the with which this filter is associated.
@@ -163,9 +180,9 @@ public KernelFunction Function
if (this._kernelFunction is null
// If the schemas are different,
// AIFunction reference potentially was modified and the kernel function should be regenerated.
- || !IsSameSchema(this._kernelFunction, this.AIFunction))
+ || !IsSameSchema(this._kernelFunction, this._invocationContext.Function))
{
- this._kernelFunction = this.AIFunction.AsKernelFunction();
+ this._kernelFunction = this._invocationContext.Function.AsKernelFunction();
}
return this._kernelFunction;
@@ -180,7 +197,7 @@ public Kernel Kernel
get
{
Kernel? kernel = null;
- this.Options?.AdditionalProperties?.TryGetValue(ChatOptionsExtensions.KernelKey, out kernel);
+ this._invocationContext.Options?.AdditionalProperties?.TryGetValue(ChatOptionsExtensions.KernelKey, out kernel);
// To avoid exception from properties, when attempting to retrieve a kernel from a non-ready context, it will give a null.
return kernel!;
@@ -192,6 +209,32 @@ public Kernel Kernel
///
public FunctionResult Result { get; set; }
+ /// Gets or sets a value indicating whether to terminate the request.
+ ///
+ /// In response to a function call request, the function might be invoked, its result added to the chat contents,
+ /// and a new request issued to the wrapped client. If this property is set to , that subsequent request
+ /// will not be issued and instead the loop immediately terminated rather than continuing until there are no
+ /// more function call requests in responses.
+ ///
+ public bool Terminate
+ {
+ get => this._invocationContext.Terminate;
+ set => this._invocationContext.Terminate = value;
+ }
+
+ /// Gets or sets the function call content information associated with this invocation.
+ internal Microsoft.Extensions.AI.FunctionCallContent CallContent
+ {
+ get => this._invocationContext.CallContent;
+ set => this._invocationContext.CallContent = value;
+ }
+
+ internal AIFunction AIFunction
+ {
+ get => this._invocationContext.Function;
+ set => this._invocationContext.Function = value;
+ }
+
private static bool IsSameSchema(KernelFunction kernelFunction, AIFunction aiFunction)
{
// Compares the schemas, should be similar.
@@ -305,11 +348,10 @@ IEnumerator IEnumerable.GetEnumerator()
=> ((IEnumerable)this).GetEnumerator();
}
- /// Destructor to clear the overrides and update the message.
+ /// Destructor to clear the chat history overrides.
~AutoFunctionInvocationContext()
{
- // The moment this class is destroyed, we need to clear the overrides and
- // overrides to update to message
+ // The moment this class is destroyed, we need to clear the update message overrides
this._chatHistory?.ClearOverrides();
}
}
From c3b89d825b3e6d31271d0a5ce514f1c8feba8966 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Mon, 14 Apr 2025 10:53:39 +0100
Subject: [PATCH 19/26] Starting update of FunctionInvokingChatClient
---
dotnet/Directory.Packages.props | 3 +-
.../AI/ChatClient/AIFunctionFactory.cs | 631 -------------
.../AI/ChatClient/FunctionFactoryOptions.cs | 63 --
.../KernelFunctionInvokingChatClient.cs | 4 +-
.../KernelFunctionInvokingChatClientV2.cs | 873 ++++++++++++++++++
.../SemanticKernel.Abstractions.csproj | 2 +-
6 files changed, 877 insertions(+), 699 deletions(-)
delete mode 100644 dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs
delete mode 100644 dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/FunctionFactoryOptions.cs
create mode 100644 dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientV2.cs
diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props
index 7878b0d5b359..94e6c8de85ef 100644
--- a/dotnet/Directory.Packages.props
+++ b/dotnet/Directory.Packages.props
@@ -93,7 +93,6 @@
-
@@ -109,7 +108,7 @@
-
+
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs
deleted file mode 100644
index e7afdd77e153..000000000000
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/AIFunctionFactory.cs
+++ /dev/null
@@ -1,631 +0,0 @@
-// Copyright (c) Microsoft. All rights reserved.
-
-using System;
-using System.Buffers;
-using System.Collections.Concurrent;
-using System.Collections.Generic;
-using System.ComponentModel;
-using System.Diagnostics;
-using System.Diagnostics.CodeAnalysis;
-using System.IO;
-using System.Linq;
-using System.Reflection;
-using System.Runtime.CompilerServices;
-using System.Text.Json;
-using System.Text.Json.Nodes;
-using System.Text.Json.Serialization.Metadata;
-using System.Text.RegularExpressions;
-using System.Threading;
-using System.Threading.Tasks;
-using Microsoft.Extensions.AI;
-
-#pragma warning disable IDE0009 // Use explicit 'this.' qualifier
-#pragma warning disable IDE1006 // Missing static prefix s_ suffix
-
-namespace Microsoft.SemanticKernel.ChatCompletion;
-
-// Slight modified source from
-// https://raw.githubusercontent.com/dotnet/extensions/refs/heads/main/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs
-
-/// Provides factory methods for creating commonly used implementations of .
-[ExcludeFromCodeCoverage]
-internal static partial class AIFunctionFactory
-{
- /// Holds the default options instance used when creating function.
- private static readonly AIFunctionFactoryOptions _defaultOptions = new();
-
- /// Creates an instance for a method, specified via a delegate.
- /// The method to be represented via the created .
- /// Metadata to use to override defaults inferred from .
- /// The created for invoking .
- ///
- ///
- /// Return values are serialized to using 's
- /// . Arguments that are not already of the expected type are
- /// marshaled to the expected type via JSON and using 's
- /// . If the argument is a ,
- /// , or , it is deserialized directly. If the argument is anything else unknown,
- /// it is round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type.
- ///
- ///
- public static AIFunction Create(Delegate method, AIFunctionFactoryOptions? options)
- {
- Verify.NotNull(method);
-
- return ReflectionAIFunction.Build(method.Method, method.Target, options ?? _defaultOptions);
- }
-
- /// Creates an instance for a method, specified via a delegate.
- /// The method to be represented via the created .
- /// The name to use for the .
- /// The description to use for the .
- /// The used to marshal function parameters and any return value.
- /// The created for invoking .
- ///
- ///
- /// Return values are serialized to using .
- /// Arguments that are not already of the expected type are marshaled to the expected type via JSON and using
- /// . If the argument is a , ,
- /// or , it is deserialized directly. If the argument is anything else unknown, it is
- /// round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type.
- ///
- ///
- public static AIFunction Create(Delegate method, string? name = null, string? description = null, JsonSerializerOptions? serializerOptions = null)
- {
- Verify.NotNull(method);
-
- AIFunctionFactoryOptions createOptions = serializerOptions is null && name is null && description is null
- ? _defaultOptions
- : new()
- {
- Name = name,
- Description = description,
- SerializerOptions = serializerOptions,
- };
-
- return ReflectionAIFunction.Build(method.Method, method.Target, createOptions);
- }
-
- ///
- /// Creates an instance for a method, specified via an instance
- /// and an optional target object if the method is an instance method.
- ///
- /// The method to be represented via the created .
- ///
- /// The target object for the if it represents an instance method.
- /// This should be if and only if is a static method.
- ///
- /// Metadata to use to override defaults inferred from .
- /// The created for invoking .
- ///
- ///
- /// Return values are serialized to using 's
- /// . Arguments that are not already of the expected type are
- /// marshaled to the expected type via JSON and using 's
- /// . If the argument is a ,
- /// , or , it is deserialized directly. If the argument is anything else unknown,
- /// it is round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type.
- ///
- ///
- public static AIFunction Create(MethodInfo method, object? target, AIFunctionFactoryOptions? options)
- {
- Verify.NotNull(method);
-
- return ReflectionAIFunction.Build(method, target, options ?? _defaultOptions);
- }
-
- ///
- /// Creates an instance for a method, specified via an instance
- /// and an optional target object if the method is an instance method.
- ///
- /// The method to be represented via the created .
- ///
- /// The target object for the if it represents an instance method.
- /// This should be if and only if is a static method.
- ///
- /// The name to use for the .
- /// The description to use for the .
- /// The used to marshal function parameters and return value.
- /// The created for invoking .
- ///
- ///
- /// Return values are serialized to using .
- /// Arguments that are not already of the expected type are marshaled to the expected type via JSON and using
- /// . If the argument is a , ,
- /// or , it is deserialized directly. If the argument is anything else unknown, it is
- /// round-tripped through JSON, serializing the object as JSON and then deserializing it to the expected type.
- ///
- ///
- public static AIFunction Create(MethodInfo method, object? target, string? name = null, string? description = null, JsonSerializerOptions? serializerOptions = null)
- {
- Verify.NotNull(method);
-
- AIFunctionFactoryOptions createOptions = serializerOptions is null && name is null && description is null
- ? _defaultOptions
- : new()
- {
- Name = name,
- Description = description,
- SerializerOptions = serializerOptions,
- };
-
- return ReflectionAIFunction.Build(method, target, createOptions);
- }
-
- private sealed class ReflectionAIFunction : AIFunction
- {
- public static ReflectionAIFunction Build(MethodInfo method, object? target, AIFunctionFactoryOptions options)
- {
- Verify.NotNull(method);
-
- if (method.ContainsGenericParameters)
- {
- throw new ArgumentException("Open generic methods are not supported", nameof(method));
- }
-
- if (!method.IsStatic && target is null)
- {
- throw new ArgumentNullException(nameof(target), "Target must not be null for an instance method.");
- }
-
- ReflectionAIFunctionDescriptor functionDescriptor = ReflectionAIFunctionDescriptor.GetOrCreate(method, options);
-
- if (target is null && options.AdditionalProperties is null)
- {
- // We can use a cached value for static methods not specifying additional properties.
- return functionDescriptor.CachedDefaultInstance ??= new(functionDescriptor, target, options);
- }
-
- return new(functionDescriptor, target, options);
- }
-
- private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, object? target, AIFunctionFactoryOptions options)
- {
- FunctionDescriptor = functionDescriptor;
- Target = target;
- AdditionalProperties = options.AdditionalProperties ?? EmptyReadOnlyDictionary.Instance;
- }
-
- public ReflectionAIFunctionDescriptor FunctionDescriptor { get; }
- public object? Target { get; }
- public override IReadOnlyDictionary AdditionalProperties { get; }
- public override string Name => FunctionDescriptor.Name;
- public override string Description => FunctionDescriptor.Description;
- public override MethodInfo UnderlyingMethod => FunctionDescriptor.Method;
- public override JsonElement JsonSchema => FunctionDescriptor.JsonSchema;
- public override JsonSerializerOptions JsonSerializerOptions => FunctionDescriptor.JsonSerializerOptions;
- protected override ValueTask InvokeCoreAsync(
- AIFunctionArguments? arguments,
- CancellationToken cancellationToken)
- {
- var paramMarshallers = FunctionDescriptor.ParameterMarshallers;
- object?[] args = paramMarshallers.Length != 0 ? new object?[paramMarshallers.Length] : [];
-
- IReadOnlyDictionary argDict =
- arguments is null || args.Length == 0 ? EmptyReadOnlyDictionary.Instance :
- arguments as IReadOnlyDictionary ??
- arguments.
-#if NET8_0_OR_GREATER
- ToDictionary();
-#else
- ToDictionary(kvp => kvp.Key, kvp => kvp.Value);
-#endif
- for (int i = 0; i < args.Length; i++)
- {
- args[i] = paramMarshallers[i](argDict, cancellationToken);
- }
-
- return FunctionDescriptor.ReturnParameterMarshaller(ReflectionInvoke(FunctionDescriptor.Method, Target, args), cancellationToken);
- }
- }
-
- ///
- /// A descriptor for a .NET method-backed AIFunction that precomputes its marshalling delegates and JSON schema.
- ///
- private sealed class ReflectionAIFunctionDescriptor
- {
- private const int InnerCacheSoftLimit = 512;
- private static readonly ConditionalWeakTable> _descriptorCache = new();
-
- /// A boxed .
- private static readonly object? _boxedDefaultCancellationToken = default(CancellationToken);
-
- ///
- /// Gets or creates a descriptors using the specified method and options.
- ///
- public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFunctionFactoryOptions options)
- {
- JsonSerializerOptions serializerOptions = options.SerializerOptions ?? AIJsonUtilities.DefaultOptions;
- AIJsonSchemaCreateOptions schemaOptions = options.JsonSchemaCreateOptions ?? AIJsonSchemaCreateOptions.Default;
- serializerOptions.MakeReadOnly();
- ConcurrentDictionary innerCache = _descriptorCache.GetOrCreateValue(serializerOptions);
-
- DescriptorKey key = new(method, options.Name, options.Description, schemaOptions);
- if (innerCache.TryGetValue(key, out ReflectionAIFunctionDescriptor? descriptor))
- {
- return descriptor;
- }
-
- descriptor = new(key, serializerOptions);
- return innerCache.Count < InnerCacheSoftLimit
- ? innerCache.GetOrAdd(key, descriptor)
- : descriptor;
- }
-
- private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions serializerOptions)
- {
- // Get marshaling delegates for parameters.
- ParameterInfo[] parameters = key.Method.GetParameters();
- ParameterMarshallers = new Func, CancellationToken, object?>[parameters.Length];
- for (int i = 0; i < parameters.Length; i++)
- {
- ParameterMarshallers[i] = GetParameterMarshaller(serializerOptions, parameters[i]);
- }
-
- // Get a marshaling delegate for the return value.
- ReturnParameterMarshaller = GetReturnParameterMarshaller(key.Method, serializerOptions);
-
- Method = key.Method;
- Name = key.Name ?? GetFunctionName(key.Method);
- Description = key.Description ?? key.Method.GetCustomAttribute(inherit: true)?.Description ?? string.Empty;
- JsonSerializerOptions = serializerOptions;
- JsonSchema = AIJsonUtilities.CreateFunctionJsonSchema(
- key.Method,
- Name,
- Description,
- serializerOptions,
- key.SchemaOptions);
- }
-
- public string Name { get; }
- public string Description { get; }
- public MethodInfo Method { get; }
- public JsonSerializerOptions JsonSerializerOptions { get; }
- public JsonElement JsonSchema { get; }
- public Func, CancellationToken, object?>[] ParameterMarshallers { get; }
- public Func> ReturnParameterMarshaller { get; }
- public ReflectionAIFunction? CachedDefaultInstance { get; set; }
-
- private static string GetFunctionName(MethodInfo method)
- {
- // Get the function name to use.
- string name = SanitizeMemberName(method.Name);
-
- const string AsyncSuffix = "Async";
- if (IsAsyncMethod(method) &&
- name.EndsWith(AsyncSuffix, StringComparison.Ordinal) &&
- name.Length > AsyncSuffix.Length)
- {
- name = name.Substring(0, name.Length - AsyncSuffix.Length);
- }
-
- return name;
-
- static bool IsAsyncMethod(MethodInfo method)
- {
- Type t = method.ReturnType;
-
- if (t == typeof(Task) || t == typeof(ValueTask))
- {
- return true;
- }
-
- if (t.IsGenericType)
- {
- t = t.GetGenericTypeDefinition();
- if (t == typeof(Task<>) || t == typeof(ValueTask<>) || t == typeof(IAsyncEnumerable<>))
- {
- return true;
- }
- }
-
- return false;
- }
- }
-
- ///
- /// Gets a delegate for handling the marshaling of a parameter.
- ///
- private static Func, CancellationToken, object?> GetParameterMarshaller(
- JsonSerializerOptions serializerOptions,
- ParameterInfo parameter)
- {
- if (string.IsNullOrWhiteSpace(parameter.Name))
- {
- throw new ArgumentException("Parameter is missing a name.", nameof(parameter));
- }
-
- // Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found.
- Type parameterType = parameter.ParameterType;
- JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(parameterType);
-
- // For CancellationToken parameters, we always bind to the token passed directly to InvokeAsync.
- if (parameterType == typeof(CancellationToken))
- {
- return static (_, cancellationToken) =>
- cancellationToken == default ? _boxedDefaultCancellationToken : // optimize common case of a default CT to avoid boxing
- cancellationToken;
- }
-
- // For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary.
- return (arguments, _) =>
- {
- // If the parameter has an argument specified in the dictionary, return that argument.
- if (arguments.TryGetValue(parameter.Name, out object? value))
- {
- return value switch
- {
- null => null, // Return as-is if null -- if the parameter is a struct this will be handled by MethodInfo.Invoke
- _ when parameterType.IsInstanceOfType(value) => value, // Do nothing if value is assignable to parameter type
- JsonElement element => JsonSerializer.Deserialize(element, typeInfo),
- JsonDocument doc => JsonSerializer.Deserialize(doc, typeInfo),
- JsonNode node => JsonSerializer.Deserialize(node, typeInfo),
- _ => MarshallViaJsonRoundtrip(value),
- };
-
- object? MarshallViaJsonRoundtrip(object value)
- {
-#pragma warning disable CA1031 // Do not catch general exception types
- try
- {
- string json = JsonSerializer.Serialize(value, serializerOptions.GetTypeInfo(value.GetType()));
- return JsonSerializer.Deserialize(json, typeInfo);
- }
- catch
- {
- // Eat any exceptions and fall back to the original value to force a cast exception later on.
- return value;
- }
-#pragma warning restore CA1031
- }
- }
-
- // There was no argument for the parameter in the dictionary.
- // Does it have a default value?
- if (parameter.HasDefaultValue)
- {
- return parameter.DefaultValue;
- }
-
- // Leave it empty.
- return null;
- };
- }
-
- ///
- /// Gets a delegate for handling the result value of a method, converting it into the to return from the invocation.
- ///
- private static Func> GetReturnParameterMarshaller(MethodInfo method, JsonSerializerOptions serializerOptions)
- {
- Type returnType = method.ReturnType;
- JsonTypeInfo returnTypeInfo;
-
- // Void
- if (returnType == typeof(void))
- {
- return static (_, _) => new ValueTask((object?)null);
- }
-
- // Task
- if (returnType == typeof(Task))
- {
- return async static (result, _) =>
- {
- await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false);
- return null;
- };
- }
-
- // ValueTask
- if (returnType == typeof(ValueTask))
- {
- return async static (result, _) =>
- {
- await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(false);
- return null;
- };
- }
-
- if (returnType.IsGenericType)
- {
- // Task
- if (returnType.GetGenericTypeDefinition() == typeof(Task<>))
- {
- MethodInfo taskResultGetter = GetMethodFromGenericMethodDefinition(returnType, _taskGetResult);
- returnTypeInfo = serializerOptions.GetTypeInfo(taskResultGetter.ReturnType);
- return async (taskObj, cancellationToken) =>
- {
- await ((Task)ThrowIfNullResult(taskObj)).ConfigureAwait(false);
- object? result = ReflectionInvoke(taskResultGetter, taskObj, null);
- return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false);
- };
- }
-
- // ValueTask
- if (returnType.GetGenericTypeDefinition() == typeof(ValueTask<>))
- {
- MethodInfo valueTaskAsTask = GetMethodFromGenericMethodDefinition(returnType, _valueTaskAsTask);
- MethodInfo asTaskResultGetter = GetMethodFromGenericMethodDefinition(valueTaskAsTask.ReturnType, _taskGetResult);
- returnTypeInfo = serializerOptions.GetTypeInfo(asTaskResultGetter.ReturnType);
- return async (taskObj, cancellationToken) =>
- {
- var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(taskObj), null)!;
- await task.ConfigureAwait(false);
- object? result = ReflectionInvoke(asTaskResultGetter, task, null);
- return await SerializeResultAsync(result, returnTypeInfo, cancellationToken).ConfigureAwait(false);
- };
- }
- }
-
- // For everything else, just serialize the result as-is.
- returnTypeInfo = serializerOptions.GetTypeInfo(returnType);
- return (result, cancellationToken) => SerializeResultAsync(result, returnTypeInfo, cancellationToken);
-
- static async ValueTask SerializeResultAsync(object? result, JsonTypeInfo returnTypeInfo, CancellationToken cancellationToken)
- {
- if (returnTypeInfo.Kind is JsonTypeInfoKind.None)
- {
- // Special-case trivial contracts to avoid the more expensive general-purpose serialization path.
- return JsonSerializer.SerializeToElement(result, returnTypeInfo);
- }
-
- // Serialize asynchronously to support potential IAsyncEnumerable responses.
- using PooledMemoryStream stream = new();
-#if NET9_0_OR_GREATER
- await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken).ConfigureAwait(false);
- Utf8JsonReader reader = new(stream.GetBuffer());
- return JsonElement.ParseValue(ref reader);
-#else
- await JsonSerializer.SerializeAsync(stream, result, returnTypeInfo, cancellationToken).ConfigureAwait(false);
- stream.Position = 0;
- var serializerOptions = _defaultOptions.SerializerOptions ?? AIJsonUtilities.DefaultOptions;
- return await JsonSerializer.DeserializeAsync(stream, serializerOptions.GetTypeInfo(typeof(JsonElement)), cancellationToken).ConfigureAwait(false);
-#endif
- }
-
- // Throws an exception if a result is found to be null unexpectedly
- static object ThrowIfNullResult(object? result) => result ?? throw new InvalidOperationException("Function returned null unexpectedly.");
- }
-
- private static readonly MethodInfo _taskGetResult = typeof(Task<>).GetProperty(nameof(Task.Result), BindingFlags.Instance | BindingFlags.Public)!.GetMethod!;
- private static readonly MethodInfo _valueTaskAsTask = typeof(ValueTask<>).GetMethod(nameof(ValueTask.AsTask), BindingFlags.Instance | BindingFlags.Public)!;
-
- private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedType, MethodInfo genericMethodDefinition)
- {
- Debug.Assert(specializedType.IsGenericType && specializedType.GetGenericTypeDefinition() == genericMethodDefinition.DeclaringType, "generic member definition doesn't match type.");
-#if NET
- return (MethodInfo)specializedType.GetMemberWithSameMetadataDefinitionAs(genericMethodDefinition);
-#else
-#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
- const BindingFlags All = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance;
-#pragma warning restore S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
- return specializedType.GetMethods(All).First(m => m.MetadataToken == genericMethodDefinition.MetadataToken);
-#endif
- }
-
- private record struct DescriptorKey(MethodInfo Method, string? Name, string? Description, AIJsonSchemaCreateOptions SchemaOptions);
- }
-
- ///
- /// Removes characters from a .NET member name that shouldn't be used in an AI function name.
- ///
- /// The .NET member name that should be sanitized.
- ///
- /// Replaces non-alphanumeric characters in the identifier with the underscore character.
- /// Primarily intended to remove characters produced by compiler-generated method name mangling.
- ///
- internal static string SanitizeMemberName(string memberName)
- {
- Verify.NotNull(memberName);
- return InvalidNameCharsRegex().Replace(memberName, "_");
- }
-
- /// Regex that flags any character other than ASCII digits or letters or the underscore.
-#if NET
- [GeneratedRegex("[^0-9A-Za-z_]")]
- private static partial Regex InvalidNameCharsRegex();
-#else
- private static Regex InvalidNameCharsRegex() => _invalidNameCharsRegex;
- private static readonly Regex _invalidNameCharsRegex = new("[^0-9A-Za-z_]", RegexOptions.Compiled);
-#endif
-
- /// Invokes the MethodInfo with the specified target object and arguments.
- private static object? ReflectionInvoke(MethodInfo method, object? target, object?[]? arguments)
- {
-#if NET
- return method.Invoke(target, BindingFlags.DoNotWrapExceptions, binder: null, arguments, culture: null);
-#else
- try
- {
- return method.Invoke(target, BindingFlags.Default, binder: null, arguments, culture: null);
- }
- catch (TargetInvocationException e) when (e.InnerException is not null)
- {
- // If we're targeting .NET Framework, such that BindingFlags.DoNotWrapExceptions
- // is ignored, the original exception will be wrapped in a TargetInvocationException.
- // Unwrap it and throw that original exception, maintaining its stack information.
- System.Runtime.ExceptionServices.ExceptionDispatchInfo.Capture(e.InnerException).Throw();
- throw;
- }
-#endif
- }
-
- ///
- /// Implements a simple write-only memory stream that uses pooled buffers.
- ///
- private sealed class PooledMemoryStream : Stream
- {
- private const int DefaultBufferSize = 4096;
- private byte[] _buffer;
- private int _position;
-
- public PooledMemoryStream(int initialCapacity = DefaultBufferSize)
- {
- _buffer = ArrayPool.Shared.Rent(initialCapacity);
- _position = 0;
- }
-
- public ReadOnlySpan GetBuffer() => _buffer.AsSpan(0, _position);
- public override bool CanWrite => true;
- public override bool CanRead => false;
- public override bool CanSeek => false;
- public override long Length => _position;
- public override long Position
- {
- get => _position;
- set => throw new NotSupportedException();
- }
-
- public override void Write(byte[] buffer, int offset, int count)
- {
- EnsureNotDisposed();
- EnsureCapacity(_position + count);
-
- Buffer.BlockCopy(buffer, offset, _buffer, _position, count);
- _position += count;
- }
-
- public override void Flush()
- {
- }
-
- public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
- public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
- public override void SetLength(long value) => throw new NotSupportedException();
-
- protected override void Dispose(bool disposing)
- {
- if (_buffer is not null)
- {
- ArrayPool.Shared.Return(_buffer);
- _buffer = null!;
- }
-
- base.Dispose(disposing);
- }
-
- private void EnsureCapacity(int requiredCapacity)
- {
- if (requiredCapacity <= _buffer.Length)
- {
- return;
- }
-
- int newCapacity = Math.Max(requiredCapacity, _buffer.Length * 2);
- byte[] newBuffer = ArrayPool.Shared.Rent(newCapacity);
- Buffer.BlockCopy(_buffer, 0, newBuffer, 0, _position);
-
- ArrayPool.Shared.Return(_buffer);
- _buffer = newBuffer;
- }
-
- private void EnsureNotDisposed()
- {
- if (_buffer is null)
- {
- Throw();
- static void Throw() => throw new ObjectDisposedException(nameof(PooledMemoryStream));
- }
- }
- }
-}
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/FunctionFactoryOptions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/FunctionFactoryOptions.cs
deleted file mode 100644
index f9f43ee630ae..000000000000
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/FunctionFactoryOptions.cs
+++ /dev/null
@@ -1,63 +0,0 @@
-// Copyright (c) Microsoft. All rights reserved.
-
-using System;
-using System.Collections.Generic;
-using System.ComponentModel;
-using System.Diagnostics.CodeAnalysis;
-using System.Reflection;
-using System.Text.Json;
-using Microsoft.Extensions.AI;
-
-namespace Microsoft.SemanticKernel.ChatCompletion;
-
-// Slight modified source from
-// https://raw.githubusercontent.com/dotnet/extensions/refs/heads/main/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryOptions.cs
-
-///
-/// Represents options that can be provided when creating an from a method.
-///
-[ExcludeFromCodeCoverage]
-internal sealed class AIFunctionFactoryOptions
-{
- ///
- /// Initializes a new instance of the class.
- ///
- public AIFunctionFactoryOptions()
- {
- }
-
- /// Gets or sets the used to marshal .NET values being passed to the underlying delegate.
- ///
- /// If no value has been specified, the instance will be used.
- ///
- public JsonSerializerOptions? SerializerOptions { get; set; }
-
- ///
- /// Gets or sets the governing the generation of JSON schemas for the function.
- ///
- ///
- /// If no value has been specified, the instance will be used.
- ///
- public AIJsonSchemaCreateOptions? JsonSchemaCreateOptions { get; set; }
-
- /// Gets or sets the name to use for the function.
- ///
- /// The name to use for the function. The default value is a name derived from the method represented by the passed or .
- ///
- public string? Name { get; set; }
-
- /// Gets or sets the description to use for the function.
- ///
- /// The description for the function. The default value is a description derived from the passed or , if possible
- /// (for example, via a on the method).
- ///
- public string? Description { get; set; }
-
- ///
- /// Gets or sets additional values to store on the resulting property.
- ///
- ///
- /// This property can be used to provide arbitrary information about the function.
- ///
- public IReadOnlyDictionary? AdditionalProperties { get; set; }
-}
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
index 3a5eb88b5cf3..df7b9f107fa2 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
@@ -45,7 +45,7 @@ namespace Microsoft.SemanticKernel.ChatCompletion;
///
///
[ExcludeFromCodeCoverage]
-internal sealed partial class KernelFunctionInvokingChatClient : DelegatingChatClient
+internal sealed partial class KernelFunctionInvokingChatClientOld : DelegatingChatClient
{
/// The for the current function invocation.
private static readonly AsyncLocal s_currentContext = new();
@@ -65,7 +65,7 @@ internal sealed partial class KernelFunctionInvokingChatClient : DelegatingChatC
///
/// The underlying , or the next instance in a chain of clients.
/// An to use for logging information about function invocation.
- public KernelFunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null)
+ public KernelFunctionInvokingChatClientOld(IChatClient innerClient, ILogger? logger = null)
: base(innerClient)
{
this._logger = logger ?? NullLogger.Instance;
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientV2.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientV2.cs
new file mode 100644
index 000000000000..6978a01dd44b
--- /dev/null
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientV2.cs
@@ -0,0 +1,873 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Runtime.ExceptionServices;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging.Abstractions;
+using Microsoft.Shared.Diagnostics;
+
+#pragma warning disable CA2213 // Disposable fields should be disposed
+#pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test
+#pragma warning disable SA1202 // 'protected' members should come before 'private' members
+
+namespace Microsoft.Extensions.AI;
+
+///
+/// A delegating chat client that invokes functions defined on .
+/// Include this in a chat pipeline to resolve function calls automatically.
+///
+///
+///
+/// When this client receives a in a chat response, it responds
+/// by calling the corresponding defined in ,
+/// producing a that it sends back to the inner client. This loop
+/// is repeated until there are no more function calls to make, or until another stop condition is met,
+/// such as hitting .
+///
+///
+/// The provided implementation of is thread-safe for concurrent use so long as the
+/// instances employed as part of the supplied are also safe.
+/// The property can be used to control whether multiple function invocation
+/// requests as part of the same request are invocable concurrently, but even with that set to
+/// (the default), multiple concurrent requests to this same instance and using the same tools could result in those
+/// tools being used concurrently (one per request). For example, a function that accesses the HttpContext of a specific
+/// ASP.NET web request should only be used as part of a single at a time, and only with
+/// set to , in case the inner client decided to issue multiple
+/// invocation requests to that same function.
+///
+///
+public partial class FunctionInvokingChatClient : DelegatingChatClient
+{
+ /// The for the current function invocation.
+ private static readonly AsyncLocal _currentContext = new();
+
+ /// Optional services used for function invocation.
+ private readonly IServiceProvider? _functionInvocationServices;
+
+ /// The logger to use for logging information about function invocation.
+ private readonly ILogger _logger;
+
+ /// The to use for telemetry.
+ /// This component does not own the instance and should not dispose it.
+ private readonly ActivitySource? _activitySource;
+
+ /// Maximum number of roundtrips allowed to the inner client.
+ private int _maximumIterationsPerRequest = 10;
+
+ /// Maximum number of consecutive iterations that are allowed contain at least one exception result. If the limit is exceeded, we rethrow the exception instead of continuing.
+ private int _maximumConsecutiveErrorsPerRequest = 3;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The underlying , or the next instance in a chain of clients.
+ /// An to use for logging information about function invocation.
+ /// An optional to use for resolving services required by the instances being invoked.
+ public FunctionInvokingChatClient(IChatClient innerClient, ILoggerFactory? loggerFactory = null, IServiceProvider? functionInvocationServices = null)
+ : base(innerClient)
+ {
+ _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance;
+ _activitySource = innerClient.GetService();
+ _functionInvocationServices = functionInvocationServices;
+ }
+
+ ///
+ /// Gets or sets the for the current function invocation.
+ ///
+ ///
+ /// This value flows across async calls.
+ ///
+ public static FunctionInvocationContext? CurrentContext
+ {
+ get => _currentContext.Value;
+ protected set => _currentContext.Value = value;
+ }
+
+ ///
+ /// Gets or sets a value indicating whether detailed exception information should be included
+ /// in the chat history when calling the underlying .
+ ///
+ ///
+ /// if the full exception message is added to the chat history
+ /// when calling the underlying .
+ /// if a generic error message is included in the chat history.
+ /// The default value is .
+ ///
+ ///
+ ///
+ /// Setting the value to prevents the underlying language model from disclosing
+ /// raw exception details to the end user, since it doesn't receive that information. Even in this
+ /// case, the raw object is available to application code by inspecting
+ /// the property.
+ ///
+ ///
+ /// Setting the value to can help the underlying bypass problems on
+ /// its own, for example by retrying the function call with different arguments. However it might
+ /// result in disclosing the raw exception information to external users, which can be a security
+ /// concern depending on the application scenario.
+ ///
+ ///
+ /// Changing the value of this property while the client is in use might result in inconsistencies
+ /// as to whether detailed errors are provided during an in-flight request.
+ ///
+ ///
+ public bool IncludeDetailedErrors { get; set; }
+
+ ///
+ /// Gets or sets a value indicating whether to allow concurrent invocation of functions.
+ ///
+ ///
+ /// if multiple function calls can execute in parallel.
+ /// if function calls are processed serially.
+ /// The default value is .
+ ///
+ ///
+ /// An individual response from the inner client might contain multiple function call requests.
+ /// By default, such function calls are processed serially. Set to
+ /// to enable concurrent invocation such that multiple function calls can execute in parallel.
+ ///
+ public bool AllowConcurrentInvocation { get; set; }
+
+ ///
+ /// Gets or sets the maximum number of iterations per request.
+ ///
+ ///
+ /// The maximum number of iterations per request.
+ /// The default value is 10.
+ ///
+ ///
+ ///
+ /// Each request to this might end up making
+ /// multiple requests to the inner client. Each time the inner client responds with
+ /// a function call request, this client might perform that invocation and send the results
+ /// back to the inner client in a new request. This property limits the number of times
+ /// such a roundtrip is performed. The value must be at least one, as it includes the initial request.
+ ///
+ ///
+ /// Changing the value of this property while the client is in use might result in inconsistencies
+ /// as to how many iterations are allowed for an in-flight request.
+ ///
+ ///
+ public int MaximumIterationsPerRequest
+ {
+ get => _maximumIterationsPerRequest;
+ set
+ {
+ if (value < 1)
+ {
+ Throw.ArgumentOutOfRangeException(nameof(value));
+ }
+
+ _maximumIterationsPerRequest = value;
+ }
+ }
+
+ ///
+ /// Gets or sets the maximum number of consecutive iterations that are allowed to fail with an error.
+ ///
+ ///
+ /// The maximum number of consecutive iterations that are allowed to fail with an error.
+ /// The default value is 3.
+ ///
+ ///
+ ///
+ /// When function invocations fail with an exception, the
+ /// continues to make requests to the inner client, optionally supplying exception information (as
+ /// controlled by ). This allows the to
+ /// recover from errors by trying other function parameters that may succeed.
+ ///
+ ///
+ /// However, in case function invocations continue to produce exceptions, this property can be used to
+ /// limit the number of consecutive failing attempts. When the limit is reached, the exception will be
+ /// rethrown to the caller.
+ ///
+ ///
+ /// If the value is set to zero, all function calling exceptions immediately terminate the function
+ /// invocation loop and the exception will be rethrown to the caller.
+ ///
+ ///
+ /// Changing the value of this property while the client is in use might result in inconsistencies
+ /// as to how many iterations are allowed for an in-flight request.
+ ///
+ ///
+ public int MaximumConsecutiveErrorsPerRequest
+ {
+ get => _maximumConsecutiveErrorsPerRequest;
+ set => _maximumConsecutiveErrorsPerRequest = Throw.IfLessThan(value, 0);
+ }
+
+ ///
+ public override async Task GetResponseAsync(
+ IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
+ {
+ _ = Throw.IfNull(messages);
+
+ // A single request into this GetResponseAsync may result in multiple requests to the inner client.
+ // Create an activity to group them together for better observability.
+ using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient));
+
+ // Copy the original messages in order to avoid enumerating the original messages multiple times.
+ // The IEnumerable can represent an arbitrary amount of work.
+ List originalMessages = [.. messages];
+ messages = originalMessages;
+
+ List? augmentedHistory = null; // the actual history of messages sent on turns other than the first
+ ChatResponse? response = null; // the response from the inner client, which is possibly modified and then eventually returned
+ List? responseMessages = null; // tracked list of messages, across multiple turns, to be used for the final response
+ UsageDetails? totalUsage = null; // tracked usage across all turns, to be used for the final response
+ List? functionCallContents = null; // function call contents that need responding to in the current turn
+ bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set
+ int consecutiveErrorCount = 0;
+
+ for (int iteration = 0; ; iteration++)
+ {
+ functionCallContents?.Clear();
+
+ // Make the call to the inner client.
+ response = await base.GetResponseAsync(messages, options, cancellationToken);
+ if (response is null)
+ {
+ Throw.InvalidOperationException($"The inner {nameof(IChatClient)} returned a null {nameof(ChatResponse)}.");
+ }
+
+ // Any function call work to do? If yes, ensure we're tracking that work in functionCallContents.
+ bool requiresFunctionInvocation =
+ options?.Tools is { Count: > 0 } &&
+ iteration < MaximumIterationsPerRequest &&
+ CopyFunctionCalls(response.Messages, ref functionCallContents);
+
+ // In a common case where we make a request and there's no function calling work required,
+ // fast path out by just returning the original response.
+ if (iteration == 0 && !requiresFunctionInvocation)
+ {
+ return response;
+ }
+
+ // Track aggregatable details from the response, including all of the response messages and usage details.
+ (responseMessages ??= []).AddRange(response.Messages);
+ if (response.Usage is not null)
+ {
+ if (totalUsage is not null)
+ {
+ totalUsage.Add(response.Usage);
+ }
+ else
+ {
+ totalUsage = response.Usage;
+ }
+ }
+
+ // If there are no tools to call, or for any other reason we should stop, we're done.
+ // Break out of the loop and allow the handling at the end to configure the response
+ // with aggregated data from previous requests.
+ if (!requiresFunctionInvocation)
+ {
+ break;
+ }
+
+ // Prepare the history for the next iteration.
+ FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId);
+
+ // Add the responses from the function calls into the augmented history and also into the tracked
+ // list of response messages.
+ var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, cancellationToken);
+ responseMessages.AddRange(modeAndMessages.MessagesAdded);
+ consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;
+
+ if (modeAndMessages.ShouldTerminate)
+ {
+ break;
+ }
+
+ UpdateOptionsForNextIteration(ref options!, response.ChatThreadId);
+ }
+
+ Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages.");
+ response.Messages = responseMessages!;
+ response.Usage = totalUsage;
+
+ return response;
+ }
+
+ ///
+ public override async IAsyncEnumerable GetStreamingResponseAsync(
+ IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ _ = Throw.IfNull(messages);
+
+ // A single request into this GetStreamingResponseAsync may result in multiple requests to the inner client.
+ // Create an activity to group them together for better observability.
+ using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient));
+
+ // Copy the original messages in order to avoid enumerating the original messages multiple times.
+ // The IEnumerable can represent an arbitrary amount of work.
+ List originalMessages = [.. messages];
+ messages = originalMessages;
+
+ List? augmentedHistory = null; // the actual history of messages sent on turns other than the first
+ List? functionCallContents = null; // function call contents that need responding to in the current turn
+ List? responseMessages = null; // tracked list of messages, across multiple turns, to be used in fallback cases to reconstitute history
+ bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set
+ List updates = []; // updates from the current response
+ int consecutiveErrorCount = 0;
+
+ for (int iteration = 0; ; iteration++)
+ {
+ updates.Clear();
+ functionCallContents?.Clear();
+
+ await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken))
+ {
+ if (update is null)
+ {
+ Throw.InvalidOperationException($"The inner {nameof(IChatClient)} streamed a null {nameof(ChatResponseUpdate)}.");
+ }
+
+ updates.Add(update);
+
+ _ = CopyFunctionCalls(update.Contents, ref functionCallContents);
+
+ yield return update;
+ Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802
+ }
+
+ // If there are no tools to call, or for any other reason we should stop, return the response.
+ if (functionCallContents is not { Count: > 0 } ||
+ options?.Tools is not { Count: > 0 } ||
+ iteration >= _maximumIterationsPerRequest)
+ {
+ break;
+ }
+
+ // Reconsistitue a response from the response updates.
+ var response = updates.ToChatResponse();
+ (responseMessages ??= []).AddRange(response.Messages);
+
+ // Prepare the history for the next iteration.
+ FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId);
+
+ // Process all of the functions, adding their results into the history.
+ var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, cancellationToken);
+ responseMessages.AddRange(modeAndMessages.MessagesAdded);
+ consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;
+
+ // This is a synthetic ID since we're generating the tool messages instead of getting them from
+ // the underlying provider. When emitting the streamed chunks, it's perfectly valid for us to
+ // use the same message ID for all of them within a given iteration, as this is a single logical
+ // message with multiple content items. We could also use different message IDs per tool content,
+ // but there's no benefit to doing so.
+ string toolResponseId = Guid.NewGuid().ToString("N");
+
+ // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages
+ // includes all activitys, including generated function results.
+ foreach (var message in modeAndMessages.MessagesAdded)
+ {
+ var toolResultUpdate = new ChatResponseUpdate
+ {
+ AdditionalProperties = message.AdditionalProperties,
+ AuthorName = message.AuthorName,
+ ChatThreadId = response.ChatThreadId,
+ CreatedAt = DateTimeOffset.UtcNow,
+ Contents = message.Contents,
+ RawRepresentation = message.RawRepresentation,
+ ResponseId = toolResponseId,
+ MessageId = toolResponseId, // See above for why this can be the same as ResponseId
+ Role = message.Role,
+ };
+
+ yield return toolResultUpdate;
+ Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802
+ }
+
+ if (modeAndMessages.ShouldTerminate)
+ {
+ yield break;
+ }
+
+ UpdateOptionsForNextIteration(ref options, response.ChatThreadId);
+ }
+ }
+
+ /// Prepares the various chat message lists after a response from the inner client and before invoking functions.
+ /// The original messages provided by the caller.
+ /// The messages reference passed to the inner client.
+ /// The augmented history containing all the messages to be sent.
+ /// The most recent response being handled.
+ /// A list of all response messages received up until this point.
+ /// Whether the previous iteration's response had a thread id.
+ private static void FixupHistories(
+ IEnumerable originalMessages,
+ ref IEnumerable messages,
+ [NotNull] ref List? augmentedHistory,
+ ChatResponse response,
+ List allTurnsResponseMessages,
+ ref bool lastIterationHadThreadId)
+ {
+ // We're now going to need to augment the history with function result contents.
+ // That means we need a separate list to store the augmented history.
+ if (response.ChatThreadId is not null)
+ {
+ // The response indicates the inner client is tracking the history, so we don't want to send
+ // anything we've already sent or received.
+ if (augmentedHistory is not null)
+ {
+ augmentedHistory.Clear();
+ }
+ else
+ {
+ augmentedHistory = [];
+ }
+
+ lastIterationHadThreadId = true;
+ }
+ else if (lastIterationHadThreadId)
+ {
+ // In the very rare case where the inner client returned a response with a thread ID but then
+ // returned a subsequent response without one, we want to reconstitue the full history. To do that,
+ // we can populate the history with the original chat messages and then all of the response
+ // messages up until this point, which includes the most recent ones.
+ augmentedHistory ??= [];
+ augmentedHistory.Clear();
+ augmentedHistory.AddRange(originalMessages);
+ augmentedHistory.AddRange(allTurnsResponseMessages);
+
+ lastIterationHadThreadId = false;
+ }
+ else
+ {
+ // If augmentedHistory is already non-null, then we've already populated it with everything up
+ // until this point (except for the most recent response). If it's null, we need to seed it with
+ // the chat history provided by the caller.
+ augmentedHistory ??= originalMessages.ToList();
+
+ // Now add the most recent response messages.
+ augmentedHistory.AddMessages(response);
+
+ lastIterationHadThreadId = false;
+ }
+
+ // Use the augmented history as the new set of messages to send.
+ messages = augmentedHistory;
+ }
+
+ /// Copies any from to .
+ private static bool CopyFunctionCalls(
+ IList messages, [NotNullWhen(true)] ref List? functionCalls)
+ {
+ bool any = false;
+ int count = messages.Count;
+ for (int i = 0; i < count; i++)
+ {
+ any |= CopyFunctionCalls(messages[i].Contents, ref functionCalls);
+ }
+
+ return any;
+ }
+
+ /// Copies any from to .
+ private static bool CopyFunctionCalls(
+ IList content, [NotNullWhen(true)] ref List? functionCalls)
+ {
+ bool any = false;
+ int count = content.Count;
+ for (int i = 0; i < count; i++)
+ {
+ if (content[i] is FunctionCallContent functionCall)
+ {
+ (functionCalls ??= []).Add(functionCall);
+ any = true;
+ }
+ }
+
+ return any;
+ }
+
+ private static void UpdateOptionsForNextIteration(ref ChatOptions options, string? chatThreadId)
+ {
+ if (options.ToolMode is RequiredChatToolMode)
+ {
+ // We have to reset the tool mode to be non-required after the first iteration,
+ // as otherwise we'll be in an infinite loop.
+ options = options.Clone();
+ options.ToolMode = null;
+ options.ChatThreadId = chatThreadId;
+ }
+ else if (options.ChatThreadId != chatThreadId)
+ {
+ // As with the other modes, ensure we've propagated the chat thread ID to the options.
+ // We only need to clone the options if we're actually mutating it.
+ options = options.Clone();
+ options.ChatThreadId = chatThreadId;
+ }
+ }
+
+ ///
+ /// Processes the function calls in the list.
+ ///
+ /// The current chat contents, inclusive of the function call contents being processed.
+ /// The options used for the response being processed.
+ /// The function call contents representing the functions to be invoked.
+ /// The iteration number of how many roundtrips have been made to the inner client.
+ /// The number of consecutive iterations, prior to this one, that were recorded as having function invocation errors.
+ /// The to monitor for cancellation requests.
+ /// A value indicating how the caller should proceed.
+ private async Task<(bool ShouldTerminate, int NewConsecutiveErrorCount, IList MessagesAdded)> ProcessFunctionCallsAsync(
+ List messages, ChatOptions options, List functionCallContents, int iteration, int consecutiveErrorCount, CancellationToken cancellationToken)
+ {
+ // We must add a response for every tool call, regardless of whether we successfully executed it or not.
+ // If we successfully execute it, we'll add the result. If we don't, we'll add an error.
+
+ Debug.Assert(functionCallContents.Count > 0, "Expecteded at least one function call.");
+
+ var captureCurrentIterationExceptions = consecutiveErrorCount < _maximumConsecutiveErrorsPerRequest;
+
+ // Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel.
+ if (functionCallContents.Count == 1)
+ {
+ FunctionInvocationResult result = await ProcessFunctionCallAsync(
+ messages, options, functionCallContents, iteration, 0, captureCurrentIterationExceptions, cancellationToken);
+
+ IList added = CreateResponseMessages([result]);
+ ThrowIfNoFunctionResultsAdded(added);
+ UpdateConsecutiveErrorCountOrThrow(added, ref consecutiveErrorCount);
+
+ messages.AddRange(added);
+ return (result.ShouldTerminate, consecutiveErrorCount, added);
+ }
+ else
+ {
+ FunctionInvocationResult[] results;
+
+ if (AllowConcurrentInvocation)
+ {
+ // Rather than await'ing each function before invoking the next, invoke all of them
+ // and then await all of them. We avoid forcibly introducing parallelism via Task.Run,
+ // but if a function invocation completes asynchronously, its processing can overlap
+ // with the processing of other the other invocation invocations.
+ results = await Task.WhenAll(
+ from i in Enumerable.Range(0, functionCallContents.Count)
+ select ProcessFunctionCallAsync(
+ messages, options, functionCallContents,
+ iteration, i, captureExceptions: true, cancellationToken));
+ }
+ else
+ {
+ // Invoke each function serially.
+ results = new FunctionInvocationResult[functionCallContents.Count];
+ for (int i = 0; i < results.Length; i++)
+ {
+ results[i] = await ProcessFunctionCallAsync(
+ messages, options, functionCallContents,
+ iteration, i, captureCurrentIterationExceptions, cancellationToken);
+ }
+ }
+
+ var shouldTerminate = false;
+
+ IList added = CreateResponseMessages(results);
+ ThrowIfNoFunctionResultsAdded(added);
+ UpdateConsecutiveErrorCountOrThrow(added, ref consecutiveErrorCount);
+
+ messages.AddRange(added);
+ foreach (FunctionInvocationResult fir in results)
+ {
+ shouldTerminate = shouldTerminate || fir.ShouldTerminate;
+ }
+
+ return (shouldTerminate, consecutiveErrorCount, added);
+ }
+ }
+
+ private void UpdateConsecutiveErrorCountOrThrow(IList added, ref int consecutiveErrorCount)
+ {
+ var allExceptions = added.SelectMany(m => m.Contents.OfType())
+ .Select(frc => frc.Exception!)
+ .Where(e => e is not null);
+
+ if (allExceptions.Any())
+ {
+ consecutiveErrorCount++;
+ if (consecutiveErrorCount > _maximumConsecutiveErrorsPerRequest)
+ {
+ var allExceptionsArray = allExceptions.ToArray();
+ if (allExceptionsArray.Length == 1)
+ {
+ ExceptionDispatchInfo.Capture(allExceptionsArray[0]).Throw();
+ }
+ else
+ {
+ throw new AggregateException(allExceptionsArray);
+ }
+ }
+ }
+ else
+ {
+ consecutiveErrorCount = 0;
+ }
+ }
+
+ ///
+ /// Throws an exception if doesn't create any messages.
+ ///
+ private void ThrowIfNoFunctionResultsAdded(IList? messages)
+ {
+ if (messages is null || messages.Count == 0)
+ {
+ Throw.InvalidOperationException($"{GetType().Name}.{nameof(CreateResponseMessages)} returned null or an empty collection of messages.");
+ }
+ }
+
+ /// Processes the function call described in [].
+ /// The current chat contents, inclusive of the function call contents being processed.
+ /// The options used for the response being processed.
+ /// The function call contents representing all the functions being invoked.
+ /// The iteration number of how many roundtrips have been made to the inner client.
+ /// The 0-based index of the function being called out of .
+ /// If true, handles function-invocation exceptions by returning a value with . Otherwise, rethrows.
+ /// The to monitor for cancellation requests.
+ /// A value indicating how the caller should proceed.
+ private async Task ProcessFunctionCallAsync(
+ List messages, ChatOptions options, List callContents,
+ int iteration, int functionCallIndex, bool captureExceptions, CancellationToken cancellationToken)
+ {
+ var callContent = callContents[functionCallIndex];
+
+ // Look up the AIFunction for the function call. If the requested function isn't available, send back an error.
+ AIFunction? function = options.Tools!.OfType().FirstOrDefault(t => t.Name == callContent.Name);
+ if (function is null)
+ {
+ return new(shouldTerminate: false, FunctionInvocationStatus.NotFound, callContent, result: null, exception: null);
+ }
+
+ FunctionInvocationContext context = new()
+ {
+ Function = function,
+ Arguments = new(callContent.Arguments) { Services = _functionInvocationServices },
+
+ Messages = messages,
+ Options = options,
+
+ CallContent = callContent,
+ Iteration = iteration,
+ FunctionCallIndex = functionCallIndex,
+ FunctionCount = callContents.Count,
+ };
+
+ object? result;
+ try
+ {
+ result = await InvokeFunctionAsync(context, cancellationToken);
+ }
+ catch (Exception e) when (!cancellationToken.IsCancellationRequested)
+ {
+ if (!captureExceptions)
+ {
+ throw;
+ }
+
+ return new(
+ shouldTerminate: false,
+ FunctionInvocationStatus.Exception,
+ callContent,
+ result: null,
+ exception: e);
+ }
+
+ return new(
+ shouldTerminate: context.Terminate,
+ FunctionInvocationStatus.RanToCompletion,
+ callContent,
+ result,
+ exception: null);
+ }
+
+ /// Creates one or more response messages for function invocation results.
+ /// Information about the function call invocations and results.
+ /// A list of all chat messages created from .
+ protected virtual IList CreateResponseMessages(
+ ReadOnlySpan results)
+ {
+ var contents = new List(results.Length);
+ for (int i = 0; i < results.Length; i++)
+ {
+ contents.Add(CreateFunctionResultContent(results[i]));
+ }
+
+ return [new(ChatRole.Tool, contents)];
+
+ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult result)
+ {
+ _ = Throw.IfNull(result);
+
+ object? functionResult;
+ if (result.Status == FunctionInvocationStatus.RanToCompletion)
+ {
+ functionResult = result.Result ?? "Success: Function completed.";
+ }
+ else
+ {
+ string message = result.Status switch
+ {
+ FunctionInvocationStatus.NotFound => $"Error: Requested function \"{result.CallContent.Name}\" not found.",
+ FunctionInvocationStatus.Exception => "Error: Function failed.",
+ _ => "Error: Unknown error.",
+ };
+
+ if (IncludeDetailedErrors && result.Exception is not null)
+ {
+ message = $"{message} Exception: {result.Exception.Message}";
+ }
+
+ functionResult = message;
+ }
+
+ return new FunctionResultContent(result.CallContent.CallId, functionResult) { Exception = result.Exception };
+ }
+ }
+
+ /// Invokes the function asynchronously.
+ ///
+ /// The function invocation context detailing the function to be invoked and its arguments along with additional request information.
+ ///
+ /// The to monitor for cancellation requests. The default is .
+ /// The result of the function invocation, or if the function invocation returned .
+ /// is .
+ protected virtual async Task InvokeFunctionAsync(FunctionInvocationContext context, CancellationToken cancellationToken)
+ {
+ _ = Throw.IfNull(context);
+
+ using Activity? activity = _activitySource?.StartActivity(context.Function.Name);
+
+ long startingTimestamp = 0;
+ if (_logger.IsEnabled(LogLevel.Debug))
+ {
+ startingTimestamp = Stopwatch.GetTimestamp();
+ if (_logger.IsEnabled(LogLevel.Trace))
+ {
+ LogInvokingSensitive(context.Function.Name, LoggingHelpers.AsJson(context.Arguments, context.Function.JsonSerializerOptions));
+ }
+ else
+ {
+ LogInvoking(context.Function.Name);
+ }
+ }
+
+ object? result = null;
+ try
+ {
+ CurrentContext = context; // doesn't need to be explicitly reset after, as that's handled automatically at async method exit
+ result = await context.Function.InvokeAsync(context.Arguments, cancellationToken);
+ }
+ catch (Exception e)
+ {
+ if (activity is not null)
+ {
+ _ = activity.SetTag("error.type", e.GetType().FullName)
+ .SetStatus(ActivityStatusCode.Error, e.Message);
+ }
+
+ if (e is OperationCanceledException)
+ {
+ LogInvocationCanceled(context.Function.Name);
+ }
+ else
+ {
+ LogInvocationFailed(context.Function.Name, e);
+ }
+
+ throw;
+ }
+ finally
+ {
+ if (_logger.IsEnabled(LogLevel.Debug))
+ {
+ TimeSpan elapsed = GetElapsedTime(startingTimestamp);
+
+ if (result is not null && _logger.IsEnabled(LogLevel.Trace))
+ {
+ LogInvocationCompletedSensitive(context.Function.Name, elapsed, LoggingHelpers.AsJson(result, context.Function.JsonSerializerOptions));
+ }
+ else
+ {
+ LogInvocationCompleted(context.Function.Name, elapsed);
+ }
+ }
+ }
+
+ return result;
+ }
+
+ private static TimeSpan GetElapsedTime(long startingTimestamp) =>
+#if NET
+ Stopwatch.GetElapsedTime(startingTimestamp);
+#else
+ new((long)((Stopwatch.GetTimestamp() - startingTimestamp) * ((double)TimeSpan.TicksPerSecond / Stopwatch.Frequency)));
+#endif
+
+ [LoggerMessage(LogLevel.Debug, "Invoking {MethodName}.", SkipEnabledCheck = true)]
+ private partial void LogInvoking(string methodName);
+
+ [LoggerMessage(LogLevel.Trace, "Invoking {MethodName}({Arguments}).", SkipEnabledCheck = true)]
+ private partial void LogInvokingSensitive(string methodName, string arguments);
+
+ [LoggerMessage(LogLevel.Debug, "{MethodName} invocation completed. Duration: {Duration}", SkipEnabledCheck = true)]
+ private partial void LogInvocationCompleted(string methodName, TimeSpan duration);
+
+ [LoggerMessage(LogLevel.Trace, "{MethodName} invocation completed. Duration: {Duration}. Result: {Result}", SkipEnabledCheck = true)]
+ private partial void LogInvocationCompletedSensitive(string methodName, TimeSpan duration, string result);
+
+ [LoggerMessage(LogLevel.Debug, "{MethodName} invocation canceled.")]
+ private partial void LogInvocationCanceled(string methodName);
+
+ [LoggerMessage(LogLevel.Error, "{MethodName} invocation failed.")]
+ private partial void LogInvocationFailed(string methodName, Exception error);
+
+ /// Provides information about the invocation of a function call.
+ public sealed class FunctionInvocationResult
+ {
+ internal FunctionInvocationResult(bool shouldTerminate, FunctionInvocationStatus status, FunctionCallContent callContent, object? result, Exception? exception)
+ {
+ ShouldTerminate = shouldTerminate;
+ Status = status;
+ CallContent = callContent;
+ Result = result;
+ Exception = exception;
+ }
+
+ /// Gets status about how the function invocation completed.
+ public FunctionInvocationStatus Status { get; }
+
+ /// Gets the function call content information associated with this invocation.
+ public FunctionCallContent CallContent { get; }
+
+ /// Gets the result of the function call.
+ public object? Result { get; }
+
+ /// Gets any exception the function call threw.
+ public Exception? Exception { get; }
+
+ /// Gets a value indicating whether the caller should terminate the processing loop.
+ internal bool ShouldTerminate { get; }
+ }
+
+ /// Provides error codes for when errors occur as part of the function calling loop.
+ public enum FunctionInvocationStatus
+ {
+ /// The operation completed successfully.
+ RanToCompletion,
+
+ /// The requested function could not be found.
+ NotFound,
+
+ /// The function call failed with an exception.
+ Exception,
+ }
+}
diff --git a/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj b/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj
index 47043cbe1df8..4826426418a6 100644
--- a/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj
+++ b/dotnet/src/SemanticKernel.Abstractions/SemanticKernel.Abstractions.csproj
@@ -29,7 +29,7 @@
-
+
From 3d56b9564712731846bc5949aa8408ccfae2c190 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Mon, 14 Apr 2025 12:20:14 +0100
Subject: [PATCH 20/26] Using Extensions AI + Latest logic for
FunctionINvokingChatClient
---
...IServiceCollectionExtensions.ChatClient.cs | 12 +-
.../AI/ChatClient/ChatClientExtensions.cs | 6 +-
.../KernelFunctionInvocationContext.cs | 79 +-
.../KernelFunctionInvokingChatClient.cs | 496 +++++-----
.../KernelFunctionInvokingChatClientV2.cs | 873 ------------------
5 files changed, 331 insertions(+), 1135 deletions(-)
delete mode 100644 dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientV2.cs
diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs
index 7fa002d806ea..2954e958936a 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIServiceCollectionExtensions.ChatClient.cs
@@ -47,12 +47,12 @@ public static IServiceCollection AddOpenAIChatClient(
IChatClient Factory(IServiceProvider serviceProvider, object? _)
{
- ILogger? logger = serviceProvider.GetService()?.CreateLogger();
+ var loggerFactory = serviceProvider.GetService();
return new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), options: GetClientOptions(orgId: orgId, httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider)))
.GetChatClient(modelId)
.AsIChatClient()
- .AsKernelFunctionInvokingChatClient(logger);
+ .AsKernelFunctionInvokingChatClient(loggerFactory);
}
services.AddKeyedSingleton(serviceId, (Func)Factory);
@@ -77,12 +77,12 @@ public static IServiceCollection AddOpenAIChatClient(this IServiceCollection ser
IChatClient Factory(IServiceProvider serviceProvider, object? _)
{
- ILogger? logger = serviceProvider.GetService()?.CreateLogger();
+ var loggerFactory = serviceProvider.GetService();
return (openAIClient ?? serviceProvider.GetRequiredService())
.GetChatClient(modelId)
.AsIChatClient()
- .AsKernelFunctionInvokingChatClient(logger);
+ .AsKernelFunctionInvokingChatClient(loggerFactory);
}
services.AddKeyedSingleton(serviceId, (Func)Factory);
@@ -114,12 +114,12 @@ public static IServiceCollection AddOpenAIChatClient(
IChatClient Factory(IServiceProvider serviceProvider, object? _)
{
- ILogger? logger = serviceProvider.GetService()?.CreateLogger();
+ var loggerFactory = serviceProvider.GetService();
return new OpenAIClient(new ApiKeyCredential(apiKey ?? SingleSpace), GetClientOptions(endpoint, orgId, HttpClientProvider.GetHttpClient(httpClient, serviceProvider)))
.GetChatClient(modelId)
.AsIChatClient()
- .AsKernelFunctionInvokingChatClient(logger);
+ .AsKernelFunctionInvokingChatClient(loggerFactory);
}
services.AddKeyedSingleton(serviceId, (Func)Factory);
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
index 0ffdd5fec99d..223530beb1ba 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
@@ -92,16 +92,16 @@ public static IChatCompletionService AsChatCompletionService(this IChatClient cl
/// Creates a new that supports for function invocation with a .
///
/// Target chat client service.
- /// Optional logger to use for logging.
+ /// Optional logger to use for logging.
/// Function invoking chat client.
[Experimental("SKEXP0001")]
- public static IChatClient AsKernelFunctionInvokingChatClient(this IChatClient client, ILogger? logger = null)
+ public static IChatClient AsKernelFunctionInvokingChatClient(this IChatClient client, ILoggerFactory? loggerFactory = null)
{
Verify.NotNull(client);
return client is KernelFunctionInvokingChatClient kernelFunctionInvocationClient
? kernelFunctionInvocationClient
- : new KernelFunctionInvokingChatClient(client, logger);
+ : new KernelFunctionInvokingChatClient(client, loggerFactory);
}
private static ChatOptions GetChatOptionsFromSettings(PromptExecutionSettings? executionSettings, Kernel? kernel)
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
index c8f293b2bb0c..efc3d39ab48a 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
@@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
using Microsoft.Extensions.AI;
#pragma warning disable IDE0009 // Use explicit 'this.' qualifier
@@ -20,63 +21,61 @@ internal class KernelFunctionInvocationContext
{
///
/// A nop function used to allow to be non-nullable. Default instances of
- /// start with this as the target function.
+ /// start with this as the target function.
///
- private static readonly AIFunction s_nopFunction = AIFunctionFactory.Create(() => { }, nameof(KernelFunctionInvocationContext));
+ private static readonly AIFunction _nopFunction = AIFunctionFactory.Create(() => { }, nameof(FunctionInvocationContext));
/// The chat contents associated with the operation that initiated this function call request.
private IList _messages = Array.Empty();
/// The AI function to be invoked.
- private AIFunction _function = s_nopFunction;
+ private AIFunction _function = _nopFunction;
/// The function call content information associated with this invocation.
- private Microsoft.Extensions.AI.FunctionCallContent _callContent = new(string.Empty, s_nopFunction.Name, EmptyReadOnlyDictionary.Instance);
+ private Microsoft.Extensions.AI.FunctionCallContent? _callContent;
- /// Initializes a new instance of the class.
- internal KernelFunctionInvocationContext()
+ /// The arguments used with the function.
+ private AIFunctionArguments? _arguments;
+
+ /// Initializes a new instance of the class.
+ public KernelFunctionInvocationContext()
+ {
+ }
+
+ /// Gets or sets the AI function to be invoked.
+ public AIFunction Function
+ {
+ get => _function;
+ set => _function = Throw.IfNull(value);
+ }
+
+ /// Gets or sets the arguments associated with this invocation.
+ public AIFunctionArguments Arguments
{
+ get => _arguments ??= [];
+ set => _arguments = Throw.IfNull(value);
}
/// Gets or sets the function call content information associated with this invocation.
public Microsoft.Extensions.AI.FunctionCallContent CallContent
{
- get => _callContent;
- set
- {
- Verify.NotNull(value);
- _callContent = value;
- }
+ get => _callContent ??= new(string.Empty, _nopFunction.Name, EmptyReadOnlyDictionary.Instance);
+ set => _callContent = Throw.IfNull(value);
}
/// Gets or sets the chat contents associated with the operation that initiated this function call request.
public IList Messages
{
get => _messages;
- set
- {
- Verify.NotNull(value);
- _messages = value;
- }
+ set => _messages = Throw.IfNull(value);
}
/// Gets or sets the chat options associated with the operation that initiated this function call request.
public ChatOptions? Options { get; set; }
- /// Gets or sets the AI function to be invoked.
- public AIFunction Function
- {
- get => _function;
- set
- {
- Verify.NotNull(value);
- _function = value;
- }
- }
-
/// Gets or sets the number of this iteration with the underlying client.
///
- /// The initial request to the client that passes along the chat contents provided to the
+ /// The initial request to the client that passes along the chat contents provided to the
/// is iteration 1. If the client responds with a function call request, the next request to the client is iteration 2, and so on.
///
public int Iteration { get; set; }
@@ -103,4 +102,26 @@ public AIFunction Function
/// more function call requests in responses.
///
public bool Terminate { get; set; }
+
+ private static class Throw
+ {
+ ///
+ /// Throws an if the specified argument is .
+ ///
+ /// Argument type to be checked for .
+ /// Object to be checked for .
+ /// The name of the parameter being checked.
+ /// The original value of .
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [return: NotNull]
+ public static T IfNull([NotNull] T argument, [CallerArgumentExpression(nameof(argument))] string paramName = "")
+ {
+ if (argument is null)
+ {
+ throw new ArgumentNullException(paramName);
+ }
+
+ return argument;
+ }
+ }
}
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
index df7b9f107fa2..e6496cc63f4f 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
@@ -1,26 +1,31 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
+#pragma warning restore IDE0073 // The file header does not match the required text
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
+using System.Runtime.ExceptionServices;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
-using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.ChatCompletion;
+#pragma warning disable IDE1006 // Naming Styles
+#pragma warning disable IDE0009 // This
#pragma warning disable CA2213 // Disposable fields should be disposed
-#pragma warning disable IDE0009 // Use explicit 'this.' qualifier
-#pragma warning disable IDE1006 // Missing prefix: 's_'
+#pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test
+#pragma warning disable SA1202 // 'protected' members should come before 'private' members
-namespace Microsoft.SemanticKernel.ChatCompletion;
+// Modified source from 2025-04-07
+// https://raw.githubusercontent.com/dotnet/extensions/84d09b794d994435568adcbb85a981143d4f15cb/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs
-// Slight modified source from
-// https://raw.githubusercontent.com/dotnet/extensions/refs/heads/main/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs
+namespace Microsoft.Extensions.AI;
///
/// A delegating chat client that invokes functions defined on .
@@ -28,9 +33,11 @@ namespace Microsoft.SemanticKernel.ChatCompletion;
///
///
///
-/// When this client receives a in a chat response, it responds
-/// by calling the corresponding defined in ,
-/// producing a .
+/// When this client receives a in a chat response, it responds
+/// by calling the corresponding defined in ,
+/// producing a that it sends back to the inner client. This loop
+/// is repeated until there are no more function calls to make, or until another stop condition is met,
+/// such as hitting .
///
///
/// The provided implementation of is thread-safe for concurrent use so long as the
@@ -44,11 +51,13 @@ namespace Microsoft.SemanticKernel.ChatCompletion;
/// invocation requests to that same function.
///
///
-[ExcludeFromCodeCoverage]
-internal sealed partial class KernelFunctionInvokingChatClientOld : DelegatingChatClient
+public partial class KernelFunctionInvokingChatClient : DelegatingChatClient
{
/// The for the current function invocation.
- private static readonly AsyncLocal s_currentContext = new();
+ private static readonly AsyncLocal _currentContext = new();
+
+ /// Optional services used for function invocation.
+ private readonly IServiceProvider? _functionInvocationServices;
/// The logger to use for logging information about function invocation.
private readonly ILogger _logger;
@@ -58,49 +67,37 @@ internal sealed partial class KernelFunctionInvokingChatClientOld : DelegatingCh
private readonly ActivitySource? _activitySource;
/// Maximum number of roundtrips allowed to the inner client.
- private int? _maximumIterationsPerRequest;
+ private int _maximumIterationsPerRequest = 10;
+
+ /// Maximum number of consecutive iterations that are allowed contain at least one exception result. If the limit is exceeded, we rethrow the exception instead of continuing.
+ private int _maximumConsecutiveErrorsPerRequest = 3;
///
/// Initializes a new instance of the class.
///
/// The underlying , or the next instance in a chain of clients.
- /// An to use for logging information about function invocation.
- public KernelFunctionInvokingChatClientOld(IChatClient innerClient, ILogger? logger = null)
+ /// An to use for logging information about function invocation.
+ /// An optional to use for resolving services required by the instances being invoked.
+ public KernelFunctionInvokingChatClient(IChatClient innerClient, ILoggerFactory? loggerFactory = null, IServiceProvider? functionInvocationServices = null)
: base(innerClient)
{
- this._logger = logger ?? NullLogger.Instance;
- this._activitySource = innerClient.GetService();
+ _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance;
+ _activitySource = innerClient.GetService();
+ _functionInvocationServices = functionInvocationServices;
}
///
- /// Gets or sets the for the current function invocation.
+ /// Gets or sets the for the current function invocation.
///
///
/// This value flows across async calls.
///
- internal static AutoFunctionInvocationContext? CurrentContext
+ public static AutoFunctionInvocationContext? CurrentContext
{
- get => s_currentContext.Value;
- set => s_currentContext.Value = value;
+ get => _currentContext.Value;
+ protected set => _currentContext.Value = value;
}
- ///
- /// Gets or sets a value indicating whether to handle exceptions that occur during function calls.
- ///
- ///
- /// if the
- /// underlying will be instructed to give a response without invoking
- /// any further functions if a function call fails with an exception.
- /// if the underlying is allowed
- /// to continue attempting function calls until is reached.
- /// The default value is .
- ///
- ///
- /// Changing the value of this property while the client is in use might result in inconsistencies
- /// whether errors are retried during an in-flight request.
- ///
- public bool RetryOnError { get; set; }
-
///
/// Gets or sets a value indicating whether detailed exception information should be included
/// in the chat history when calling the underlying .
@@ -116,17 +113,17 @@ internal static AutoFunctionInvocationContext? CurrentContext
/// Setting the value to prevents the underlying language model from disclosing
/// raw exception details to the end user, since it doesn't receive that information. Even in this
/// case, the raw object is available to application code by inspecting
- /// the property.
+ /// the property.
///
///
/// Setting the value to can help the underlying bypass problems on
- /// its own, for example by retrying the function call with different arguments. However, it might
+ /// its own, for example by retrying the function call with different arguments. However it might
/// result in disclosing the raw exception information to external users, which can be a security
/// concern depending on the application scenario.
///
///
/// Changing the value of this property while the client is in use might result in inconsistencies
- /// whether detailed errors are provided during an in-flight request.
+ /// as to whether detailed errors are provided during an in-flight request.
///
///
public bool IncludeDetailedErrors { get; set; }
@@ -151,25 +148,24 @@ internal static AutoFunctionInvocationContext? CurrentContext
///
///
/// The maximum number of iterations per request.
- /// The default value is .
+ /// The default value is 10.
///
///
///
- /// Each request to this might end up making
+ /// Each request to this might end up making
/// multiple requests to the inner client. Each time the inner client responds with
/// a function call request, this client might perform that invocation and send the results
/// back to the inner client in a new request. This property limits the number of times
- /// such a roundtrip is performed. If null, there is no limit applied. If set, the value
- /// must be at least one, as it includes the initial request.
+ /// such a roundtrip is performed. The value must be at least one, as it includes the initial request.
///
///
/// Changing the value of this property while the client is in use might result in inconsistencies
/// as to how many iterations are allowed for an in-flight request.
///
///
- public int? MaximumIterationsPerRequest
+ public int MaximumIterationsPerRequest
{
- get => this._maximumIterationsPerRequest;
+ get => _maximumIterationsPerRequest;
set
{
if (value < 1)
@@ -177,7 +173,48 @@ public int? MaximumIterationsPerRequest
throw new ArgumentOutOfRangeException(nameof(value));
}
- this._maximumIterationsPerRequest = value;
+ _maximumIterationsPerRequest = value;
+ }
+ }
+
+ ///
+ /// Gets or sets the maximum number of consecutive iterations that are allowed to fail with an error.
+ ///
+ ///
+ /// The maximum number of consecutive iterations that are allowed to fail with an error.
+ /// The default value is 3.
+ ///
+ ///
+ ///
+ /// When function invocations fail with an exception, the
+ /// continues to make requests to the inner client, optionally supplying exception information (as
+ /// controlled by ). This allows the to
+ /// recover from errors by trying other function parameters that may succeed.
+ ///
+ ///
+ /// However, in case function invocations continue to produce exceptions, this property can be used to
+ /// limit the number of consecutive failing attempts. When the limit is reached, the exception will be
+ /// rethrown to the caller.
+ ///
+ ///
+ /// If the value is set to zero, all function calling exceptions immediately terminate the function
+ /// invocation loop and the exception will be rethrown to the caller.
+ ///
+ ///
+ /// Changing the value of this property while the client is in use might result in inconsistencies
+ /// as to how many iterations are allowed for an in-flight request.
+ ///
+ ///
+ public int MaximumConsecutiveErrorsPerRequest
+ {
+ get => _maximumConsecutiveErrorsPerRequest;
+ set
+ {
+ if (value < 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(value), "Argument less than minimum value 0");
+ }
+ _maximumConsecutiveErrorsPerRequest = value;
}
}
@@ -189,7 +226,7 @@ public override async Task GetResponseAsync(
// A single request into this GetResponseAsync may result in multiple requests to the inner client.
// Create an activity to group them together for better observability.
- using Activity? activity = this._activitySource?.StartActivity(nameof(KernelFunctionInvokingChatClient));
+ using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient));
// Copy the original messages in order to avoid enumerating the original messages multiple times.
// The IEnumerable can represent an arbitrary amount of work.
@@ -200,8 +237,9 @@ public override async Task GetResponseAsync(
ChatResponse? response = null; // the response from the inner client, which is possibly modified and then eventually returned
List? responseMessages = null; // tracked list of messages, across multiple turns, to be used for the final response
UsageDetails? totalUsage = null; // tracked usage across all turns, to be used for the final response
- List? functionCallContents = null; // function call contents that need responding to in the current turn
+ List? functionCallContents = null; // function call contents that need responding to in the current turn
bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set
+ int consecutiveErrorCount = 0;
for (int iteration = 0; ; iteration++)
{
@@ -217,7 +255,7 @@ public override async Task GetResponseAsync(
// Any function call work to do? If yes, ensure we're tracking that work in functionCallContents.
bool requiresFunctionInvocation =
options?.Tools is { Count: > 0 } &&
- (!this.MaximumIterationsPerRequest.HasValue || iteration < this.MaximumIterationsPerRequest.GetValueOrDefault()) &&
+ iteration < MaximumIterationsPerRequest &&
CopyFunctionCalls(response.Messages, ref functionCallContents);
// In a common case where we make a request and there's no function calling work required,
@@ -227,7 +265,7 @@ public override async Task GetResponseAsync(
return response;
}
- // Track aggregate details from the response, including all the response messages and usage details.
+ // Track aggregatable details from the response, including all of the response messages and usage details.
(responseMessages ??= []).AddRange(response.Messages);
if (response.Usage is not null)
{
@@ -257,21 +295,23 @@ public override async Task GetResponseAsync(
// Add the responses from the function calls into the augmented history and also into the tracked
// list of response messages.
- var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents!, iteration, isStreaming: false, cancellationToken).ConfigureAwait(false);
+ var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, isStreaming: false, cancellationToken).ConfigureAwait(false);
responseMessages.AddRange(modeAndMessages.MessagesAdded);
+ consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;
- // Clear the auto function invocation options.
- ClearOptionsForAutoFunctionInvocation(ref options);
-
- if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, response.ChatThreadId))
+ if (modeAndMessages.ShouldTerminate)
{
- // Terminate
break;
}
+
+ // Clear the auto function invocation options.
+ ClearOptionsForAutoFunctionInvocation(ref options);
+
+ UpdateOptionsForNextIteration(ref options!, response.ChatThreadId);
}
Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages.");
- response.Messages = responseMessages;
+ response.Messages = responseMessages!;
response.Usage = totalUsage;
return response;
@@ -285,7 +325,7 @@ public override async IAsyncEnumerable GetStreamingResponseA
// A single request into this GetStreamingResponseAsync may result in multiple requests to the inner client.
// Create an activity to group them together for better observability.
- using Activity? activity = this._activitySource?.StartActivity(nameof(KernelFunctionInvokingChatClient));
+ using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient));
// Copy the original messages in order to avoid enumerating the original messages multiple times.
// The IEnumerable can represent an arbitrary amount of work.
@@ -293,10 +333,11 @@ public override async IAsyncEnumerable GetStreamingResponseA
messages = originalMessages;
List? augmentedHistory = null; // the actual history of messages sent on turns other than the first
- List? functionCallContents = null; // function call contents that need responding to in the current turn
+ List? functionCallContents = null; // function call contents that need responding to in the current turn
List? responseMessages = null; // tracked list of messages, across multiple turns, to be used in fallback cases to reconstitute history
bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set
List updates = []; // updates from the current response
+ int consecutiveErrorCount = 0;
for (int iteration = 0; ; iteration++)
{
@@ -321,12 +362,12 @@ public override async IAsyncEnumerable GetStreamingResponseA
// If there are no tools to call, or for any other reason we should stop, return the response.
if (functionCallContents is not { Count: > 0 } ||
options?.Tools is not { Count: > 0 } ||
- (this.MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations))
+ iteration >= _maximumIterationsPerRequest)
{
break;
}
- // Reconstitute a response from the response updates.
+ // Reconsistitue a response from the response updates.
var response = updates.ToChatResponse();
(responseMessages ??= []).AddRange(response.Messages);
@@ -336,16 +377,23 @@ public override async IAsyncEnumerable GetStreamingResponseA
// Prepare the options for the next auto function invocation iteration.
UpdateOptionsForAutoFunctionInvocation(ref options, response.Messages.Last().ToChatMessageContent(), isStreaming: true);
- // Process all the functions, adding their results into the history.
- var modeAndMessages = await this.ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, isStreaming: true, cancellationToken).ConfigureAwait(false);
+ // Process all of the functions, adding their results into the history.
+ var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, isStreaming: true, cancellationToken).ConfigureAwait(false);
responseMessages.AddRange(modeAndMessages.MessagesAdded);
+ consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;
// Clear the auto function invocation options.
ClearOptionsForAutoFunctionInvocation(ref options);
- // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages
- // include all activities, including generated function results.
+ // This is a synthetic ID since we're generating the tool messages instead of getting them from
+ // the underlying provider. When emitting the streamed chunks, it's perfectly valid for us to
+ // use the same message ID for all of them within a given iteration, as this is a single logical
+ // message with multiple content items. We could also use different message IDs per tool content,
+ // but there's no benefit to doing so.
string toolResponseId = Guid.NewGuid().ToString("N");
+
+ // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages
+ // includes all activitys, including generated function results.
foreach (var message in modeAndMessages.MessagesAdded)
{
var toolResultUpdate = new ChatResponseUpdate
@@ -357,6 +405,7 @@ public override async IAsyncEnumerable GetStreamingResponseA
Contents = message.Contents,
RawRepresentation = message.RawRepresentation,
ResponseId = toolResponseId,
+ MessageId = toolResponseId, // See above for why this can be the same as ResponseId
Role = message.Role,
};
@@ -364,11 +413,12 @@ public override async IAsyncEnumerable GetStreamingResponseA
Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802
}
- if (UpdateOptionsForMode(modeAndMessages.Mode, ref options, response.ChatThreadId))
+ if (modeAndMessages.ShouldTerminate)
{
- // Terminate
yield break;
}
+
+ UpdateOptionsForNextIteration(ref options, response.ChatThreadId);
}
}
@@ -407,8 +457,8 @@ private static void FixupHistories(
else if (lastIterationHadThreadId)
{
// In the very rare case where the inner client returned a response with a thread ID but then
- // returned a subsequent response without one, we want to reconstitute the full history. To do that,
- // we can populate the history with the original chat messages and then all the response
+ // returned a subsequent response without one, we want to reconstitue the full history. To do that,
+ // we can populate the history with the original chat messages and then all of the response
// messages up until this point, which includes the most recent ones.
augmentedHistory ??= [];
augmentedHistory.Clear();
@@ -436,7 +486,7 @@ private static void FixupHistories(
/// Copies any from to .
private static bool CopyFunctionCalls(
- IList messages, [NotNullWhen(true)] ref List? functionCalls)
+ IList messages, [NotNullWhen(true)] ref List? functionCalls)
{
bool any = false;
int count = messages.Count;
@@ -448,15 +498,15 @@ private static bool CopyFunctionCalls(
return any;
}
- /// Copies any from to .
+ /// Copies any from to .
private static bool CopyFunctionCalls(
- IList content, [NotNullWhen(true)] ref List? functionCalls)
+ IList content, [NotNullWhen(true)] ref List? functionCalls)
{
bool any = false;
int count = content.Count;
for (int i = 0; i < count; i++)
{
- if (content[i] is Microsoft.Extensions.AI.FunctionCallContent functionCall)
+ if (content[i] is FunctionCallContent functionCall)
{
(functionCalls ??= []).Add(functionCall);
any = true;
@@ -497,47 +547,23 @@ private static void ClearOptionsForAutoFunctionInvocation(ref ChatOptions option
}
}
- /// Updates for the response.
- /// true if the function calling loop should terminate; otherwise, false.
- private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions options, string? chatThreadId)
+ private static void UpdateOptionsForNextIteration(ref ChatOptions options, string? chatThreadId)
{
- switch (mode)
+ if (options.ToolMode is RequiredChatToolMode)
{
- case ContinueMode.Continue when options.ToolMode is RequiredChatToolMode:
- // We have to reset the tool mode to be non-required after the first iteration,
- // as otherwise we'll be in an infinite loop.
- options = options.Clone();
- options.ToolMode = null;
- options.ChatThreadId = chatThreadId;
-
- break;
-
- case ContinueMode.AllowOneMoreRoundtrip:
- // The LLM gets one further chance to answer, but cannot use tools.
- options = options.Clone();
- options.Tools = null;
- options.ToolMode = null;
- options.ChatThreadId = chatThreadId;
-
- break;
-
- case ContinueMode.Terminate:
- // Bail immediately.
- return true;
-
- default:
- // As with the other modes, ensure we've propagated the chat thread ID to the options.
- // We only need to clone the options if we're actually mutating it.
- if (options.ChatThreadId != chatThreadId)
- {
- options = options.Clone();
- options.ChatThreadId = chatThreadId;
- }
-
- break;
+ // We have to reset the tool mode to be non-required after the first iteration,
+ // as otherwise we'll be in an infinite loop.
+ options = options.Clone();
+ options.ToolMode = null;
+ options.ChatThreadId = chatThreadId;
+ }
+ else if (options.ChatThreadId != chatThreadId)
+ {
+ // As with the other modes, ensure we've propagated the chat thread ID to the options.
+ // We only need to clone the options if we're actually mutating it.
+ options = options.Clone();
+ options.ChatThreadId = chatThreadId;
}
-
- return false;
}
///
@@ -547,85 +573,122 @@ private static bool UpdateOptionsForMode(ContinueMode mode, ref ChatOptions opti
/// The options used for the response being processed.
/// The function call contents representing the functions to be invoked.
/// The iteration number of how many roundtrips have been made to the inner client.
+ /// The number of consecutive iterations, prior to this one, that were recorded as having function invocation errors.
/// Whether the function calls are being processed in a streaming context.
/// The to monitor for cancellation requests.
- /// A value indicating how the caller should proceed.
- private async Task<(ContinueMode Mode, IList MessagesAdded)> ProcessFunctionCallsAsync(
- List messages, ChatOptions options, List functionCallContents, int iteration, bool isStreaming, CancellationToken cancellationToken)
+ /// A value indicating how the caller should proceed.
+ private async Task<(bool ShouldTerminate, int NewConsecutiveErrorCount, IList MessagesAdded)> ProcessFunctionCallsAsync(
+ List messages, ChatOptions options, List functionCallContents,
+ int iteration, int consecutiveErrorCount, bool isStreaming, CancellationToken cancellationToken)
{
// We must add a response for every tool call, regardless of whether we successfully executed it or not.
// If we successfully execute it, we'll add the result. If we don't, we'll add an error.
- Debug.Assert(functionCallContents.Count > 0, "Expected at least one function call.");
- ContinueMode continueMode = ContinueMode.Continue;
+ Debug.Assert(functionCallContents.Count > 0, "Expecteded at least one function call.");
+ var shouldTerminate = false;
+
+ var captureCurrentIterationExceptions = consecutiveErrorCount < _maximumConsecutiveErrorsPerRequest;
// Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel.
if (functionCallContents.Count == 1)
{
- FunctionInvocationResult result = await this.ProcessFunctionCallAsync(
- messages, options, functionCallContents, iteration, 0, isStreaming, cancellationToken).ConfigureAwait(false);
+ FunctionInvocationResult result = await ProcessFunctionCallAsync(
+ messages, options, functionCallContents, iteration, 0, captureCurrentIterationExceptions, isStreaming, cancellationToken).ConfigureAwait(false);
- IList added = this.CreateResponseMessages([result]);
- this.ThrowIfNoFunctionResultsAdded(added);
+ IList added = CreateResponseMessages([result]);
+ ThrowIfNoFunctionResultsAdded(added);
+ UpdateConsecutiveErrorCountOrThrow(added, ref consecutiveErrorCount);
messages.AddRange(added);
- return (result.ContinueMode, added);
+ return (result.ShouldTerminate, consecutiveErrorCount, added);
}
else
{
List results = [];
var terminationRequested = false;
- if (this.AllowConcurrentInvocation)
+ if (AllowConcurrentInvocation)
{
- // Schedule the invocation of every function.
+ // Rather than await'ing each function before invoking the next, invoke all of them
+ // and then await all of them. We avoid forcibly introducing parallelism via Task.Run,
+ // but if a function invocation completes asynchronously, its processing can overlap
+ // with the processing of other the other invocation invocations.
results.AddRange(await Task.WhenAll(
from i in Enumerable.Range(0, functionCallContents.Count)
- select Task.Run(() => this.ProcessFunctionCallAsync(
+ select ProcessFunctionCallAsync(
messages, options, functionCallContents,
- iteration, i, isStreaming, cancellationToken))).ConfigureAwait(false));
+ iteration, i, captureExceptions: true, isStreaming, cancellationToken)).ConfigureAwait(false));
- terminationRequested = results.Any(r => r.ContinueMode == ContinueMode.Terminate);
+ terminationRequested = results.Any(r => r.ShouldTerminate);
}
else
{
// Invoke each function serially.
for (int i = 0; i < functionCallContents.Count; i++)
{
- var result = await this.ProcessFunctionCallAsync(
+ var result = await ProcessFunctionCallAsync(
messages, options, functionCallContents,
- iteration, i, isStreaming, cancellationToken).ConfigureAwait(false);
+ iteration, i, captureCurrentIterationExceptions, isStreaming, cancellationToken).ConfigureAwait(false);
results.Add(result);
- if (result.ContinueMode == ContinueMode.Terminate)
+ if (result.ShouldTerminate)
{
- continueMode = ContinueMode.Terminate;
+ shouldTerminate = true;
terminationRequested = true;
break;
}
}
}
- IList added = this.CreateResponseMessages(results);
- this.ThrowIfNoFunctionResultsAdded(added);
+ IList added = CreateResponseMessages(results);
+ ThrowIfNoFunctionResultsAdded(added);
+ UpdateConsecutiveErrorCountOrThrow(added, ref consecutiveErrorCount);
+
messages.AddRange(added);
if (!terminationRequested)
{
// If any function requested termination, we'll terminate.
- continueMode = ContinueMode.Continue;
+ shouldTerminate = false;
foreach (FunctionInvocationResult fir in results)
{
- if (fir.ContinueMode > continueMode)
- {
- continueMode = fir.ContinueMode;
- }
+ shouldTerminate = shouldTerminate || fir.ShouldTerminate;
}
}
- return (continueMode, added);
+ return (shouldTerminate, consecutiveErrorCount, added);
+ }
+ }
+
+ private void UpdateConsecutiveErrorCountOrThrow(IList added, ref int consecutiveErrorCount)
+ {
+ var allExceptions = added.SelectMany(m => m.Contents.OfType())
+ .Select(frc => frc.Exception!)
+ .Where(e => e is not null);
+
+#pragma warning disable CA1851 // Possible multiple enumerations of 'IEnumerable' collection
+ if (allExceptions.Any())
+ {
+ consecutiveErrorCount++;
+ if (consecutiveErrorCount > _maximumConsecutiveErrorsPerRequest)
+ {
+ var allExceptionsArray = allExceptions.ToArray();
+ if (allExceptionsArray.Length == 1)
+ {
+ ExceptionDispatchInfo.Capture(allExceptionsArray[0]).Throw();
+ }
+ else
+ {
+ throw new AggregateException(allExceptionsArray);
+ }
+ }
}
+ else
+ {
+ consecutiveErrorCount = 0;
+ }
+#pragma warning restore CA1851 // Possible multiple enumerations of 'IEnumerable' collection
}
///
@@ -645,21 +708,21 @@ private void ThrowIfNoFunctionResultsAdded(IList? messages)
/// The function call contents representing all the functions being invoked.
/// The iteration number of how many roundtrips have been made to the inner client.
/// The 0-based index of the function being called out of .
+ /// If true, handles function-invocation exceptions by returning a value with . Otherwise, rethrows.
/// Whether the function calls are being processed in a streaming context.
/// The to monitor for cancellation requests.
- /// A value indicating how the caller should proceed.
+ /// A value indicating how the caller should proceed.
private async Task ProcessFunctionCallAsync(
- List messages, ChatOptions options, List callContents,
- int iteration, int functionCallIndex, bool isStreaming, CancellationToken cancellationToken)
+ List messages, ChatOptions options, List callContents,
+ int iteration, int functionCallIndex, bool captureExceptions, bool isStreaming, CancellationToken cancellationToken)
{
var callContent = callContents[functionCallIndex];
// Look up the AIFunction for the function call. If the requested function isn't available, send back an error.
- AIFunction? function = options.Tools!.OfType().FirstOrDefault(
- t => t.Name == callContent.Name);
+ AIFunction? function = options.Tools!.OfType().FirstOrDefault(t => t.Name == callContent.Name);
if (function is null)
{
- return new(ContinueMode.Continue, FunctionInvocationStatus.NotFound, callContent, result: null, exception: null);
+ return new(shouldTerminate: false, FunctionInvokingChatClient.FunctionInvocationStatus.NotFound, callContent, result: null, exception: null);
}
if (callContent.Arguments is not null)
@@ -667,73 +730,69 @@ private async Task ProcessFunctionCallAsync(
callContent.Arguments = new KernelArguments(callContent.Arguments);
}
- var context = new AutoFunctionInvocationContext(new KernelFunctionInvocationContext
+ var context = new AutoFunctionInvocationContext(new()
{
- Options = options,
+ Function = function,
+ Arguments = new(callContent.Arguments) { Services = _functionInvocationServices },
+
Messages = messages,
+ Options = options,
+
CallContent = callContent,
- Function = function,
Iteration = iteration,
FunctionCallIndex = functionCallIndex,
FunctionCount = callContents.Count,
- }) { IsStreaming = isStreaming };
+ })
+ { IsStreaming = isStreaming };
object? result;
try
{
- result = await this.InvokeFunctionAsync(context, cancellationToken).ConfigureAwait(false);
+ result = await InvokeFunctionAsync(context, cancellationToken).ConfigureAwait(false);
}
catch (Exception e) when (!cancellationToken.IsCancellationRequested)
{
- return new(this.RetryOnError ? ContinueMode.Continue : ContinueMode.AllowOneMoreRoundtrip, // We won't allow further function calls, hence the LLM will just get one more chance to give a final answer.
- FunctionInvocationStatus.Exception,
+ if (!captureExceptions)
+ {
+ throw;
+ }
+
+ return new(
+ shouldTerminate: false,
+ FunctionInvokingChatClient.FunctionInvocationStatus.Exception,
callContent,
result: null,
exception: e);
}
return new(
- context.Terminate ? ContinueMode.Terminate : ContinueMode.Continue,
- FunctionInvocationStatus.RanToCompletion,
+ shouldTerminate: context.Terminate,
+ FunctionInvokingChatClient.FunctionInvocationStatus.RanToCompletion,
callContent,
result,
exception: null);
}
- /// Represents the return value of , dictating how the loop should behave.
- /// These values are ordered from least severe to most severe, and code explicitly depends on the ordering.
- internal enum ContinueMode
- {
- /// Send back the responses and continue processing.
- Continue = 0,
-
- /// Send back the response but without any tools.
- AllowOneMoreRoundtrip = 1,
-
- /// Immediately exit the function calling loop.
- Terminate = 2,
- }
-
/// Creates one or more response messages for function invocation results.
/// Information about the function call invocations and results.
/// A list of all chat messages created from .
- internal IList CreateResponseMessages(
+ protected virtual IList CreateResponseMessages(
IReadOnlyList results)
{
var contents = new List(results.Count);
- foreach (var t in results)
+ for (int i = 0; i < results.Count; i++)
{
- contents.Add(CreateFunctionResultContent(t));
+ contents.Add(CreateFunctionResultContent(results[i]));
}
return [new(ChatRole.Tool, contents)];
- Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult result)
+ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult result)
{
Verify.NotNull(result);
object? functionResult;
- if (result.Status == FunctionInvocationStatus.RanToCompletion)
+ if (result.Status == FunctionInvokingChatClient.FunctionInvocationStatus.RanToCompletion)
{
functionResult = result.Result ?? "Success: Function completed.";
}
@@ -741,12 +800,12 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi
{
string message = result.Status switch
{
- FunctionInvocationStatus.NotFound => $"Error: Requested function \"{result.CallContent.Name}\" not found.",
- FunctionInvocationStatus.Exception => "Error: Function failed.",
+ FunctionInvokingChatClient.FunctionInvocationStatus.NotFound => $"Error: Requested function \"{result.CallContent.Name}\" not found.",
+ FunctionInvokingChatClient.FunctionInvocationStatus.Exception => "Error: Function failed.",
_ => "Error: Unknown error.",
};
- if (this.IncludeDetailedErrors && result.Exception is not null)
+ if (IncludeDetailedErrors && result.Exception is not null)
{
message = $"{message} Exception: {result.Exception.Message}";
}
@@ -754,7 +813,7 @@ Microsoft.Extensions.AI.FunctionResultContent CreateFunctionResultContent(Functi
functionResult = message;
}
- return new Microsoft.Extensions.AI.FunctionResultContent(result.CallContent.CallId, functionResult) { Exception = result.Exception };
+ return new FunctionResultContent(result.CallContent.CallId, functionResult) { Exception = result.Exception };
}
}
@@ -801,40 +860,42 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
}
/// Invokes the function asynchronously.
- /// The function invocation context detailing the function to be invoked and its arguments along with additional request information.
+ ///
+ /// The function invocation context detailing the function to be invoked and its arguments along with additional request information.
+ ///
/// The to monitor for cancellation requests. The default is .
/// The result of the function invocation, or if the function invocation returned .
- /// is .
- private async Task InvokeFunctionAsync(AutoFunctionInvocationContext invocationContext, CancellationToken cancellationToken)
+ /// is .
+ protected virtual async Task InvokeFunctionAsync(AutoFunctionInvocationContext context, CancellationToken cancellationToken)
{
- Verify.NotNull(invocationContext);
+ Verify.NotNull(context);
- using Activity? activity = this._activitySource?.StartActivity(invocationContext.Function.Name);
+ using Activity? activity = _activitySource?.StartActivity(context.Function.Name);
long startingTimestamp = 0;
- if (this._logger.IsEnabled(LogLevel.Debug))
+ if (_logger.IsEnabled(LogLevel.Debug))
{
startingTimestamp = Stopwatch.GetTimestamp();
- if (this._logger.IsEnabled(LogLevel.Trace))
+ if (_logger.IsEnabled(LogLevel.Trace))
{
- this.LogInvokingSensitive(invocationContext.Function.Name, LoggingAsJson(invocationContext.CallContent.Arguments, invocationContext.AIFunction.JsonSerializerOptions));
+ LogInvokingSensitive(context.Function.Name, LoggingAsJson(context.CallContent.Arguments, context.AIFunction.JsonSerializerOptions));
}
else
{
- this.LogInvoking(invocationContext.Function.Name);
+ LogInvoking(context.Function.Name);
}
}
object? result = null;
try
{
- CurrentContext = invocationContext;
- invocationContext = await this.OnAutoFunctionInvocationAsync(
- invocationContext,
- async (context) =>
+ CurrentContext = context; // doesn't need to be explicitly reset after, as that's handled automatically at async method exit
+ context = await this.OnAutoFunctionInvocationAsync(
+ context,
+ async (ctx) =>
{
// Check if filter requested termination
- if (context.Terminate)
+ if (ctx.Terminate)
{
return;
}
@@ -842,10 +903,10 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
// Note that we explicitly do not use executionSettings here; those pertain to the all-up operation and not necessarily to any
// further calls made as part of this function invocation. In particular, we must not use function calling settings naively here,
// as the called function could in turn telling the model about itself as a possible candidate for invocation.
- result = await invocationContext.AIFunction.InvokeAsync(new(invocationContext.Arguments), cancellationToken).ConfigureAwait(false);
- context.Result = new FunctionResult(context.Function, result);
+ result = await context.AIFunction.InvokeAsync(new(context.Arguments), cancellationToken).ConfigureAwait(false);
+ ctx.Result = new FunctionResult(ctx.Function, result);
}).ConfigureAwait(false);
- result = invocationContext.Result.GetValue();
+ result = context.Result.GetValue();
}
catch (Exception e)
{
@@ -857,28 +918,28 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
if (e is OperationCanceledException)
{
- this.LogInvocationCanceled(invocationContext.Function.Name);
+ LogInvocationCanceled(context.Function.Name);
}
else
{
- this.LogInvocationFailed(invocationContext.Function.Name, e);
+ LogInvocationFailed(context.Function.Name, e);
}
throw;
}
finally
{
- if (this._logger.IsEnabled(LogLevel.Debug))
+ if (_logger.IsEnabled(LogLevel.Debug))
{
TimeSpan elapsed = GetElapsedTime(startingTimestamp);
- if (result is not null && this._logger.IsEnabled(LogLevel.Trace))
+ if (result is not null && _logger.IsEnabled(LogLevel.Trace))
{
- this.LogInvocationCompletedSensitive(invocationContext.Function.Name, elapsed, LoggingAsJson(result, invocationContext.AIFunction.JsonSerializerOptions));
+ LogInvocationCompletedSensitive(context.Function.Name, elapsed, LoggingAsJson(result, context.AIFunction.JsonSerializerOptions));
}
else
{
- this.LogInvocationCompleted(invocationContext.Function.Name, elapsed);
+ LogInvocationCompleted(context.Function.Name, elapsed);
}
}
}
@@ -936,20 +997,20 @@ private static TimeSpan GetElapsedTime(long startingTimestamp) =>
/// Provides information about the invocation of a function call.
public sealed class FunctionInvocationResult
{
- internal FunctionInvocationResult(ContinueMode continueMode, FunctionInvocationStatus status, Microsoft.Extensions.AI.FunctionCallContent callContent, object? result, Exception? exception)
+ internal FunctionInvocationResult(bool shouldTerminate, FunctionInvokingChatClient.FunctionInvocationStatus status, FunctionCallContent callContent, object? result, Exception? exception)
{
- this.ContinueMode = continueMode;
- this.Status = status;
- this.CallContent = callContent;
- this.Result = result;
- this.Exception = exception;
+ ShouldTerminate = shouldTerminate;
+ Status = status;
+ CallContent = callContent;
+ Result = result;
+ Exception = exception;
}
/// Gets status about how the function invocation completed.
- public FunctionInvocationStatus Status { get; }
+ public FunctionInvokingChatClient.FunctionInvocationStatus Status { get; }
/// Gets the function call content information associated with this invocation.
- public Microsoft.Extensions.AI.FunctionCallContent CallContent { get; }
+ public FunctionCallContent CallContent { get; }
/// Gets the result of the function call.
public object? Result { get; }
@@ -957,20 +1018,7 @@ internal FunctionInvocationResult(ContinueMode continueMode, FunctionInvocationS
/// Gets any exception the function call threw.
public Exception? Exception { get; }
- /// Gets an indication for how the caller should continue the processing loop.
- internal ContinueMode ContinueMode { get; }
- }
-
- /// Provides error codes for when errors occur as part of the function calling loop.
- public enum FunctionInvocationStatus
- {
- /// The operation completed successfully.
- RanToCompletion,
-
- /// The requested function could not be found.
- NotFound,
-
- /// The function call failed with an exception.
- Exception,
+ /// Gets a value indicating whether the caller should terminate the processing loop.
+ internal bool ShouldTerminate { get; }
}
}
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientV2.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientV2.cs
deleted file mode 100644
index 6978a01dd44b..000000000000
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClientV2.cs
+++ /dev/null
@@ -1,873 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-
-using System;
-using System.Collections.Generic;
-using System.Diagnostics;
-using System.Diagnostics.CodeAnalysis;
-using System.Linq;
-using System.Runtime.CompilerServices;
-using System.Runtime.ExceptionServices;
-using System.Threading;
-using System.Threading.Tasks;
-using Microsoft.Extensions.Logging;
-using Microsoft.Extensions.Logging.Abstractions;
-using Microsoft.Shared.Diagnostics;
-
-#pragma warning disable CA2213 // Disposable fields should be disposed
-#pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test
-#pragma warning disable SA1202 // 'protected' members should come before 'private' members
-
-namespace Microsoft.Extensions.AI;
-
-///
-/// A delegating chat client that invokes functions defined on .
-/// Include this in a chat pipeline to resolve function calls automatically.
-///
-///
-///
-/// When this client receives a in a chat response, it responds
-/// by calling the corresponding defined in ,
-/// producing a that it sends back to the inner client. This loop
-/// is repeated until there are no more function calls to make, or until another stop condition is met,
-/// such as hitting .
-///
-///
-/// The provided implementation of is thread-safe for concurrent use so long as the
-/// instances employed as part of the supplied are also safe.
-/// The property can be used to control whether multiple function invocation
-/// requests as part of the same request are invocable concurrently, but even with that set to
-/// (the default), multiple concurrent requests to this same instance and using the same tools could result in those
-/// tools being used concurrently (one per request). For example, a function that accesses the HttpContext of a specific
-/// ASP.NET web request should only be used as part of a single at a time, and only with
-/// set to , in case the inner client decided to issue multiple
-/// invocation requests to that same function.
-///
-///
-public partial class FunctionInvokingChatClient : DelegatingChatClient
-{
- /// The for the current function invocation.
- private static readonly AsyncLocal _currentContext = new();
-
- /// Optional services used for function invocation.
- private readonly IServiceProvider? _functionInvocationServices;
-
- /// The logger to use for logging information about function invocation.
- private readonly ILogger _logger;
-
- /// The to use for telemetry.
- /// This component does not own the instance and should not dispose it.
- private readonly ActivitySource? _activitySource;
-
- /// Maximum number of roundtrips allowed to the inner client.
- private int _maximumIterationsPerRequest = 10;
-
- /// Maximum number of consecutive iterations that are allowed contain at least one exception result. If the limit is exceeded, we rethrow the exception instead of continuing.
- private int _maximumConsecutiveErrorsPerRequest = 3;
-
- ///
- /// Initializes a new instance of the class.
- ///
- /// The underlying , or the next instance in a chain of clients.
- /// An to use for logging information about function invocation.
- /// An optional to use for resolving services required by the instances being invoked.
- public FunctionInvokingChatClient(IChatClient innerClient, ILoggerFactory? loggerFactory = null, IServiceProvider? functionInvocationServices = null)
- : base(innerClient)
- {
- _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance;
- _activitySource = innerClient.GetService();
- _functionInvocationServices = functionInvocationServices;
- }
-
- ///
- /// Gets or sets the for the current function invocation.
- ///
- ///
- /// This value flows across async calls.
- ///
- public static FunctionInvocationContext? CurrentContext
- {
- get => _currentContext.Value;
- protected set => _currentContext.Value = value;
- }
-
- ///
- /// Gets or sets a value indicating whether detailed exception information should be included
- /// in the chat history when calling the underlying .
- ///
- ///
- /// if the full exception message is added to the chat history
- /// when calling the underlying .
- /// if a generic error message is included in the chat history.
- /// The default value is .
- ///
- ///
- ///
- /// Setting the value to prevents the underlying language model from disclosing
- /// raw exception details to the end user, since it doesn't receive that information. Even in this
- /// case, the raw object is available to application code by inspecting
- /// the property.
- ///
- ///
- /// Setting the value to can help the underlying bypass problems on
- /// its own, for example by retrying the function call with different arguments. However it might
- /// result in disclosing the raw exception information to external users, which can be a security
- /// concern depending on the application scenario.
- ///
- ///
- /// Changing the value of this property while the client is in use might result in inconsistencies
- /// as to whether detailed errors are provided during an in-flight request.
- ///
- ///
- public bool IncludeDetailedErrors { get; set; }
-
- ///
- /// Gets or sets a value indicating whether to allow concurrent invocation of functions.
- ///
- ///
- /// if multiple function calls can execute in parallel.
- /// if function calls are processed serially.
- /// The default value is .
- ///
- ///
- /// An individual response from the inner client might contain multiple function call requests.
- /// By default, such function calls are processed serially. Set to
- /// to enable concurrent invocation such that multiple function calls can execute in parallel.
- ///
- public bool AllowConcurrentInvocation { get; set; }
-
- ///
- /// Gets or sets the maximum number of iterations per request.
- ///
- ///
- /// The maximum number of iterations per request.
- /// The default value is 10.
- ///
- ///
- ///
- /// Each request to this might end up making
- /// multiple requests to the inner client. Each time the inner client responds with
- /// a function call request, this client might perform that invocation and send the results
- /// back to the inner client in a new request. This property limits the number of times
- /// such a roundtrip is performed. The value must be at least one, as it includes the initial request.
- ///
- ///
- /// Changing the value of this property while the client is in use might result in inconsistencies
- /// as to how many iterations are allowed for an in-flight request.
- ///
- ///
- public int MaximumIterationsPerRequest
- {
- get => _maximumIterationsPerRequest;
- set
- {
- if (value < 1)
- {
- Throw.ArgumentOutOfRangeException(nameof(value));
- }
-
- _maximumIterationsPerRequest = value;
- }
- }
-
- ///
- /// Gets or sets the maximum number of consecutive iterations that are allowed to fail with an error.
- ///
- ///
- /// The maximum number of consecutive iterations that are allowed to fail with an error.
- /// The default value is 3.
- ///
- ///
- ///
- /// When function invocations fail with an exception, the
- /// continues to make requests to the inner client, optionally supplying exception information (as
- /// controlled by ). This allows the to
- /// recover from errors by trying other function parameters that may succeed.
- ///
- ///
- /// However, in case function invocations continue to produce exceptions, this property can be used to
- /// limit the number of consecutive failing attempts. When the limit is reached, the exception will be
- /// rethrown to the caller.
- ///
- ///
- /// If the value is set to zero, all function calling exceptions immediately terminate the function
- /// invocation loop and the exception will be rethrown to the caller.
- ///
- ///
- /// Changing the value of this property while the client is in use might result in inconsistencies
- /// as to how many iterations are allowed for an in-flight request.
- ///
- ///
- public int MaximumConsecutiveErrorsPerRequest
- {
- get => _maximumConsecutiveErrorsPerRequest;
- set => _maximumConsecutiveErrorsPerRequest = Throw.IfLessThan(value, 0);
- }
-
- ///
- public override async Task GetResponseAsync(
- IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default)
- {
- _ = Throw.IfNull(messages);
-
- // A single request into this GetResponseAsync may result in multiple requests to the inner client.
- // Create an activity to group them together for better observability.
- using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient));
-
- // Copy the original messages in order to avoid enumerating the original messages multiple times.
- // The IEnumerable can represent an arbitrary amount of work.
- List originalMessages = [.. messages];
- messages = originalMessages;
-
- List? augmentedHistory = null; // the actual history of messages sent on turns other than the first
- ChatResponse? response = null; // the response from the inner client, which is possibly modified and then eventually returned
- List? responseMessages = null; // tracked list of messages, across multiple turns, to be used for the final response
- UsageDetails? totalUsage = null; // tracked usage across all turns, to be used for the final response
- List? functionCallContents = null; // function call contents that need responding to in the current turn
- bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set
- int consecutiveErrorCount = 0;
-
- for (int iteration = 0; ; iteration++)
- {
- functionCallContents?.Clear();
-
- // Make the call to the inner client.
- response = await base.GetResponseAsync(messages, options, cancellationToken);
- if (response is null)
- {
- Throw.InvalidOperationException($"The inner {nameof(IChatClient)} returned a null {nameof(ChatResponse)}.");
- }
-
- // Any function call work to do? If yes, ensure we're tracking that work in functionCallContents.
- bool requiresFunctionInvocation =
- options?.Tools is { Count: > 0 } &&
- iteration < MaximumIterationsPerRequest &&
- CopyFunctionCalls(response.Messages, ref functionCallContents);
-
- // In a common case where we make a request and there's no function calling work required,
- // fast path out by just returning the original response.
- if (iteration == 0 && !requiresFunctionInvocation)
- {
- return response;
- }
-
- // Track aggregatable details from the response, including all of the response messages and usage details.
- (responseMessages ??= []).AddRange(response.Messages);
- if (response.Usage is not null)
- {
- if (totalUsage is not null)
- {
- totalUsage.Add(response.Usage);
- }
- else
- {
- totalUsage = response.Usage;
- }
- }
-
- // If there are no tools to call, or for any other reason we should stop, we're done.
- // Break out of the loop and allow the handling at the end to configure the response
- // with aggregated data from previous requests.
- if (!requiresFunctionInvocation)
- {
- break;
- }
-
- // Prepare the history for the next iteration.
- FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId);
-
- // Add the responses from the function calls into the augmented history and also into the tracked
- // list of response messages.
- var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options!, functionCallContents!, iteration, consecutiveErrorCount, cancellationToken);
- responseMessages.AddRange(modeAndMessages.MessagesAdded);
- consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;
-
- if (modeAndMessages.ShouldTerminate)
- {
- break;
- }
-
- UpdateOptionsForNextIteration(ref options!, response.ChatThreadId);
- }
-
- Debug.Assert(responseMessages is not null, "Expected to only be here if we have response messages.");
- response.Messages = responseMessages!;
- response.Usage = totalUsage;
-
- return response;
- }
-
- ///
- public override async IAsyncEnumerable GetStreamingResponseAsync(
- IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
- {
- _ = Throw.IfNull(messages);
-
- // A single request into this GetStreamingResponseAsync may result in multiple requests to the inner client.
- // Create an activity to group them together for better observability.
- using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient));
-
- // Copy the original messages in order to avoid enumerating the original messages multiple times.
- // The IEnumerable can represent an arbitrary amount of work.
- List originalMessages = [.. messages];
- messages = originalMessages;
-
- List? augmentedHistory = null; // the actual history of messages sent on turns other than the first
- List? functionCallContents = null; // function call contents that need responding to in the current turn
- List? responseMessages = null; // tracked list of messages, across multiple turns, to be used in fallback cases to reconstitute history
- bool lastIterationHadThreadId = false; // whether the last iteration's response had a ChatThreadId set
- List updates = []; // updates from the current response
- int consecutiveErrorCount = 0;
-
- for (int iteration = 0; ; iteration++)
- {
- updates.Clear();
- functionCallContents?.Clear();
-
- await foreach (var update in base.GetStreamingResponseAsync(messages, options, cancellationToken))
- {
- if (update is null)
- {
- Throw.InvalidOperationException($"The inner {nameof(IChatClient)} streamed a null {nameof(ChatResponseUpdate)}.");
- }
-
- updates.Add(update);
-
- _ = CopyFunctionCalls(update.Contents, ref functionCallContents);
-
- yield return update;
- Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802
- }
-
- // If there are no tools to call, or for any other reason we should stop, return the response.
- if (functionCallContents is not { Count: > 0 } ||
- options?.Tools is not { Count: > 0 } ||
- iteration >= _maximumIterationsPerRequest)
- {
- break;
- }
-
- // Reconsistitue a response from the response updates.
- var response = updates.ToChatResponse();
- (responseMessages ??= []).AddRange(response.Messages);
-
- // Prepare the history for the next iteration.
- FixupHistories(originalMessages, ref messages, ref augmentedHistory, response, responseMessages, ref lastIterationHadThreadId);
-
- // Process all of the functions, adding their results into the history.
- var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, cancellationToken);
- responseMessages.AddRange(modeAndMessages.MessagesAdded);
- consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;
-
- // This is a synthetic ID since we're generating the tool messages instead of getting them from
- // the underlying provider. When emitting the streamed chunks, it's perfectly valid for us to
- // use the same message ID for all of them within a given iteration, as this is a single logical
- // message with multiple content items. We could also use different message IDs per tool content,
- // but there's no benefit to doing so.
- string toolResponseId = Guid.NewGuid().ToString("N");
-
- // Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages
- // includes all activitys, including generated function results.
- foreach (var message in modeAndMessages.MessagesAdded)
- {
- var toolResultUpdate = new ChatResponseUpdate
- {
- AdditionalProperties = message.AdditionalProperties,
- AuthorName = message.AuthorName,
- ChatThreadId = response.ChatThreadId,
- CreatedAt = DateTimeOffset.UtcNow,
- Contents = message.Contents,
- RawRepresentation = message.RawRepresentation,
- ResponseId = toolResponseId,
- MessageId = toolResponseId, // See above for why this can be the same as ResponseId
- Role = message.Role,
- };
-
- yield return toolResultUpdate;
- Activity.Current = activity; // workaround for https://github.com/dotnet/runtime/issues/47802
- }
-
- if (modeAndMessages.ShouldTerminate)
- {
- yield break;
- }
-
- UpdateOptionsForNextIteration(ref options, response.ChatThreadId);
- }
- }
-
- /// Prepares the various chat message lists after a response from the inner client and before invoking functions.
- /// The original messages provided by the caller.
- /// The messages reference passed to the inner client.
- /// The augmented history containing all the messages to be sent.
- /// The most recent response being handled.
- /// A list of all response messages received up until this point.
- /// Whether the previous iteration's response had a thread id.
- private static void FixupHistories(
- IEnumerable originalMessages,
- ref IEnumerable messages,
- [NotNull] ref List? augmentedHistory,
- ChatResponse response,
- List allTurnsResponseMessages,
- ref bool lastIterationHadThreadId)
- {
- // We're now going to need to augment the history with function result contents.
- // That means we need a separate list to store the augmented history.
- if (response.ChatThreadId is not null)
- {
- // The response indicates the inner client is tracking the history, so we don't want to send
- // anything we've already sent or received.
- if (augmentedHistory is not null)
- {
- augmentedHistory.Clear();
- }
- else
- {
- augmentedHistory = [];
- }
-
- lastIterationHadThreadId = true;
- }
- else if (lastIterationHadThreadId)
- {
- // In the very rare case where the inner client returned a response with a thread ID but then
- // returned a subsequent response without one, we want to reconstitue the full history. To do that,
- // we can populate the history with the original chat messages and then all of the response
- // messages up until this point, which includes the most recent ones.
- augmentedHistory ??= [];
- augmentedHistory.Clear();
- augmentedHistory.AddRange(originalMessages);
- augmentedHistory.AddRange(allTurnsResponseMessages);
-
- lastIterationHadThreadId = false;
- }
- else
- {
- // If augmentedHistory is already non-null, then we've already populated it with everything up
- // until this point (except for the most recent response). If it's null, we need to seed it with
- // the chat history provided by the caller.
- augmentedHistory ??= originalMessages.ToList();
-
- // Now add the most recent response messages.
- augmentedHistory.AddMessages(response);
-
- lastIterationHadThreadId = false;
- }
-
- // Use the augmented history as the new set of messages to send.
- messages = augmentedHistory;
- }
-
- /// Copies any from to .
- private static bool CopyFunctionCalls(
- IList messages, [NotNullWhen(true)] ref List? functionCalls)
- {
- bool any = false;
- int count = messages.Count;
- for (int i = 0; i < count; i++)
- {
- any |= CopyFunctionCalls(messages[i].Contents, ref functionCalls);
- }
-
- return any;
- }
-
- /// Copies any from to .
- private static bool CopyFunctionCalls(
- IList content, [NotNullWhen(true)] ref List? functionCalls)
- {
- bool any = false;
- int count = content.Count;
- for (int i = 0; i < count; i++)
- {
- if (content[i] is FunctionCallContent functionCall)
- {
- (functionCalls ??= []).Add(functionCall);
- any = true;
- }
- }
-
- return any;
- }
-
- private static void UpdateOptionsForNextIteration(ref ChatOptions options, string? chatThreadId)
- {
- if (options.ToolMode is RequiredChatToolMode)
- {
- // We have to reset the tool mode to be non-required after the first iteration,
- // as otherwise we'll be in an infinite loop.
- options = options.Clone();
- options.ToolMode = null;
- options.ChatThreadId = chatThreadId;
- }
- else if (options.ChatThreadId != chatThreadId)
- {
- // As with the other modes, ensure we've propagated the chat thread ID to the options.
- // We only need to clone the options if we're actually mutating it.
- options = options.Clone();
- options.ChatThreadId = chatThreadId;
- }
- }
-
- ///
- /// Processes the function calls in the list.
- ///
- /// The current chat contents, inclusive of the function call contents being processed.
- /// The options used for the response being processed.
- /// The function call contents representing the functions to be invoked.
- /// The iteration number of how many roundtrips have been made to the inner client.
- /// The number of consecutive iterations, prior to this one, that were recorded as having function invocation errors.
- /// The to monitor for cancellation requests.
- /// A value indicating how the caller should proceed.
- private async Task<(bool ShouldTerminate, int NewConsecutiveErrorCount, IList MessagesAdded)> ProcessFunctionCallsAsync(
- List messages, ChatOptions options, List functionCallContents, int iteration, int consecutiveErrorCount, CancellationToken cancellationToken)
- {
- // We must add a response for every tool call, regardless of whether we successfully executed it or not.
- // If we successfully execute it, we'll add the result. If we don't, we'll add an error.
-
- Debug.Assert(functionCallContents.Count > 0, "Expecteded at least one function call.");
-
- var captureCurrentIterationExceptions = consecutiveErrorCount < _maximumConsecutiveErrorsPerRequest;
-
- // Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel.
- if (functionCallContents.Count == 1)
- {
- FunctionInvocationResult result = await ProcessFunctionCallAsync(
- messages, options, functionCallContents, iteration, 0, captureCurrentIterationExceptions, cancellationToken);
-
- IList added = CreateResponseMessages([result]);
- ThrowIfNoFunctionResultsAdded(added);
- UpdateConsecutiveErrorCountOrThrow(added, ref consecutiveErrorCount);
-
- messages.AddRange(added);
- return (result.ShouldTerminate, consecutiveErrorCount, added);
- }
- else
- {
- FunctionInvocationResult[] results;
-
- if (AllowConcurrentInvocation)
- {
- // Rather than await'ing each function before invoking the next, invoke all of them
- // and then await all of them. We avoid forcibly introducing parallelism via Task.Run,
- // but if a function invocation completes asynchronously, its processing can overlap
- // with the processing of other the other invocation invocations.
- results = await Task.WhenAll(
- from i in Enumerable.Range(0, functionCallContents.Count)
- select ProcessFunctionCallAsync(
- messages, options, functionCallContents,
- iteration, i, captureExceptions: true, cancellationToken));
- }
- else
- {
- // Invoke each function serially.
- results = new FunctionInvocationResult[functionCallContents.Count];
- for (int i = 0; i < results.Length; i++)
- {
- results[i] = await ProcessFunctionCallAsync(
- messages, options, functionCallContents,
- iteration, i, captureCurrentIterationExceptions, cancellationToken);
- }
- }
-
- var shouldTerminate = false;
-
- IList added = CreateResponseMessages(results);
- ThrowIfNoFunctionResultsAdded(added);
- UpdateConsecutiveErrorCountOrThrow(added, ref consecutiveErrorCount);
-
- messages.AddRange(added);
- foreach (FunctionInvocationResult fir in results)
- {
- shouldTerminate = shouldTerminate || fir.ShouldTerminate;
- }
-
- return (shouldTerminate, consecutiveErrorCount, added);
- }
- }
-
- private void UpdateConsecutiveErrorCountOrThrow(IList added, ref int consecutiveErrorCount)
- {
- var allExceptions = added.SelectMany(m => m.Contents.OfType())
- .Select(frc => frc.Exception!)
- .Where(e => e is not null);
-
- if (allExceptions.Any())
- {
- consecutiveErrorCount++;
- if (consecutiveErrorCount > _maximumConsecutiveErrorsPerRequest)
- {
- var allExceptionsArray = allExceptions.ToArray();
- if (allExceptionsArray.Length == 1)
- {
- ExceptionDispatchInfo.Capture(allExceptionsArray[0]).Throw();
- }
- else
- {
- throw new AggregateException(allExceptionsArray);
- }
- }
- }
- else
- {
- consecutiveErrorCount = 0;
- }
- }
-
- ///
- /// Throws an exception if doesn't create any messages.
- ///
- private void ThrowIfNoFunctionResultsAdded(IList? messages)
- {
- if (messages is null || messages.Count == 0)
- {
- Throw.InvalidOperationException($"{GetType().Name}.{nameof(CreateResponseMessages)} returned null or an empty collection of messages.");
- }
- }
-
- /// Processes the function call described in [].
- /// The current chat contents, inclusive of the function call contents being processed.
- /// The options used for the response being processed.
- /// The function call contents representing all the functions being invoked.
- /// The iteration number of how many roundtrips have been made to the inner client.
- /// The 0-based index of the function being called out of .
- /// If true, handles function-invocation exceptions by returning a value with . Otherwise, rethrows.
- /// The to monitor for cancellation requests.
- /// A value indicating how the caller should proceed.
- private async Task ProcessFunctionCallAsync(
- List messages, ChatOptions options, List callContents,
- int iteration, int functionCallIndex, bool captureExceptions, CancellationToken cancellationToken)
- {
- var callContent = callContents[functionCallIndex];
-
- // Look up the AIFunction for the function call. If the requested function isn't available, send back an error.
- AIFunction? function = options.Tools!.OfType().FirstOrDefault(t => t.Name == callContent.Name);
- if (function is null)
- {
- return new(shouldTerminate: false, FunctionInvocationStatus.NotFound, callContent, result: null, exception: null);
- }
-
- FunctionInvocationContext context = new()
- {
- Function = function,
- Arguments = new(callContent.Arguments) { Services = _functionInvocationServices },
-
- Messages = messages,
- Options = options,
-
- CallContent = callContent,
- Iteration = iteration,
- FunctionCallIndex = functionCallIndex,
- FunctionCount = callContents.Count,
- };
-
- object? result;
- try
- {
- result = await InvokeFunctionAsync(context, cancellationToken);
- }
- catch (Exception e) when (!cancellationToken.IsCancellationRequested)
- {
- if (!captureExceptions)
- {
- throw;
- }
-
- return new(
- shouldTerminate: false,
- FunctionInvocationStatus.Exception,
- callContent,
- result: null,
- exception: e);
- }
-
- return new(
- shouldTerminate: context.Terminate,
- FunctionInvocationStatus.RanToCompletion,
- callContent,
- result,
- exception: null);
- }
-
- /// Creates one or more response messages for function invocation results.
- /// Information about the function call invocations and results.
- /// A list of all chat messages created from .
- protected virtual IList CreateResponseMessages(
- ReadOnlySpan results)
- {
- var contents = new List(results.Length);
- for (int i = 0; i < results.Length; i++)
- {
- contents.Add(CreateFunctionResultContent(results[i]));
- }
-
- return [new(ChatRole.Tool, contents)];
-
- FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult result)
- {
- _ = Throw.IfNull(result);
-
- object? functionResult;
- if (result.Status == FunctionInvocationStatus.RanToCompletion)
- {
- functionResult = result.Result ?? "Success: Function completed.";
- }
- else
- {
- string message = result.Status switch
- {
- FunctionInvocationStatus.NotFound => $"Error: Requested function \"{result.CallContent.Name}\" not found.",
- FunctionInvocationStatus.Exception => "Error: Function failed.",
- _ => "Error: Unknown error.",
- };
-
- if (IncludeDetailedErrors && result.Exception is not null)
- {
- message = $"{message} Exception: {result.Exception.Message}";
- }
-
- functionResult = message;
- }
-
- return new FunctionResultContent(result.CallContent.CallId, functionResult) { Exception = result.Exception };
- }
- }
-
- /// Invokes the function asynchronously.
- ///
- /// The function invocation context detailing the function to be invoked and its arguments along with additional request information.
- ///
- /// The to monitor for cancellation requests. The default is .
- /// The result of the function invocation, or if the function invocation returned .
- /// is .
- protected virtual async Task InvokeFunctionAsync(FunctionInvocationContext context, CancellationToken cancellationToken)
- {
- _ = Throw.IfNull(context);
-
- using Activity? activity = _activitySource?.StartActivity(context.Function.Name);
-
- long startingTimestamp = 0;
- if (_logger.IsEnabled(LogLevel.Debug))
- {
- startingTimestamp = Stopwatch.GetTimestamp();
- if (_logger.IsEnabled(LogLevel.Trace))
- {
- LogInvokingSensitive(context.Function.Name, LoggingHelpers.AsJson(context.Arguments, context.Function.JsonSerializerOptions));
- }
- else
- {
- LogInvoking(context.Function.Name);
- }
- }
-
- object? result = null;
- try
- {
- CurrentContext = context; // doesn't need to be explicitly reset after, as that's handled automatically at async method exit
- result = await context.Function.InvokeAsync(context.Arguments, cancellationToken);
- }
- catch (Exception e)
- {
- if (activity is not null)
- {
- _ = activity.SetTag("error.type", e.GetType().FullName)
- .SetStatus(ActivityStatusCode.Error, e.Message);
- }
-
- if (e is OperationCanceledException)
- {
- LogInvocationCanceled(context.Function.Name);
- }
- else
- {
- LogInvocationFailed(context.Function.Name, e);
- }
-
- throw;
- }
- finally
- {
- if (_logger.IsEnabled(LogLevel.Debug))
- {
- TimeSpan elapsed = GetElapsedTime(startingTimestamp);
-
- if (result is not null && _logger.IsEnabled(LogLevel.Trace))
- {
- LogInvocationCompletedSensitive(context.Function.Name, elapsed, LoggingHelpers.AsJson(result, context.Function.JsonSerializerOptions));
- }
- else
- {
- LogInvocationCompleted(context.Function.Name, elapsed);
- }
- }
- }
-
- return result;
- }
-
- private static TimeSpan GetElapsedTime(long startingTimestamp) =>
-#if NET
- Stopwatch.GetElapsedTime(startingTimestamp);
-#else
- new((long)((Stopwatch.GetTimestamp() - startingTimestamp) * ((double)TimeSpan.TicksPerSecond / Stopwatch.Frequency)));
-#endif
-
- [LoggerMessage(LogLevel.Debug, "Invoking {MethodName}.", SkipEnabledCheck = true)]
- private partial void LogInvoking(string methodName);
-
- [LoggerMessage(LogLevel.Trace, "Invoking {MethodName}({Arguments}).", SkipEnabledCheck = true)]
- private partial void LogInvokingSensitive(string methodName, string arguments);
-
- [LoggerMessage(LogLevel.Debug, "{MethodName} invocation completed. Duration: {Duration}", SkipEnabledCheck = true)]
- private partial void LogInvocationCompleted(string methodName, TimeSpan duration);
-
- [LoggerMessage(LogLevel.Trace, "{MethodName} invocation completed. Duration: {Duration}. Result: {Result}", SkipEnabledCheck = true)]
- private partial void LogInvocationCompletedSensitive(string methodName, TimeSpan duration, string result);
-
- [LoggerMessage(LogLevel.Debug, "{MethodName} invocation canceled.")]
- private partial void LogInvocationCanceled(string methodName);
-
- [LoggerMessage(LogLevel.Error, "{MethodName} invocation failed.")]
- private partial void LogInvocationFailed(string methodName, Exception error);
-
- /// Provides information about the invocation of a function call.
- public sealed class FunctionInvocationResult
- {
- internal FunctionInvocationResult(bool shouldTerminate, FunctionInvocationStatus status, FunctionCallContent callContent, object? result, Exception? exception)
- {
- ShouldTerminate = shouldTerminate;
- Status = status;
- CallContent = callContent;
- Result = result;
- Exception = exception;
- }
-
- /// Gets status about how the function invocation completed.
- public FunctionInvocationStatus Status { get; }
-
- /// Gets the function call content information associated with this invocation.
- public FunctionCallContent CallContent { get; }
-
- /// Gets the result of the function call.
- public object? Result { get; }
-
- /// Gets any exception the function call threw.
- public Exception? Exception { get; }
-
- /// Gets a value indicating whether the caller should terminate the processing loop.
- internal bool ShouldTerminate { get; }
- }
-
- /// Provides error codes for when errors occur as part of the function calling loop.
- public enum FunctionInvocationStatus
- {
- /// The operation completed successfully.
- RanToCompletion,
-
- /// The requested function could not be found.
- NotFound,
-
- /// The function call failed with an exception.
- Exception,
- }
-}
From 6d7a374b9808d5e0f072a9279bb3ec1ec5db9833 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Mon, 14 Apr 2025 12:35:24 +0100
Subject: [PATCH 21/26] Removing KernelFunctionINvocationContext in favor of
Microsoft.Extensions.AI.FunctionInvocationContext
---
.../AI/ChatClient/ChatClientExtensions.cs | 2 +-
.../KernelFunctionInvocationContext.cs | 127 ------------------
.../AutoFunctionInvocationContext.cs | 4 +-
3 files changed, 3 insertions(+), 130 deletions(-)
delete mode 100644 dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
index 223530beb1ba..e035a436a83a 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientExtensions.cs
@@ -92,7 +92,7 @@ public static IChatCompletionService AsChatCompletionService(this IChatClient cl
/// Creates a new that supports for function invocation with a .
///
/// Target chat client service.
- /// Optional logger to use for logging.
+ /// Optional logger factory to use for logging.
/// Function invoking chat client.
[Experimental("SKEXP0001")]
public static IChatClient AsKernelFunctionInvokingChatClient(this IChatClient client, ILoggerFactory? loggerFactory = null)
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
deleted file mode 100644
index efc3d39ab48a..000000000000
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvocationContext.cs
+++ /dev/null
@@ -1,127 +0,0 @@
-// Copyright (c) Microsoft. All rights reserved.
-
-using System;
-using System.Collections.Generic;
-using System.Diagnostics.CodeAnalysis;
-using System.Runtime.CompilerServices;
-using Microsoft.Extensions.AI;
-
-#pragma warning disable IDE0009 // Use explicit 'this.' qualifier
-#pragma warning disable CA2213 // Disposable fields should be disposed
-#pragma warning disable IDE0044 // Add readonly modifier
-
-namespace Microsoft.SemanticKernel.ChatCompletion;
-
-// Slight modified source from
-// https://raw.githubusercontent.com/dotnet/extensions/refs/heads/main/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvocationContext.cs
-
-/// Provides context for an in-flight function invocation.
-[ExcludeFromCodeCoverage]
-internal class KernelFunctionInvocationContext
-{
- ///
- /// A nop function used to allow to be non-nullable. Default instances of
- /// start with this as the target function.
- ///
- private static readonly AIFunction _nopFunction = AIFunctionFactory.Create(() => { }, nameof(FunctionInvocationContext));
-
- /// The chat contents associated with the operation that initiated this function call request.
- private IList _messages = Array.Empty();
-
- /// The AI function to be invoked.
- private AIFunction _function = _nopFunction;
-
- /// The function call content information associated with this invocation.
- private Microsoft.Extensions.AI.FunctionCallContent? _callContent;
-
- /// The arguments used with the function.
- private AIFunctionArguments? _arguments;
-
- /// Initializes a new instance of the class.
- public KernelFunctionInvocationContext()
- {
- }
-
- /// Gets or sets the AI function to be invoked.
- public AIFunction Function
- {
- get => _function;
- set => _function = Throw.IfNull(value);
- }
-
- /// Gets or sets the arguments associated with this invocation.
- public AIFunctionArguments Arguments
- {
- get => _arguments ??= [];
- set => _arguments = Throw.IfNull(value);
- }
-
- /// Gets or sets the function call content information associated with this invocation.
- public Microsoft.Extensions.AI.FunctionCallContent CallContent
- {
- get => _callContent ??= new(string.Empty, _nopFunction.Name, EmptyReadOnlyDictionary.Instance);
- set => _callContent = Throw.IfNull(value);
- }
-
- /// Gets or sets the chat contents associated with the operation that initiated this function call request.
- public IList Messages
- {
- get => _messages;
- set => _messages = Throw.IfNull(value);
- }
-
- /// Gets or sets the chat options associated with the operation that initiated this function call request.
- public ChatOptions? Options { get; set; }
-
- /// Gets or sets the number of this iteration with the underlying client.
- ///
- /// The initial request to the client that passes along the chat contents provided to the
- /// is iteration 1. If the client responds with a function call request, the next request to the client is iteration 2, and so on.
- ///
- public int Iteration { get; set; }
-
- /// Gets or sets the index of the function call within the iteration.
- ///
- /// The response from the underlying client may include multiple function call requests.
- /// This index indicates the position of the function call within the iteration.
- ///
- public int FunctionCallIndex { get; set; }
-
- /// Gets or sets the total number of function call requests within the iteration.
- ///
- /// The response from the underlying client might include multiple function call requests.
- /// This count indicates how many there were.
- ///
- public int FunctionCount { get; set; }
-
- /// Gets or sets a value indicating whether to terminate the request.
- ///
- /// In response to a function call request, the function might be invoked, its result added to the chat contents,
- /// and a new request issued to the wrapped client. If this property is set to , that subsequent request
- /// will not be issued and instead the loop immediately terminated rather than continuing until there are no
- /// more function call requests in responses.
- ///
- public bool Terminate { get; set; }
-
- private static class Throw
- {
- ///
- /// Throws an if the specified argument is .
- ///
- /// Argument type to be checked for .
- /// Object to be checked for .
- /// The name of the parameter being checked.
- /// The original value of .
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- [return: NotNull]
- public static T IfNull([NotNull] T argument, [CallerArgumentExpression(nameof(argument))] string paramName = "")
- {
- if (argument is null)
- {
- throw new ArgumentNullException(paramName);
- }
-
- return argument;
- }
- }
-}
diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
index f70205365330..a0ddb480c275 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
@@ -17,12 +17,12 @@ public class AutoFunctionInvocationContext
{
private ChatHistory? _chatHistory;
private KernelFunction? _kernelFunction;
- private readonly KernelFunctionInvocationContext _invocationContext = new();
+ private readonly Microsoft.Extensions.AI.FunctionInvocationContext _invocationContext = new();
///
/// Initializes a new instance of the class from an existing .
///
- internal AutoFunctionInvocationContext(KernelFunctionInvocationContext invocationContext)
+ internal AutoFunctionInvocationContext(Microsoft.Extensions.AI.FunctionInvocationContext invocationContext)
{
Verify.NotNull(invocationContext);
Verify.NotNull(invocationContext.Options);
From 7a380273bd21e7dcd5c7e2db5908bdbd2c2f3bf8 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Mon, 14 Apr 2025 12:58:32 +0100
Subject: [PATCH 22/26] Fix reference
---
.../AutoFunctionInvocation/AutoFunctionInvocationContext.cs | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
index a0ddb480c275..bc8dd0c3490c 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs
@@ -20,7 +20,7 @@ public class AutoFunctionInvocationContext
private readonly Microsoft.Extensions.AI.FunctionInvocationContext _invocationContext = new();
///
- /// Initializes a new instance of the class from an existing .
+ /// Initializes a new instance of the class from an existing .
///
internal AutoFunctionInvocationContext(Microsoft.Extensions.AI.FunctionInvocationContext invocationContext)
{
From c989c142e7c4ad1daeb0dd4b88bdf8254c2ab396 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Mon, 14 Apr 2025 13:00:41 +0100
Subject: [PATCH 23/26] Typo fix
---
.../AI/ChatClient/KernelFunctionInvokingChatClient.cs | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
index e6496cc63f4f..5c77bb415867 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
@@ -367,7 +367,7 @@ public override async IAsyncEnumerable GetStreamingResponseA
break;
}
- // Reconsistitue a response from the response updates.
+ // Reconstitute a response from the response updates.
var response = updates.ToChatResponse();
(responseMessages ??= []).AddRange(response.Messages);
From 4316651ad72f81e418a4c0fbab1a750997218aa6 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Mon, 14 Apr 2025 13:08:06 +0100
Subject: [PATCH 24/26] Fix typos + virtual to private
---
.../KernelFunctionInvokingChatClient.cs | 23 +++++++++----------
1 file changed, 11 insertions(+), 12 deletions(-)
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
index 5c77bb415867..544fefe72493 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
@@ -117,13 +117,13 @@ public static AutoFunctionInvocationContext? CurrentContext
///
///
/// Setting the value to can help the underlying bypass problems on
- /// its own, for example by retrying the function call with different arguments. However it might
+ /// its own, for example by retrying the function call with different arguments. However, it might
/// result in disclosing the raw exception information to external users, which can be a security
/// concern depending on the application scenario.
///
///
/// Changing the value of this property while the client is in use might result in inconsistencies
- /// as to whether detailed errors are provided during an in-flight request.
+ /// whether detailed errors are provided during an in-flight request.
///
///
public bool IncludeDetailedErrors { get; set; }
@@ -265,7 +265,7 @@ public override async Task GetResponseAsync(
return response;
}
- // Track aggregatable details from the response, including all of the response messages and usage details.
+ // Track aggregate details from the response, including all the response messages and usage details.
(responseMessages ??= []).AddRange(response.Messages);
if (response.Usage is not null)
{
@@ -377,7 +377,7 @@ public override async IAsyncEnumerable GetStreamingResponseA
// Prepare the options for the next auto function invocation iteration.
UpdateOptionsForAutoFunctionInvocation(ref options, response.Messages.Last().ToChatMessageContent(), isStreaming: true);
- // Process all of the functions, adding their results into the history.
+ // Process all the functions, adding their results into the history.
var modeAndMessages = await ProcessFunctionCallsAsync(augmentedHistory, options, functionCallContents, iteration, consecutiveErrorCount, isStreaming: true, cancellationToken).ConfigureAwait(false);
responseMessages.AddRange(modeAndMessages.MessagesAdded);
consecutiveErrorCount = modeAndMessages.NewConsecutiveErrorCount;
@@ -393,7 +393,7 @@ public override async IAsyncEnumerable GetStreamingResponseA
string toolResponseId = Guid.NewGuid().ToString("N");
// Stream any generated function results. This mirrors what's done for GetResponseAsync, where the returned messages
- // includes all activitys, including generated function results.
+ // include all activity, including generated function results.
foreach (var message in modeAndMessages.MessagesAdded)
{
var toolResultUpdate = new ChatResponseUpdate
@@ -457,8 +457,8 @@ private static void FixupHistories(
else if (lastIterationHadThreadId)
{
// In the very rare case where the inner client returned a response with a thread ID but then
- // returned a subsequent response without one, we want to reconstitue the full history. To do that,
- // we can populate the history with the original chat messages and then all of the response
+ // returned a subsequent response without one, we want to reconstitute the full history. To do that,
+ // we can populate the history with the original chat messages and then all the response
// messages up until this point, which includes the most recent ones.
augmentedHistory ??= [];
augmentedHistory.Clear();
@@ -584,7 +584,7 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions options, strin
// We must add a response for every tool call, regardless of whether we successfully executed it or not.
// If we successfully execute it, we'll add the result. If we don't, we'll add an error.
- Debug.Assert(functionCallContents.Count > 0, "Expecteded at least one function call.");
+ Debug.Assert(functionCallContents.Count > 0, "Expected at least one function call.");
var shouldTerminate = false;
var captureCurrentIterationExceptions = consecutiveErrorCount < _maximumConsecutiveErrorsPerRequest;
@@ -609,7 +609,7 @@ private static void UpdateOptionsForNextIteration(ref ChatOptions options, strin
var terminationRequested = false;
if (AllowConcurrentInvocation)
{
- // Rather than await'ing each function before invoking the next, invoke all of them
+ // Rather than awaiting each function before invoking the next, invoke all of them
// and then await all of them. We avoid forcibly introducing parallelism via Task.Run,
// but if a function invocation completes asynchronously, its processing can overlap
// with the processing of other the other invocation invocations.
@@ -776,8 +776,7 @@ private async Task ProcessFunctionCallAsync(
/// Creates one or more response messages for function invocation results.
/// Information about the function call invocations and results.
/// A list of all chat messages created from .
- protected virtual IList CreateResponseMessages(
- IReadOnlyList results)
+ private IList CreateResponseMessages(List results)
{
var contents = new List(results.Count);
for (int i = 0; i < results.Count; i++)
@@ -866,7 +865,7 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(
/// The to monitor for cancellation requests. The default is .
/// The result of the function invocation, or if the function invocation returned .
/// is .
- protected virtual async Task InvokeFunctionAsync(AutoFunctionInvocationContext context, CancellationToken cancellationToken)
+ private async Task InvokeFunctionAsync(AutoFunctionInvocationContext context, CancellationToken cancellationToken)
{
Verify.NotNull(context);
From cb636e941155aec247f9451056d4912a5880d78f Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Tue, 15 Apr 2025 13:32:03 +0100
Subject: [PATCH 25/26] Address PR comments
---
.../OpenAIKernelBuilderExtensions.ChatClient.cs | 11 +++++------
.../Extensions/OpenAIKernelBuilderExtensions.cs | 8 ++++----
.../AI/ChatClient/ChatOptionsExtensions.cs | 2 +-
.../AI/ChatClient/KernelFunctionInvokingChatClient.cs | 2 +-
.../Functions/KernelFunction.cs | 1 -
5 files changed, 11 insertions(+), 13 deletions(-)
diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs
index ac0005f5d5c5..9d1832b340ff 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.ChatClient.cs
@@ -3,20 +3,19 @@
using System;
using System.Diagnostics.CodeAnalysis;
using System.Net.Http;
+using Microsoft.Extensions.AI;
using OpenAI;
namespace Microsoft.SemanticKernel;
-///
-/// Sponsor extensions class for .
-///
+/// Extension methods for .
[Experimental("SKEXP0010")]
public static class OpenAIChatClientKernelBuilderExtensions
{
#region Chat Completion
///
- /// Adds the OpenAI chat completion service to the list.
+ /// Adds an OpenAI to the .
///
/// The instance to augment.
/// OpenAI model name, see https://platform.openai.com/docs/models
@@ -46,7 +45,7 @@ public static IKernelBuilder AddOpenAIChatClient(
}
///
- /// Adds the OpenAI chat completion service to the list.
+ /// Adds an OpenAI to the .
///
/// The instance to augment.
/// OpenAI model id
@@ -70,7 +69,7 @@ public static IKernelBuilder AddOpenAIChatClient(
}
///
- /// Adds the Custom Endpoint OpenAI chat completion service to the list.
+ /// Adds a custom endpoint OpenAI to the .
///
/// The instance to augment.
/// OpenAI model name, see https://platform.openai.com/docs/models
diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.cs
index a07a81fdb5b3..01307b9adc2a 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI/Extensions/OpenAIKernelBuilderExtensions.cs
@@ -20,7 +20,7 @@
namespace Microsoft.SemanticKernel;
///
-/// Sponsor extensions class for .
+/// Extension methods for .
///
public static class OpenAIKernelBuilderExtensions
{
@@ -269,7 +269,7 @@ public static IKernelBuilder AddOpenAIFiles(
#region Chat Completion
///
- /// Adds the OpenAI chat completion service to the list.
+ /// Adds an to the .
///
/// The instance to augment.
/// OpenAI model name, see https://platform.openai.com/docs/models
@@ -304,7 +304,7 @@ OpenAIChatCompletionService Factory(IServiceProvider serviceProvider, object? _)
}
///
- /// Adds the OpenAI chat completion service to the list.
+ /// Adds an to the .
///
/// The instance to augment.
/// OpenAI model id
@@ -330,7 +330,7 @@ OpenAIChatCompletionService Factory(IServiceProvider serviceProvider, object? _)
}
///
- /// Adds the Custom Endpoint OpenAI chat completion service to the list.
+ /// Adds a custom endpoint to the .
///
/// The instance to augment.
/// OpenAI model name, see https://platform.openai.com/docs/models
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs
index 5db8240b1707..68540a1c32d8 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatOptionsExtensions.cs
@@ -137,7 +137,7 @@ internal static ChatOptions AddKernel(this ChatOptions options, Kernel? kernel)
if (kernel is not null)
{
options.AdditionalProperties ??= [];
- options.AdditionalProperties?.TryAdd(KernelKey, kernel);
+ options.AdditionalProperties.TryAdd(KernelKey, kernel);
}
return options;
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
index 544fefe72493..ea2dce48fc62 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/KernelFunctionInvokingChatClient.cs
@@ -226,7 +226,7 @@ public override async Task GetResponseAsync(
// A single request into this GetResponseAsync may result in multiple requests to the inner client.
// Create an activity to group them together for better observability.
- using Activity? activity = _activitySource?.StartActivity(nameof(FunctionInvokingChatClient));
+ using Activity? activity = _activitySource?.StartActivity(nameof(KernelFunctionInvokingChatClient));
// Copy the original messages in order to avoid enumerating the original messages multiple times.
// The IEnumerable can represent an arbitrary amount of work.
diff --git a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs
index cb480c71c605..0e09c1d525b7 100644
--- a/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs
@@ -542,7 +542,6 @@ public KernelAIFunction(KernelFunction kernelFunction, Kernel? kernel)
this.JsonSchema = BuildFunctionSchema(kernelFunction);
}
- internal string KernelFunctionName => this._kernelFunction.Name;
public override string Name { get; }
public override JsonElement JsonSchema { get; }
public override string Description => this._kernelFunction.Description;
From 7043c726c2de9d455a28fed55427f4de639f7384 Mon Sep 17 00:00:00 2001
From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com>
Date: Wed, 16 Apr 2025 14:29:14 +0100
Subject: [PATCH 26/26] Address PR feedback
---
.../AI/ChatClient/ChatClientAIService.cs | 2 +-
.../AI/ChatCompletion/ChatHistory.cs | 4 +++-
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs
index b840b33e690b..58ea317804f9 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatClient/ChatClientAIService.cs
@@ -33,7 +33,7 @@ internal ChatClientAIService(IChatClient chatClient)
var metadata = this._chatClient.GetService();
Verify.NotNull(metadata);
- this._internalAttributes["ModelId"] = metadata.DefaultModelId;
+ this._internalAttributes[AIServiceExtensions.ModelIdKey] = metadata.DefaultModelId;
this._internalAttributes[nameof(metadata.ProviderName)] = metadata.ProviderName;
this._internalAttributes[nameof(metadata.ProviderUri)] = metadata.ProviderUri;
}
diff --git a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs
index 0054dc98a400..147cdd5ba332 100644
--- a/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs
+++ b/dotnet/src/SemanticKernel.Abstractions/AI/ChatCompletion/ChatHistory.cs
@@ -35,7 +35,9 @@ public ChatHistory()
this._messages = [];
}
- // Due to changes using the AutoFunctionInvocation as a dependency of KernelInvocation, that needs to reflect
+ // This allows observation of the chat history changes by-reference reflecting in an
+ // internal IEnumerable when used from IChatClients
+ // with AutoFunctionInvocationFilters
internal void SetOverrides(
Action overrideAdd,
Func overrideRemove,