diff --git a/dotnet/samples/Concepts/FunctionCalling/OpenAI_FunctionCalling.cs b/dotnet/samples/Concepts/FunctionCalling/OpenAI_FunctionCalling.cs
index 1b817fbc60fe..3acf2bc61b95 100644
--- a/dotnet/samples/Concepts/FunctionCalling/OpenAI_FunctionCalling.cs
+++ b/dotnet/samples/Concepts/FunctionCalling/OpenAI_FunctionCalling.cs
@@ -233,7 +233,7 @@ public async Task RunNonStreamingPromptWithSimulatedFunctionAsync()
}
// Adding a simulated function call to the connector response message
- FunctionCallContent simulatedFunctionCall = new("weather-alert", id: "call_123");
+ FunctionCallContent simulatedFunctionCall = new("weather_alert", id: "call_123");
result.Items.Add(simulatedFunctionCall);
// Adding a simulated function result to chat history
diff --git a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs
index 8059077d8bf4..18594daa58db 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs
@@ -51,6 +51,26 @@ internal abstract class ClientCore
///
private const int MaxInflightAutoInvokes = 128;
+ ///
+ /// The maximum number of function auto-invokes that can be made in a single user request.
+ ///
+ ///
+ /// After this number of iterations as part of a single user request is reached, auto-invocation
+ /// will be disabled. This is a safeguard against possible runaway execution if the model routinely re-requests
+ /// the same function over and over.
+ ///
+ private const int MaximumAutoInvokeAttempts = 128;
+
+ ///
+ /// Number of requests that are part of a single user interaction that should include this functions in the request.
+ ///
+ ///
+ /// Once this limit is reached, the functions will no longer be included in subsequent requests that are part of the user operation, e.g.
+ /// if this is 1, the first request will include the functions, but the subsequent response sending back the functions' result
+ /// will not include the functions for further use.
+ ///
+ private const int MaximumUseAttempts = 1;
+
/// Singleton tool used when tool call count drops to 0 but we need to supply tools to keep the service happy.
private static readonly ChatCompletionsFunctionToolDefinition s_nonInvocableFunctionTool = new() { Name = "NonInvocableTool" };
@@ -384,13 +404,16 @@ internal async Task> GetChatMessageContentsAsy
// Convert the incoming execution settings to OpenAI settings.
OpenAIPromptExecutionSettings chatExecutionSettings = OpenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);
- bool autoInvoke = kernel is not null && chatExecutionSettings.ToolCallBehavior?.MaximumAutoInvokeAttempts > 0 && s_inflightAutoInvokes.Value < MaxInflightAutoInvokes;
ValidateMaxTokens(chatExecutionSettings.MaxTokens);
- ValidateAutoInvoke(autoInvoke, chatExecutionSettings.ResultsPerPrompt);
// Create the Azure SDK ChatCompletionOptions instance from all available information.
var chatOptions = this.CreateChatCompletionsOptions(chatExecutionSettings, chat, kernel, this.DeploymentOrModelName);
+ var functionCallConfiguration = this.ConfigureFunctionCalling(requestIndex: 0, kernel, chatExecutionSettings, chatOptions);
+
+ bool autoInvoke = kernel is not null && functionCallConfiguration?.MaximumAutoInvokeAttempts > 0 && s_inflightAutoInvokes.Value < MaxInflightAutoInvokes;
+ ValidateAutoInvoke(autoInvoke, chatExecutionSettings.ResultsPerPrompt);
+
for (int requestIndex = 1; ; requestIndex++)
{
// Make the request.
@@ -490,7 +513,7 @@ internal async Task> GetChatMessageContentsAsy
// Make sure the requested function is one we requested. If we're permitting any kernel function to be invoked,
// then we don't need to check this, as it'll be handled when we look up the function in the kernel to be able
// to invoke it. If we're permitting only a specific list of functions, though, then we need to explicitly check.
- if (chatExecutionSettings.ToolCallBehavior?.AllowAnyRequestedKernelFunction is not true &&
+ if (functionCallConfiguration?.AllowAnyRequestedKernelFunction is not true &&
!IsRequestableTool(chatOptions, openAIFunctionToolCall))
{
AddResponseMessage(chatOptions, chat, result: null, "Error: Function call request for a function that wasn't defined.", toolCall, this.Logger);
@@ -564,42 +587,15 @@ internal async Task> GetChatMessageContentsAsy
}
// Update tool use information for the next go-around based on having completed another iteration.
- Debug.Assert(chatExecutionSettings.ToolCallBehavior is not null);
-
- // Set the tool choice to none. If we end up wanting to use tools, we'll reset it to the desired value.
- chatOptions.ToolChoice = ChatCompletionsToolChoice.None;
- chatOptions.Tools.Clear();
-
- if (requestIndex >= chatExecutionSettings.ToolCallBehavior!.MaximumUseAttempts)
- {
- // Don't add any tools as we've reached the maximum attempts limit.
- if (this.Logger.IsEnabled(LogLevel.Debug))
- {
- this.Logger.LogDebug("Maximum use ({MaximumUse}) reached; removing the tool.", chatExecutionSettings.ToolCallBehavior!.MaximumUseAttempts);
- }
- }
- else
- {
- // Regenerate the tool list as necessary. The invocation of the function(s) could have augmented
- // what functions are available in the kernel.
- chatExecutionSettings.ToolCallBehavior.ConfigureOptions(kernel, chatOptions);
- }
-
- // Having already sent tools and with tool call information in history, the service can become unhappy ("[] is too short - 'tools'")
- // if we don't send any tools in subsequent requests, even if we say not to use any.
- if (chatOptions.ToolChoice == ChatCompletionsToolChoice.None)
- {
- Debug.Assert(chatOptions.Tools.Count == 0);
- chatOptions.Tools.Add(s_nonInvocableFunctionTool);
- }
+ functionCallConfiguration = this.ConfigureFunctionCalling(requestIndex, kernel, chatExecutionSettings, chatOptions);
// Disable auto invocation if we've exceeded the allowed limit.
- if (requestIndex >= chatExecutionSettings.ToolCallBehavior!.MaximumAutoInvokeAttempts)
+ if (requestIndex >= functionCallConfiguration?.MaximumAutoInvokeAttempts)
{
autoInvoke = false;
if (this.Logger.IsEnabled(LogLevel.Debug))
{
- this.Logger.LogDebug("Maximum auto-invoke ({MaximumAutoInvoke}) reached.", chatExecutionSettings.ToolCallBehavior!.MaximumAutoInvokeAttempts);
+ this.Logger.LogDebug("Maximum auto-invoke ({MaximumAutoInvoke}) reached.", functionCallConfiguration?.MaximumAutoInvokeAttempts);
}
}
}
@@ -614,14 +610,15 @@ internal async IAsyncEnumerable GetStreamingC
Verify.NotNull(chat);
OpenAIPromptExecutionSettings chatExecutionSettings = OpenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);
-
ValidateMaxTokens(chatExecutionSettings.MaxTokens);
- bool autoInvoke = kernel is not null && chatExecutionSettings.ToolCallBehavior?.MaximumAutoInvokeAttempts > 0 && s_inflightAutoInvokes.Value < MaxInflightAutoInvokes;
- ValidateAutoInvoke(autoInvoke, chatExecutionSettings.ResultsPerPrompt);
-
var chatOptions = this.CreateChatCompletionsOptions(chatExecutionSettings, chat, kernel, this.DeploymentOrModelName);
+ var functionCallConfiguration = this.ConfigureFunctionCalling(requestIndex: 0, kernel, chatExecutionSettings, chatOptions);
+
+ bool autoInvoke = kernel is not null && functionCallConfiguration?.MaximumAutoInvokeAttempts > 0 && s_inflightAutoInvokes.Value < MaxInflightAutoInvokes;
+ ValidateAutoInvoke(autoInvoke, chatExecutionSettings.ResultsPerPrompt);
+
StringBuilder? contentBuilder = null;
Dictionary? toolCallIdsByIndex = null;
Dictionary? functionNamesByIndex = null;
@@ -777,7 +774,7 @@ internal async IAsyncEnumerable GetStreamingC
// Make sure the requested function is one we requested. If we're permitting any kernel function to be invoked,
// then we don't need to check this, as it'll be handled when we look up the function in the kernel to be able
// to invoke it. If we're permitting only a specific list of functions, though, then we need to explicitly check.
- if (chatExecutionSettings.ToolCallBehavior?.AllowAnyRequestedKernelFunction is not true &&
+ if (functionCallConfiguration?.AllowAnyRequestedKernelFunction is not true &&
!IsRequestableTool(chatOptions, openAIFunctionToolCall))
{
AddResponseMessage(chatOptions, chat, result: null, "Error: Function call request for a function that wasn't defined.", toolCall, this.Logger);
@@ -854,42 +851,15 @@ internal async IAsyncEnumerable GetStreamingC
}
// Update tool use information for the next go-around based on having completed another iteration.
- Debug.Assert(chatExecutionSettings.ToolCallBehavior is not null);
-
- // Set the tool choice to none. If we end up wanting to use tools, we'll reset it to the desired value.
- chatOptions.ToolChoice = ChatCompletionsToolChoice.None;
- chatOptions.Tools.Clear();
-
- if (requestIndex >= chatExecutionSettings.ToolCallBehavior!.MaximumUseAttempts)
- {
- // Don't add any tools as we've reached the maximum attempts limit.
- if (this.Logger.IsEnabled(LogLevel.Debug))
- {
- this.Logger.LogDebug("Maximum use ({MaximumUse}) reached; removing the tool.", chatExecutionSettings.ToolCallBehavior!.MaximumUseAttempts);
- }
- }
- else
- {
- // Regenerate the tool list as necessary. The invocation of the function(s) could have augmented
- // what functions are available in the kernel.
- chatExecutionSettings.ToolCallBehavior.ConfigureOptions(kernel, chatOptions);
- }
-
- // Having already sent tools and with tool call information in history, the service can become unhappy ("[] is too short - 'tools'")
- // if we don't send any tools in subsequent requests, even if we say not to use any.
- if (chatOptions.ToolChoice == ChatCompletionsToolChoice.None)
- {
- Debug.Assert(chatOptions.Tools.Count == 0);
- chatOptions.Tools.Add(s_nonInvocableFunctionTool);
- }
+ functionCallConfiguration = this.ConfigureFunctionCalling(requestIndex, kernel, chatExecutionSettings, chatOptions);
// Disable auto invocation if we've exceeded the allowed limit.
- if (requestIndex >= chatExecutionSettings.ToolCallBehavior!.MaximumAutoInvokeAttempts)
+ if (requestIndex >= functionCallConfiguration?.MaximumAutoInvokeAttempts)
{
autoInvoke = false;
if (this.Logger.IsEnabled(LogLevel.Debug))
{
- this.Logger.LogDebug("Maximum auto-invoke ({MaximumAutoInvoke}) reached.", chatExecutionSettings.ToolCallBehavior!.MaximumAutoInvokeAttempts);
+ this.Logger.LogDebug("Maximum auto-invoke ({MaximumAutoInvoke}) reached.", functionCallConfiguration?.MaximumAutoInvokeAttempts);
}
}
}
@@ -1115,7 +1085,6 @@ private ChatCompletionsOptions CreateChatCompletionsOptions(
break;
}
- executionSettings.ToolCallBehavior?.ConfigureOptions(kernel, options);
if (executionSettings.TokenSelectionBiases is not null)
{
foreach (var keyValue in executionSettings.TokenSelectionBiases)
@@ -1571,4 +1540,154 @@ await autoFunctionInvocationFilters[index].OnAutoFunctionInvocationAsync(context
await functionCallCallback(context).ConfigureAwait(false);
}
}
+
+ ///
+ /// Configures the function calling functionality based on the provided parameters.
+ ///
+ /// Request sequence index of automatic function invocation process.
+ /// The to be used for function calling.
+ /// Execution settings for the completion API.
+ /// The chat completion options from the Azure.AI.OpenAI package.
+ private (bool? AllowAnyRequestedKernelFunction, int? MaximumAutoInvokeAttempts)? ConfigureFunctionCalling(int requestIndex, Kernel? kernel, OpenAIPromptExecutionSettings executionSettings, ChatCompletionsOptions chatOptions)
+ {
+ (bool? AllowAnyRequestedKernelFunction, int? MaximumAutoInvokeAttempts)? result = null;
+
+ // If neither behavior specified, we don't need to do anything.
+ if (executionSettings.FunctionChoiceBehavior is null && executionSettings.ToolCallBehavior is null)
+ {
+ return result;
+ }
+
+ // If both behaviors are specified, we can't handle that.
+ if (executionSettings.FunctionChoiceBehavior is not null && executionSettings.ToolCallBehavior is not null)
+ {
+ throw new ArgumentException($"{nameof(executionSettings.ToolCallBehavior)} and {nameof(executionSettings.FunctionChoiceBehavior)} cannot be used together.");
+ }
+
+ // Set the tool choice to none. If we end up wanting to use tools, we'll set it to the desired value.
+ chatOptions.ToolChoice = ChatCompletionsToolChoice.None;
+ chatOptions.Tools.Clear();
+
+ // Handling new tool behavior represented by `PromptExecutionSettings.FunctionChoiceBehavior` property.
+ if (executionSettings.FunctionChoiceBehavior is { } functionChoiceBehavior)
+ {
+ result = this.ConfigureFunctionCalling(requestIndex, kernel, chatOptions, functionChoiceBehavior);
+ }
+ // Handling old-style tool call behavior represented by `OpenAIPromptExecutionSettings.ToolCallBehavior` property.
+ else if (executionSettings.ToolCallBehavior is { } toolCallBehavior)
+ {
+ result = this.ConfigureFunctionCalling(requestIndex, kernel, chatOptions, toolCallBehavior);
+ }
+
+ // Having already sent tools and with tool call information in history, the service can become unhappy "Invalid 'tools': empty array. Expected an array with minimum length 1, but got an empty array instead."
+ // if we don't send any tools in subsequent requests, even if we say not to use any.
+ // Similarly, if we say not to use any tool (ToolChoice = ChatCompletionsToolChoice.None) and dont provide any for the first request,
+ // the service fails with "'tool_choice' is only allowed when 'tools' are specified."
+ if (chatOptions.ToolChoice == ChatCompletionsToolChoice.None && chatOptions.Tools.Count == 0)
+ {
+ chatOptions.Tools.Add(s_nonInvocableFunctionTool);
+ }
+
+ return result;
+ }
+
+ private (bool? AllowAnyRequestedKernelFunction, int? MaximumAutoInvokeAttempts)? ConfigureFunctionCalling(int requestIndex, Kernel? kernel, ChatCompletionsOptions chatOptions, FunctionChoiceBehavior functionChoiceBehavior)
+ {
+ // Regenerate the tool list as necessary and getting other call behavior properties. The invocation of the function(s) could have augmented
+ // what functions are available in the kernel.
+ var config = functionChoiceBehavior.GetConfiguration(new() { Kernel = kernel });
+
+ (bool? AllowAnyRequestedKernelFunction, int? MaximumAutoInvokeAttempts) result = new()
+ {
+ AllowAnyRequestedKernelFunction = config.AllowAnyRequestedKernelFunction,
+ MaximumAutoInvokeAttempts = config.Options.AutoInvoke ? MaximumAutoInvokeAttempts : 0,
+ };
+
+ if (config.Choice == FunctionChoice.Required && requestIndex >= MaximumUseAttempts)
+ {
+ // Don't add any tools as we've reached the maximum use attempts limit.
+ if (this.Logger.IsEnabled(LogLevel.Debug))
+ {
+ this.Logger.LogDebug("Maximum use ({MaximumUse}) reached; removing the functions.", MaximumUseAttempts);
+ }
+
+ return result;
+ }
+
+ if (config.Choice == FunctionChoice.Auto)
+ {
+ if (config.Functions is { Count: > 0 } functions)
+ {
+ chatOptions.ToolChoice = ChatCompletionsToolChoice.Auto;
+
+ foreach (var function in functions)
+ {
+ var functionDefinition = function.Metadata.ToOpenAIFunction().ToFunctionDefinition();
+ chatOptions.Tools.Add(new ChatCompletionsFunctionToolDefinition(functionDefinition));
+ }
+ }
+
+ return result;
+ }
+
+ if (config.Choice == FunctionChoice.Required)
+ {
+ if (config.Functions is { Count: > 0 } functions)
+ {
+ if (functions.Count > 1)
+ {
+ throw new KernelException("Only one required function is allowed.");
+ }
+
+ var functionDefinition = functions[0].Metadata.ToOpenAIFunction().ToFunctionDefinition();
+
+ chatOptions.ToolChoice = new ChatCompletionsToolChoice(functionDefinition);
+ chatOptions.Tools.Add(new ChatCompletionsFunctionToolDefinition(functionDefinition));
+ }
+
+ return result;
+ }
+
+ if (config.Choice == FunctionChoice.None)
+ {
+ if (config.Functions is { Count: > 0 } functions)
+ {
+ chatOptions.ToolChoice = ChatCompletionsToolChoice.None;
+
+ foreach (var function in functions)
+ {
+ var functionDefinition = function.Metadata.ToOpenAIFunction().ToFunctionDefinition();
+ chatOptions.Tools.Add(new ChatCompletionsFunctionToolDefinition(functionDefinition));
+ }
+ }
+
+ return result;
+ }
+
+ throw new NotSupportedException($"Unsupported function choice '{config.Choice}'.");
+ }
+
+ private (bool? AllowAnyRequestedKernelFunction, int? MaximumAutoInvokeAttempts)? ConfigureFunctionCalling(int requestIndex, Kernel? kernel, ChatCompletionsOptions chatOptions, ToolCallBehavior toolCallBehavior)
+ {
+ if (requestIndex >= toolCallBehavior.MaximumUseAttempts)
+ {
+ // Don't add any tools as we've reached the maximum attempts limit.
+ if (this.Logger.IsEnabled(LogLevel.Debug))
+ {
+ this.Logger.LogDebug("Maximum use ({MaximumUse}) reached; removing the tools.", toolCallBehavior.MaximumUseAttempts);
+ }
+ }
+ else
+ {
+ // Regenerate the tool list as necessary. The invocation of the function(s) could have augmented
+ // what functions are available in the kernel.
+ toolCallBehavior.ConfigureOptions(kernel, chatOptions);
+ }
+
+ return new()
+ {
+ AllowAnyRequestedKernelFunction = toolCallBehavior.AllowAnyRequestedKernelFunction,
+ MaximumAutoInvokeAttempts = toolCallBehavior.MaximumAutoInvokeAttempts,
+ };
+ }
}
diff --git a/dotnet/src/Connectors/Connectors.OpenAI/OpenAIPromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.OpenAI/OpenAIPromptExecutionSettings.cs
index 36796c62f7b9..8707adde442c 100644
--- a/dotnet/src/Connectors/Connectors.OpenAI/OpenAIPromptExecutionSettings.cs
+++ b/dotnet/src/Connectors/Connectors.OpenAI/OpenAIPromptExecutionSettings.cs
@@ -343,6 +343,7 @@ public override PromptExecutionSettings Clone()
ResponseFormat = this.ResponseFormat,
TokenSelectionBiases = this.TokenSelectionBiases is not null ? new Dictionary(this.TokenSelectionBiases) : null,
ToolCallBehavior = this.ToolCallBehavior,
+ FunctionChoiceBehavior = this.FunctionChoiceBehavior,
User = this.User,
ChatSystemPrompt = this.ChatSystemPrompt,
Logprobs = this.Logprobs,
@@ -382,6 +383,8 @@ public static OpenAIPromptExecutionSettings FromExecutionSettings(PromptExecutio
var openAIExecutionSettings = JsonSerializer.Deserialize(json, JsonOptionsCache.ReadPermissive);
if (openAIExecutionSettings is not null)
{
+ // Restores the original function choice behavior that lost internal state(list of functions) during serialization/deserialization process.
+ openAIExecutionSettings.FunctionChoiceBehavior = executionSettings.FunctionChoiceBehavior;
return openAIExecutionSettings;
}
diff --git a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/AzureOpenAIChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/AzureOpenAIChatCompletionServiceTests.cs
index 22be8458c2cc..625b9f30f228 100644
--- a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/AzureOpenAIChatCompletionServiceTests.cs
+++ b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/AzureOpenAIChatCompletionServiceTests.cs
@@ -936,6 +936,105 @@ public async Task FunctionResultsCanBeProvidedToLLMAsManyResultsInOneChatMessage
Assert.Equal("2", assistantMessage2.GetProperty("tool_call_id").GetString());
}
+ [Fact]
+ public async Task ItCreatesCorrectFunctionToolCallsWhenUsingAutoFunctionChoiceBehaviorAsync()
+ {
+ // Arrange
+ var kernel = new Kernel();
+ kernel.Plugins.AddFromFunctions("TimePlugin", [
+ KernelFunctionFactory.CreateFromMethod(() => { }, "Date"),
+ KernelFunctionFactory.CreateFromMethod(() => { }, "Now")
+ ]);
+
+ var chatCompletion = new AzureOpenAIChatCompletionService("deployment", "https://endpoint", "api-key", "model-id", this._httpClient);
+
+ this._messageHandlerStub.ResponsesToReturn.Add(new HttpResponseMessage(HttpStatusCode.OK)
+ {
+ Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_test_response.json"))
+ });
+
+ var executionSettings = new OpenAIPromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };
+
+ // Act
+ await chatCompletion.GetChatMessageContentsAsync([], executionSettings, kernel);
+
+ // Assert
+ var actualRequestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContents[0]!);
+ Assert.NotNull(actualRequestContent);
+
+ var optionsJson = JsonSerializer.Deserialize(actualRequestContent);
+ Assert.Equal(2, optionsJson.GetProperty("tools").GetArrayLength());
+ Assert.Equal("TimePlugin-Date", optionsJson.GetProperty("tools")[0].GetProperty("function").GetProperty("name").GetString());
+ Assert.Equal("TimePlugin-Now", optionsJson.GetProperty("tools")[1].GetProperty("function").GetProperty("name").GetString());
+
+ Assert.Equal("auto", optionsJson.GetProperty("tool_choice").ToString());
+ }
+
+ [Fact]
+ public async Task ItCreatesCorrectFunctionToolCallsWhenUsingRequiredFunctionChoiceBehaviorAsync()
+ {
+ // Arrange
+ var kernel = new Kernel();
+ kernel.Plugins.AddFromFunctions("TimePlugin", [
+ KernelFunctionFactory.CreateFromMethod(() => { }, "Date"),
+ ]);
+
+ var chatCompletion = new AzureOpenAIChatCompletionService("deployment", "https://endpoint", "api-key", "model-id", this._httpClient);
+
+ this._messageHandlerStub.ResponsesToReturn.Add(new HttpResponseMessage(HttpStatusCode.OK)
+ {
+ Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_test_response.json"))
+ });
+
+ var executionSettings = new OpenAIPromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Required() };
+
+ // Act
+ await chatCompletion.GetChatMessageContentsAsync([], executionSettings, kernel);
+
+ // Assert
+ var actualRequestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContents[0]!);
+ Assert.NotNull(actualRequestContent);
+
+ var optionsJson = JsonSerializer.Deserialize(actualRequestContent);
+ Assert.Equal(1, optionsJson.GetProperty("tools").GetArrayLength());
+ Assert.Equal("TimePlugin-Date", optionsJson.GetProperty("tools")[0].GetProperty("function").GetProperty("name").GetString());
+ Assert.Equal("TimePlugin-Date", optionsJson.GetProperty("tool_choice").GetProperty("function").GetProperty("name").ToString());
+ }
+
+ [Fact]
+ public async Task ItCreatesCorrectFunctionToolCallsWhenUsingNoneFunctionChoiceBehaviorAsync()
+ {
+ // Arrange
+ var kernel = new Kernel();
+ kernel.Plugins.AddFromFunctions("TimePlugin", [
+ KernelFunctionFactory.CreateFromMethod(() => { }, "Date"),
+ KernelFunctionFactory.CreateFromMethod(() => { }, "Now")
+ ]);
+
+ var chatCompletion = new AzureOpenAIChatCompletionService("deployment", "https://endpoint", "api-key", "model-id", this._httpClient);
+
+ this._messageHandlerStub.ResponsesToReturn.Add(new HttpResponseMessage(HttpStatusCode.OK)
+ {
+ Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_test_response.json"))
+ });
+
+ var executionSettings = new OpenAIPromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.None() };
+
+ // Act
+ await chatCompletion.GetChatMessageContentsAsync([], executionSettings, kernel);
+
+ // Assert
+ var actualRequestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContents[0]!);
+ Assert.NotNull(actualRequestContent);
+
+ var optionsJson = JsonSerializer.Deserialize(actualRequestContent);
+ Assert.Equal(2, optionsJson.GetProperty("tools").GetArrayLength());
+ Assert.Equal("TimePlugin-Date", optionsJson.GetProperty("tools")[0].GetProperty("function").GetProperty("name").GetString());
+ Assert.Equal("TimePlugin-Now", optionsJson.GetProperty("tools")[1].GetProperty("function").GetProperty("name").GetString());
+
+ Assert.Equal("none", optionsJson.GetProperty("tool_choice").ToString());
+ }
+
public void Dispose()
{
this._httpClient.Dispose();
diff --git a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/OpenAIChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/OpenAIChatCompletionServiceTests.cs
index 7d1c47388f91..3aa34d54af45 100644
--- a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/OpenAIChatCompletionServiceTests.cs
+++ b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/OpenAIChatCompletionServiceTests.cs
@@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Globalization;
using System.IO;
+using System.Linq;
using System.Net;
using System.Net.Http;
using System.Text;
@@ -27,7 +28,8 @@ public sealed class OpenAIChatCompletionServiceTests : IDisposable
{
private readonly HttpMessageHandlerStub _messageHandlerStub;
private readonly HttpClient _httpClient;
- private readonly OpenAIFunction _timepluginDate, _timepluginNow;
+ private readonly KernelPlugin _plugin;
+ private readonly KernelFunction _timepluginDate, _timepluginNow;
private readonly OpenAIPromptExecutionSettings _executionSettings;
private readonly Mock _mockLoggerFactory;
@@ -37,18 +39,21 @@ public OpenAIChatCompletionServiceTests()
this._httpClient = new HttpClient(this._messageHandlerStub, false);
this._mockLoggerFactory = new Mock();
- IList functions = KernelPluginFactory.CreateFromFunctions("TimePlugin", new[]
+ this._plugin = KernelPluginFactory.CreateFromFunctions("TimePlugin", new[]
{
KernelFunctionFactory.CreateFromMethod((string? format = null) => DateTime.Now.Date.ToString(format, CultureInfo.InvariantCulture), "Date", "TimePlugin.Date"),
KernelFunctionFactory.CreateFromMethod((string? format = null) => DateTime.Now.ToString(format, CultureInfo.InvariantCulture), "Now", "TimePlugin.Now"),
- }).GetFunctionsMetadata();
+ });
- this._timepluginDate = functions[0].ToOpenAIFunction();
- this._timepluginNow = functions[1].ToOpenAIFunction();
+ this._timepluginDate = this._plugin.ElementAt(0);
+ this._timepluginNow = this._plugin.ElementAt(1);
this._executionSettings = new()
{
- ToolCallBehavior = ToolCallBehavior.EnableFunctions([this._timepluginDate, this._timepluginNow])
+ ToolCallBehavior = ToolCallBehavior.EnableFunctions([
+ this._timepluginDate.Metadata.ToOpenAIFunction(),
+ this._timepluginNow.Metadata.ToOpenAIFunction()
+ ])
};
}
@@ -161,7 +166,7 @@ public async Task ItCreatesCorrectFunctionToolCallsWhenUsingNowAsync()
var chatCompletion = new OpenAIChatCompletionService(modelId: "gpt-3.5-turbo", apiKey: "NOKEY", httpClient: this._httpClient);
this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK)
{ Content = new StringContent(ChatCompletionResponse) };
- this._executionSettings.ToolCallBehavior = ToolCallBehavior.RequireFunction(this._timepluginNow);
+ this._executionSettings.ToolCallBehavior = ToolCallBehavior.RequireFunction(this._timepluginNow.Metadata.ToOpenAIFunction());
// Act
await chatCompletion.GetChatMessageContentsAsync([], this._executionSettings);
@@ -587,6 +592,100 @@ public async Task FunctionResultsCanBeProvidedToLLMAsManyResultsInOneChatMessage
Assert.Equal("2", assistantMessage2.GetProperty("tool_call_id").GetString());
}
+ [Fact]
+ public async Task ItCreatesCorrectFunctionToolCallsWhenUsingAutoFunctionChoiceBehaviorAsync()
+ {
+ // Arrange
+ var kernel = new Kernel();
+ kernel.Plugins.Add(this._plugin);
+
+ var chatCompletion = new OpenAIChatCompletionService(modelId: "gpt-3.5-turbo", apiKey: "NOKEY", httpClient: this._httpClient);
+
+ this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK)
+ {
+ Content = new StringContent(ChatCompletionResponse)
+ };
+
+ var executionSettings = new OpenAIPromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };
+
+ // Act
+ await chatCompletion.GetChatMessageContentsAsync([], executionSettings, kernel);
+
+ // Assert
+ var actualRequestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!);
+ Assert.NotNull(actualRequestContent);
+
+ var optionsJson = JsonSerializer.Deserialize(actualRequestContent);
+ Assert.Equal(2, optionsJson.GetProperty("tools").GetArrayLength());
+ Assert.Equal("TimePlugin-Date", optionsJson.GetProperty("tools")[0].GetProperty("function").GetProperty("name").GetString());
+ Assert.Equal("TimePlugin-Now", optionsJson.GetProperty("tools")[1].GetProperty("function").GetProperty("name").GetString());
+
+ Assert.Equal("auto", optionsJson.GetProperty("tool_choice").ToString());
+ }
+
+ [Fact]
+ public async Task ItCreatesCorrectFunctionToolCallsWhenUsingRequiredFunctionChoiceBehaviorAsync()
+ {
+ // Arrange
+ var kernel = new Kernel();
+ kernel.Plugins.AddFromFunctions("TimePlugin", [this._timepluginDate]);
+
+ var chatCompletion = new OpenAIChatCompletionService(modelId: "gpt-3.5-turbo", apiKey: "NOKEY", httpClient: this._httpClient);
+
+ this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK)
+ {
+ Content = new StringContent(ChatCompletionResponse)
+ };
+
+ var executionSettings = new OpenAIPromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Required() };
+
+ // Act
+ await chatCompletion.GetChatMessageContentsAsync([], executionSettings, kernel);
+
+ // Assert
+ var actualRequestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!);
+ Assert.NotNull(actualRequestContent);
+
+ var optionsJson = JsonSerializer.Deserialize(actualRequestContent);
+ Assert.Equal(1, optionsJson.GetProperty("tools").GetArrayLength());
+ Assert.Equal("TimePlugin-Date", optionsJson.GetProperty("tools")[0].GetProperty("function").GetProperty("name").GetString());
+ Assert.Equal("TimePlugin-Date", optionsJson.GetProperty("tool_choice").GetProperty("function").GetProperty("name").ToString());
+ }
+
+ [Fact]
+ public async Task ItCreatesCorrectFunctionToolCallsWhenUsingNoneFunctionChoiceBehaviorAsync()
+ {
+ // Arrange
+ var kernel = new Kernel();
+ kernel.Plugins.AddFromFunctions("TimePlugin", [
+ KernelFunctionFactory.CreateFromMethod(() => { }, "Date"),
+ KernelFunctionFactory.CreateFromMethod(() => { }, "Now")
+ ]);
+
+ var chatCompletion = new AzureOpenAIChatCompletionService("deployment", "https://endpoint", "api-key", "model-id", this._httpClient);
+
+ this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK)
+ {
+ Content = new StringContent(ChatCompletionResponse)
+ };
+
+ var executionSettings = new OpenAIPromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.None() };
+
+ // Act
+ await chatCompletion.GetChatMessageContentsAsync([], executionSettings, kernel);
+
+ // Assert
+ var actualRequestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!);
+ Assert.NotNull(actualRequestContent);
+
+ var optionsJson = JsonSerializer.Deserialize(actualRequestContent);
+ Assert.Equal(2, optionsJson.GetProperty("tools").GetArrayLength());
+ Assert.Equal("TimePlugin-Date", optionsJson.GetProperty("tools")[0].GetProperty("function").GetProperty("name").GetString());
+ Assert.Equal("TimePlugin-Now", optionsJson.GetProperty("tools")[1].GetProperty("function").GetProperty("name").GetString());
+
+ Assert.Equal("none", optionsJson.GetProperty("tool_choice").ToString());
+ }
+
public void Dispose()
{
this._httpClient.Dispose();
diff --git a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/OpenAIPromptExecutionSettingsTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/OpenAIPromptExecutionSettingsTests.cs
index b64649230d96..c29d060b5117 100644
--- a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/OpenAIPromptExecutionSettingsTests.cs
+++ b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/OpenAIPromptExecutionSettingsTests.cs
@@ -256,6 +256,22 @@ public void FromExecutionSettingsWithDataDoesNotIncludeEmptyStopSequences()
Assert.Null(executionSettingsWithData.StopSequences);
}
+ [Fact]
+ public void ItRestoresOriginalFunctionChoiceBehavior()
+ {
+ // Arrange
+ var functionChoiceBehavior = FunctionChoiceBehavior.None();
+
+ var originalExecutionSettings = new PromptExecutionSettings();
+ originalExecutionSettings.FunctionChoiceBehavior = functionChoiceBehavior;
+
+ // Act
+ var result = OpenAIPromptExecutionSettings.FromExecutionSettings(originalExecutionSettings);
+
+ // Assert
+ Assert.Equal(functionChoiceBehavior, result.FunctionChoiceBehavior);
+ }
+
private static void AssertExecutionSettings(OpenAIPromptExecutionSettings executionSettings)
{
Assert.NotNull(executionSettings);
diff --git a/dotnet/src/Functions/Functions.UnitTests/Markdown/Functions/KernelFunctionMarkdownTests.cs b/dotnet/src/Functions/Functions.UnitTests/Markdown/Functions/KernelFunctionMarkdownTests.cs
index a277284f3ccc..221752578bf6 100644
--- a/dotnet/src/Functions/Functions.UnitTests/Markdown/Functions/KernelFunctionMarkdownTests.cs
+++ b/dotnet/src/Functions/Functions.UnitTests/Markdown/Functions/KernelFunctionMarkdownTests.cs
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
+using System.Linq;
using Microsoft.SemanticKernel;
using Xunit;
@@ -18,9 +19,62 @@ public void ItShouldCreatePromptFunctionConfigFromMarkdown()
Assert.NotNull(model);
Assert.Equal("TellMeAbout", model.Name);
Assert.Equal("Hello AI, tell me about {{$input}}", model.Template);
- Assert.Equal(2, model.ExecutionSettings.Count);
+ Assert.Equal(3, model.ExecutionSettings.Count);
Assert.Equal("gpt4", model.ExecutionSettings["service1"].ModelId);
Assert.Equal("gpt3.5", model.ExecutionSettings["service2"].ModelId);
+ Assert.Equal("gpt3.5-turbo", model.ExecutionSettings["service3"].ModelId);
+ }
+
+ [Fact]
+ public void ItShouldInitializeFunctionChoiceBehaviorsFromMarkdown()
+ {
+ // Arrange
+ var kernel = new Kernel();
+ kernel.Plugins.AddFromFunctions("p1", [KernelFunctionFactory.CreateFromMethod(() => { }, "f1")]);
+ kernel.Plugins.AddFromFunctions("p2", [KernelFunctionFactory.CreateFromMethod(() => { }, "f2")]);
+ kernel.Plugins.AddFromFunctions("p3", [KernelFunctionFactory.CreateFromMethod(() => { }, "f3")]);
+
+ // Act
+ var function = KernelFunctionMarkdown.CreateFromPromptMarkdown(Markdown, "TellMeAbout");
+
+ // Assert
+ Assert.NotNull(function);
+ Assert.NotEmpty(function.ExecutionSettings);
+
+ Assert.Equal(3, function.ExecutionSettings.Count);
+
+ // AutoFunctionCallChoice for service1
+ var service1ExecutionSettings = function.ExecutionSettings["service1"];
+ Assert.NotNull(service1ExecutionSettings?.FunctionChoiceBehavior);
+
+ var autoConfig = service1ExecutionSettings.FunctionChoiceBehavior.GetConfiguration(new FunctionChoiceBehaviorContext() { Kernel = kernel });
+ Assert.NotNull(autoConfig);
+ Assert.Equal(FunctionChoice.Auto, autoConfig.Choice);
+ Assert.NotNull(autoConfig.Functions);
+ Assert.Equal("p1", autoConfig.Functions.Single().PluginName);
+ Assert.Equal("f1", autoConfig.Functions.Single().Name);
+
+ // RequiredFunctionCallChoice for service2
+ var service2ExecutionSettings = function.ExecutionSettings["service2"];
+ Assert.NotNull(service2ExecutionSettings?.FunctionChoiceBehavior);
+
+ var requiredConfig = service2ExecutionSettings.FunctionChoiceBehavior.GetConfiguration(new FunctionChoiceBehaviorContext() { Kernel = kernel });
+ Assert.NotNull(requiredConfig);
+ Assert.Equal(FunctionChoice.Required, requiredConfig.Choice);
+ Assert.NotNull(requiredConfig.Functions);
+ Assert.Equal("p2", requiredConfig.Functions.Single().PluginName);
+ Assert.Equal("f2", requiredConfig.Functions.Single().Name);
+
+ // NoneFunctionCallChoice for service3
+ var service3ExecutionSettings = function.ExecutionSettings["service3"];
+ Assert.NotNull(service3ExecutionSettings?.FunctionChoiceBehavior);
+
+ var noneConfig = service3ExecutionSettings.FunctionChoiceBehavior.GetConfiguration(new FunctionChoiceBehaviorContext() { Kernel = kernel });
+ Assert.NotNull(noneConfig);
+ Assert.Equal(FunctionChoice.None, noneConfig.Choice);
+ Assert.NotNull(noneConfig.Functions);
+ Assert.Equal("p3", noneConfig.Functions.Single().PluginName);
+ Assert.Equal("f3", noneConfig.Functions.Single().Name);
}
[Fact]
@@ -47,7 +101,11 @@ These are AI execution settings
{
"service1" : {
"model_id": "gpt4",
- "temperature": 0.7
+ "temperature": 0.7,
+ "function_choice_behavior": {
+ "type": "auto",
+ "functions": ["p1.f1"]
+ }
}
}
```
@@ -56,7 +114,24 @@ These are more AI execution settings
{
"service2" : {
"model_id": "gpt3.5",
- "temperature": 0.8
+ "temperature": 0.8,
+ "function_choice_behavior": {
+ "type": "required",
+ "functions": ["p2.f2"]
+ }
+ }
+ }
+ ```
+ These are AI execution settings as well
+ ```sk.execution_settings
+ {
+ "service3" : {
+ "model_id": "gpt3.5-turbo",
+ "temperature": 0.8,
+ "function_choice_behavior": {
+ "type": "none",
+ "functions": ["p3.f3"]
+ }
}
}
```
diff --git a/dotnet/src/Functions/Functions.UnitTests/Yaml/Functions/KernelFunctionYamlTests.cs b/dotnet/src/Functions/Functions.UnitTests/Yaml/Functions/KernelFunctionYamlTests.cs
index 30bce2a3fac2..d898822893ca 100644
--- a/dotnet/src/Functions/Functions.UnitTests/Yaml/Functions/KernelFunctionYamlTests.cs
+++ b/dotnet/src/Functions/Functions.UnitTests/Yaml/Functions/KernelFunctionYamlTests.cs
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
+using System.Linq;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Xunit;
@@ -68,7 +69,7 @@ public void ItShouldSupportCreatingOpenAIExecutionSettings()
// Arrange
var deserializer = new DeserializerBuilder()
.WithNamingConvention(UnderscoredNamingConvention.Instance)
- .WithNodeDeserializer(new PromptExecutionSettingsNodeDeserializer())
+ .WithTypeConverter(new PromptExecutionSettingsTypeConverter())
.Build();
var promptFunctionModel = deserializer.Deserialize(this._yaml);
@@ -82,6 +83,55 @@ public void ItShouldSupportCreatingOpenAIExecutionSettings()
Assert.Equal(0.0, executionSettings.TopP);
}
+ [Fact]
+ public void ItShouldDeserializeFunctionChoiceBehaviors()
+ {
+ // Act
+ var promptTemplateConfig = KernelFunctionYaml.ToPromptTemplateConfig(this._yaml);
+
+ var kernel = new Kernel();
+ kernel.Plugins.AddFromFunctions("p1", [KernelFunctionFactory.CreateFromMethod(() => { }, "f1")]);
+ kernel.Plugins.AddFromFunctions("p2", [KernelFunctionFactory.CreateFromMethod(() => { }, "f2")]);
+ kernel.Plugins.AddFromFunctions("p3", [KernelFunctionFactory.CreateFromMethod(() => { }, "f3")]);
+
+ // Assert
+ Assert.NotNull(promptTemplateConfig?.ExecutionSettings);
+ Assert.Equal(3, promptTemplateConfig.ExecutionSettings.Count);
+
+ // Service with auto function choice behavior
+ var service1ExecutionSettings = promptTemplateConfig.ExecutionSettings["service1"];
+ Assert.NotNull(service1ExecutionSettings?.FunctionChoiceBehavior);
+
+ var autoConfig = service1ExecutionSettings.FunctionChoiceBehavior.GetConfiguration(new FunctionChoiceBehaviorContext() { Kernel = kernel });
+ Assert.NotNull(autoConfig);
+ Assert.Equal(FunctionChoice.Auto, autoConfig.Choice);
+ Assert.NotNull(autoConfig.Functions);
+ Assert.Equal("p1", autoConfig.Functions.Single().PluginName);
+ Assert.Equal("f1", autoConfig.Functions.Single().Name);
+
+ // Service with required function choice behavior
+ var service2ExecutionSettings = promptTemplateConfig.ExecutionSettings["service2"];
+ Assert.NotNull(service2ExecutionSettings?.FunctionChoiceBehavior);
+
+ var requiredConfig = service2ExecutionSettings.FunctionChoiceBehavior.GetConfiguration(new FunctionChoiceBehaviorContext() { Kernel = kernel });
+ Assert.NotNull(requiredConfig);
+ Assert.Equal(FunctionChoice.Required, requiredConfig.Choice);
+ Assert.NotNull(requiredConfig.Functions);
+ Assert.Equal("p2", requiredConfig.Functions.Single().PluginName);
+ Assert.Equal("f2", requiredConfig.Functions.Single().Name);
+
+ // Service with none function choice behavior
+ var service3ExecutionSettings = promptTemplateConfig.ExecutionSettings["service3"];
+ Assert.NotNull(service3ExecutionSettings?.FunctionChoiceBehavior);
+
+ var noneConfig = service3ExecutionSettings.FunctionChoiceBehavior.GetConfiguration(new FunctionChoiceBehaviorContext() { Kernel = kernel });
+ Assert.NotNull(noneConfig);
+ Assert.Equal(FunctionChoice.None, noneConfig.Choice);
+ Assert.NotNull(noneConfig.Functions);
+ Assert.Equal("p3", noneConfig.Functions.Single().PluginName);
+ Assert.Equal("f3", noneConfig.Functions.Single().Name);
+ }
+
[Fact]
public void ItShouldCreateFunctionWithDefaultValueOfStringType()
{
@@ -157,6 +207,10 @@ string CreateYaml(object defaultValue)
frequency_penalty: 0.0
max_tokens: 256
stop_sequences: []
+ function_choice_behavior:
+ type: auto
+ functions:
+ - p1.f1
service2:
model_id: gpt-3.5
temperature: 1.0
@@ -165,6 +219,22 @@ string CreateYaml(object defaultValue)
frequency_penalty: 0.0
max_tokens: 256
stop_sequences: [ "foo", "bar", "baz" ]
+ function_choice_behavior:
+ type: required
+ functions:
+ - p2.f2
+ service3:
+ model_id: gpt-3.5
+ temperature: 1.0
+ top_p: 0.0
+ presence_penalty: 0.0
+ frequency_penalty: 0.0
+ max_tokens: 256
+ stop_sequences: [ "foo", "bar", "baz" ]
+ function_choice_behavior:
+ type: none
+ functions:
+ - p3.f3
""";
private readonly string _yamlWithCustomSettings = """
diff --git a/dotnet/src/Functions/Functions.UnitTests/Yaml/PromptExecutionSettingsNodeDeserializerTests.cs b/dotnet/src/Functions/Functions.UnitTests/Yaml/PromptExecutionSettingsNodeDeserializerTests.cs
deleted file mode 100644
index 140de66fdaa8..000000000000
--- a/dotnet/src/Functions/Functions.UnitTests/Yaml/PromptExecutionSettingsNodeDeserializerTests.cs
+++ /dev/null
@@ -1,68 +0,0 @@
-// Copyright (c) Microsoft. All rights reserved.
-
-using Microsoft.SemanticKernel;
-using Xunit;
-using YamlDotNet.Serialization;
-using YamlDotNet.Serialization.NamingConventions;
-
-namespace SemanticKernel.Functions.UnitTests.Yaml;
-
-///
-/// Tests for .
-///
-public sealed class PromptExecutionSettingsNodeDeserializerTests
-{
- [Fact]
- public void ItShouldCreatePromptFunctionFromYamlWithCustomModelSettings()
- {
- // Arrange
- var deserializer = new DeserializerBuilder()
- .WithNamingConvention(UnderscoredNamingConvention.Instance)
- .WithNodeDeserializer(new PromptExecutionSettingsNodeDeserializer())
- .Build();
-
- // Act
- var semanticFunctionConfig = deserializer.Deserialize(this._yaml);
-
- // Assert
- Assert.NotNull(semanticFunctionConfig);
- Assert.Equal("SayHello", semanticFunctionConfig.Name);
- Assert.Equal("Say hello to the specified person using the specified language", semanticFunctionConfig.Description);
- Assert.Equal(2, semanticFunctionConfig.InputVariables.Count);
- Assert.Equal("language", semanticFunctionConfig.InputVariables[1].Name);
- Assert.Equal(2, semanticFunctionConfig.ExecutionSettings.Count);
- Assert.Equal("gpt-4", semanticFunctionConfig.ExecutionSettings["service1"].ModelId);
- Assert.Equal("gpt-3.5", semanticFunctionConfig.ExecutionSettings["service2"].ModelId);
- }
-
- private readonly string _yaml = """
- template_format: semantic-kernel
- template: Say hello world to {{$name}} in {{$language}}
- description: Say hello to the specified person using the specified language
- name: SayHello
- input_variables:
- - name: name
- description: The name of the person to greet
- default: John
- - name: language
- description: The language to generate the greeting in
- default: English
- execution_settings:
- service1:
- model_id: gpt-4
- temperature: 1.0
- top_p: 0.0
- presence_penalty: 0.0
- frequency_penalty: 0.0
- max_tokens: 256
- stop_sequences: []
- service2:
- model_id: gpt-3.5
- temperature: 1.0
- top_p: 0.0
- presence_penalty: 0.0
- frequency_penalty: 0.0
- max_tokens: 256
- stop_sequences: [ "foo", "bar", "baz" ]
- """;
-}
diff --git a/dotnet/src/Functions/Functions.UnitTests/Yaml/PromptExecutionSettingsTypeConverterTests.cs b/dotnet/src/Functions/Functions.UnitTests/Yaml/PromptExecutionSettingsTypeConverterTests.cs
new file mode 100644
index 000000000000..2ef79ef0e850
--- /dev/null
+++ b/dotnet/src/Functions/Functions.UnitTests/Yaml/PromptExecutionSettingsTypeConverterTests.cs
@@ -0,0 +1,140 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System.Linq;
+using Microsoft.SemanticKernel;
+using Xunit;
+using YamlDotNet.Serialization;
+using YamlDotNet.Serialization.NamingConventions;
+
+namespace SemanticKernel.Functions.UnitTests.Yaml;
+
+///
+/// Tests for .
+///
+public sealed class PromptExecutionSettingsTypeConverterTests
+{
+ [Fact]
+ public void ItShouldCreatePromptFunctionFromYamlWithCustomModelSettings()
+ {
+ // Arrange
+ var deserializer = new DeserializerBuilder()
+ .WithNamingConvention(UnderscoredNamingConvention.Instance)
+ .WithTypeConverter(new PromptExecutionSettingsTypeConverter())
+ .Build();
+
+ // Act
+ var semanticFunctionConfig = deserializer.Deserialize(this._yaml);
+
+ // Assert
+ Assert.NotNull(semanticFunctionConfig);
+ Assert.Equal("SayHello", semanticFunctionConfig.Name);
+ Assert.Equal("Say hello to the specified person using the specified language", semanticFunctionConfig.Description);
+ Assert.Equal(2, semanticFunctionConfig.InputVariables.Count);
+ Assert.Equal("language", semanticFunctionConfig.InputVariables[1].Name);
+ Assert.Equal(3, semanticFunctionConfig.ExecutionSettings.Count);
+ Assert.Equal("gpt-4", semanticFunctionConfig.ExecutionSettings["service1"].ModelId);
+ Assert.Equal("gpt-3.5", semanticFunctionConfig.ExecutionSettings["service2"].ModelId);
+ Assert.Equal("gpt-3.5-turbo", semanticFunctionConfig.ExecutionSettings["service3"].ModelId);
+ }
+
+ [Fact]
+ public void ItShouldDeserializeFunctionChoiceBehaviors()
+ {
+ // Arrange
+ var kernel = new Kernel();
+ kernel.Plugins.AddFromFunctions("p1", [KernelFunctionFactory.CreateFromMethod(() => { }, "f1")]);
+ kernel.Plugins.AddFromFunctions("p2", [KernelFunctionFactory.CreateFromMethod(() => { }, "f2")]);
+ kernel.Plugins.AddFromFunctions("p3", [KernelFunctionFactory.CreateFromMethod(() => { }, "f3")]);
+
+ // Act
+ var promptTemplateConfig = KernelFunctionYaml.ToPromptTemplateConfig(this._yaml);
+
+ // Assert
+ Assert.NotNull(promptTemplateConfig?.ExecutionSettings);
+ Assert.Equal(3, promptTemplateConfig.ExecutionSettings.Count);
+
+ // Service with auto function choice behavior
+ var service1ExecutionSettings = promptTemplateConfig.ExecutionSettings["service1"];
+ Assert.NotNull(service1ExecutionSettings?.FunctionChoiceBehavior);
+
+ var autoConfig = service1ExecutionSettings.FunctionChoiceBehavior.GetConfiguration(new FunctionChoiceBehaviorContext() { Kernel = kernel });
+ Assert.NotNull(autoConfig);
+ Assert.Equal(FunctionChoice.Auto, autoConfig.Choice);
+ Assert.NotNull(autoConfig.Functions);
+ Assert.Equal("p1", autoConfig.Functions.Single().PluginName);
+ Assert.Equal("f1", autoConfig.Functions.Single().Name);
+
+ // Service with required function choice behavior
+ var service2ExecutionSettings = promptTemplateConfig.ExecutionSettings["service2"];
+ Assert.NotNull(service2ExecutionSettings?.FunctionChoiceBehavior);
+
+ var requiredConfig = service2ExecutionSettings.FunctionChoiceBehavior.GetConfiguration(new FunctionChoiceBehaviorContext() { Kernel = kernel });
+ Assert.NotNull(requiredConfig);
+ Assert.Equal(FunctionChoice.Required, requiredConfig.Choice);
+ Assert.NotNull(requiredConfig.Functions);
+ Assert.Equal("p2", requiredConfig.Functions.Single().PluginName);
+ Assert.Equal("f2", requiredConfig.Functions.Single().Name);
+
+ // Service with none function choice behavior
+ var service3ExecutionSettings = promptTemplateConfig.ExecutionSettings["service3"];
+ Assert.NotNull(service3ExecutionSettings?.FunctionChoiceBehavior);
+
+ var noneConfig = service3ExecutionSettings.FunctionChoiceBehavior.GetConfiguration(new FunctionChoiceBehaviorContext() { Kernel = kernel });
+ Assert.NotNull(noneConfig);
+ Assert.Equal(FunctionChoice.None, noneConfig.Choice);
+ Assert.NotNull(noneConfig.Functions);
+ Assert.Equal("p3", noneConfig.Functions.Single().PluginName);
+ Assert.Equal("f3", noneConfig.Functions.Single().Name);
+ }
+
+ private readonly string _yaml = """
+ template_format: semantic-kernel
+ template: Say hello world to {{$name}} in {{$language}}
+ description: Say hello to the specified person using the specified language
+ name: SayHello
+ input_variables:
+ - name: name
+ description: The name of the person to greet
+ default: John
+ - name: language
+ description: The language to generate the greeting in
+ default: English
+ execution_settings:
+ service1:
+ model_id: gpt-4
+ temperature: 1.0
+ top_p: 0.0
+ presence_penalty: 0.0
+ frequency_penalty: 0.0
+ max_tokens: 256
+ stop_sequences: []
+ function_choice_behavior:
+ type: auto
+ functions:
+ - p1.f1
+ service2:
+ model_id: gpt-3.5
+ temperature: 1.0
+ top_p: 0.0
+ presence_penalty: 0.0
+ frequency_penalty: 0.0
+ max_tokens: 256
+ stop_sequences: [ "foo", "bar", "baz" ]
+ function_choice_behavior:
+ type: required
+ functions:
+ - p2.f2
+ service3:
+ model_id: gpt-3.5-turbo
+ temperature: 1.0
+ top_p: 0.0
+ presence_penalty: 0.0
+ frequency_penalty: 0.0
+ max_tokens: 256
+ stop_sequences: [ "foo", "bar", "baz" ]
+ function_choice_behavior:
+ type: none
+ functions:
+ - p3.f3
+ """;
+}
diff --git a/dotnet/src/Functions/Functions.Yaml/KernelFunctionYaml.cs b/dotnet/src/Functions/Functions.Yaml/KernelFunctionYaml.cs
index ec2a26fc2b61..863d991bb207 100644
--- a/dotnet/src/Functions/Functions.Yaml/KernelFunctionYaml.cs
+++ b/dotnet/src/Functions/Functions.Yaml/KernelFunctionYaml.cs
@@ -57,7 +57,7 @@ public static PromptTemplateConfig ToPromptTemplateConfig(string text)
{
var deserializer = new DeserializerBuilder()
.WithNamingConvention(UnderscoredNamingConvention.Instance)
- .WithNodeDeserializer(new PromptExecutionSettingsNodeDeserializer())
+ .WithTypeConverter(new PromptExecutionSettingsTypeConverter())
.Build();
return deserializer.Deserialize(text);
diff --git a/dotnet/src/Functions/Functions.Yaml/PromptExecutionSettingsNodeDeserializer.cs b/dotnet/src/Functions/Functions.Yaml/PromptExecutionSettingsNodeDeserializer.cs
deleted file mode 100644
index 5bd7b839b068..000000000000
--- a/dotnet/src/Functions/Functions.Yaml/PromptExecutionSettingsNodeDeserializer.cs
+++ /dev/null
@@ -1,43 +0,0 @@
-// Copyright (c) Microsoft. All rights reserved.
-
-using System;
-using System.Collections.Generic;
-using YamlDotNet.Core;
-using YamlDotNet.Serialization;
-
-namespace Microsoft.SemanticKernel;
-
-///
-/// Deserializer for .
-///
-internal sealed class PromptExecutionSettingsNodeDeserializer : INodeDeserializer
-{
- ///
- public bool Deserialize(IParser reader, Type expectedType, Func nestedObjectDeserializer, out object? value)
- {
- if (expectedType != typeof(PromptExecutionSettings))
- {
- value = null;
- return false;
- }
-
- var dictionary = nestedObjectDeserializer.Invoke(reader, typeof(Dictionary));
- var modelSettings = new PromptExecutionSettings();
- foreach (var kv in (Dictionary)dictionary!)
- {
- switch (kv.Key)
- {
- case "model_id":
- modelSettings.ModelId = (string)kv.Value;
- break;
-
- default:
- (modelSettings.ExtensionData ??= new Dictionary()).Add(kv.Key, kv.Value);
- break;
- }
- }
-
- value = modelSettings;
- return true;
- }
-}
diff --git a/dotnet/src/Functions/Functions.Yaml/PromptExecutionSettingsTypeConverter.cs b/dotnet/src/Functions/Functions.Yaml/PromptExecutionSettingsTypeConverter.cs
new file mode 100644
index 000000000000..3f128806c145
--- /dev/null
+++ b/dotnet/src/Functions/Functions.Yaml/PromptExecutionSettingsTypeConverter.cs
@@ -0,0 +1,98 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text.Json.Serialization;
+using YamlDotNet.Core;
+using YamlDotNet.Core.Events;
+using YamlDotNet.Serialization;
+using YamlDotNet.Serialization.BufferedDeserialization;
+using YamlDotNet.Serialization.NamingConventions;
+
+namespace Microsoft.SemanticKernel;
+
+///
+/// Allows custom deserialization for from YAML prompts.
+///
+internal sealed class PromptExecutionSettingsTypeConverter : IYamlTypeConverter
+{
+ private static IDeserializer? s_deserializer;
+
+ ///
+ public bool Accepts(Type type)
+ {
+ return type == typeof(PromptExecutionSettings);
+ }
+
+ ///
+ public object? ReadYaml(IParser parser, Type type)
+ {
+ s_deserializer ??= new DeserializerBuilder()
+ .WithNamingConvention(UnderscoredNamingConvention.Instance)
+ .IgnoreUnmatchedProperties() // Required to ignore the 'type' property used as type discrimination. Otherwise, the "Property 'type' not found on type '{type.FullName}'" exception is thrown.
+ .WithTypeDiscriminatingNodeDeserializer(ConfigureTypeDiscriminatingNodeDeserializer)
+ .Build();
+
+ parser.MoveNext(); // Move to the first property
+
+ var executionSettings = new PromptExecutionSettings();
+ while (parser.Current is not MappingEnd)
+ {
+ var propertyName = parser.Consume().Value;
+ switch (propertyName)
+ {
+ case "model_id":
+ executionSettings.ModelId = s_deserializer.Deserialize(parser);
+ break;
+ case "function_choice_behavior":
+#pragma warning disable SKEXP0001
+ executionSettings.FunctionChoiceBehavior = s_deserializer.Deserialize(parser);
+#pragma warning restore SKEXP0010
+ break;
+ default:
+ (executionSettings.ExtensionData ??= new Dictionary()).Add(propertyName, s_deserializer.Deserialize