diff --git a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs index 47da5614adf2..e3a3d7574545 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs @@ -1036,7 +1036,7 @@ private static CompletionsOptions CreateCompletionsOptions(string text, OpenAIPr Echo = false, ChoicesPerPrompt = executionSettings.ResultsPerPrompt, GenerationSampleCount = executionSettings.ResultsPerPrompt, - LogProbabilityCount = null, + LogProbabilityCount = executionSettings.TopLogprobs, User = executionSettings.User, DeploymentName = deploymentOrModelName }; @@ -1088,7 +1088,9 @@ private ChatCompletionsOptions CreateChatCompletionsOptions( ChoiceCount = executionSettings.ResultsPerPrompt, DeploymentName = deploymentOrModelName, Seed = executionSettings.Seed, - User = executionSettings.User + User = executionSettings.User, + LogProbabilitiesPerToken = executionSettings.TopLogprobs, + EnableLogProbabilities = executionSettings.Logprobs }; switch (executionSettings.ResponseFormat) diff --git a/dotnet/src/Connectors/Connectors.OpenAI/OpenAIPromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.OpenAI/OpenAIPromptExecutionSettings.cs index f88cb18b7950..b4097b7020da 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/OpenAIPromptExecutionSettings.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/OpenAIPromptExecutionSettings.cs @@ -254,6 +254,39 @@ public string? User } } + /// + /// Whether to return log probabilities of the output tokens or not. + /// If true, returns the log probabilities of each output token returned in the `content` of `message`. + /// + [Experimental("SKEXP0010")] + [JsonPropertyName("logprobs")] + public bool? Logprobs + { + get => this._logprobs; + + set + { + this.ThrowIfFrozen(); + this._logprobs = value; + } + } + + /// + /// An integer specifying the number of most likely tokens to return at each token position, each with an associated log probability. + /// + [Experimental("SKEXP0010")] + [JsonPropertyName("top_logprobs")] + public int? TopLogprobs + { + get => this._topLogprobs; + + set + { + this.ThrowIfFrozen(); + this._topLogprobs = value; + } + } + /// public override void Freeze() { @@ -294,7 +327,9 @@ public override PromptExecutionSettings Clone() TokenSelectionBiases = this.TokenSelectionBiases is not null ? new Dictionary(this.TokenSelectionBiases) : null, ToolCallBehavior = this.ToolCallBehavior, User = this.User, - ChatSystemPrompt = this.ChatSystemPrompt + ChatSystemPrompt = this.ChatSystemPrompt, + Logprobs = this.Logprobs, + TopLogprobs = this.TopLogprobs }; } @@ -370,6 +405,8 @@ public static OpenAIPromptExecutionSettings FromExecutionSettingsWithData(Prompt private ToolCallBehavior? _toolCallBehavior; private string? _user; private string? _chatSystemPrompt; + private bool? _logprobs; + private int? _topLogprobs; #endregion } diff --git a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/AzureOpenAIChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/AzureOpenAIChatCompletionServiceTests.cs index c8d6c0de5f40..159fcd7d852c 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/AzureOpenAIChatCompletionServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/AzureOpenAIChatCompletionServiceTests.cs @@ -161,7 +161,9 @@ public async Task GetChatMessageContentsHandlesSettingsCorrectlyAsync() ResultsPerPrompt = 5, Seed = 567, TokenSelectionBiases = new Dictionary { { 2, 3 } }, - StopSequences = ["stop_sequence"] + StopSequences = ["stop_sequence"], + Logprobs = true, + TopLogprobs = 5 }; var chatHistory = new ChatHistory(); @@ -218,6 +220,8 @@ public async Task GetChatMessageContentsHandlesSettingsCorrectlyAsync() Assert.Equal(567, content.GetProperty("seed").GetInt32()); Assert.Equal(3, content.GetProperty("logit_bias").GetProperty("2").GetInt32()); Assert.Equal("stop_sequence", content.GetProperty("stop")[0].GetString()); + Assert.True(content.GetProperty("logprobs").GetBoolean()); + Assert.Equal(5, content.GetProperty("top_logprobs").GetInt32()); } [Theory] diff --git a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/OpenAIPromptExecutionSettingsTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/OpenAIPromptExecutionSettingsTests.cs index 6def578e8821..c951f821b348 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/OpenAIPromptExecutionSettingsTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/OpenAIPromptExecutionSettingsTests.cs @@ -30,6 +30,8 @@ public void ItCreatesOpenAIExecutionSettingsWithCorrectDefaults() Assert.Equal(1, executionSettings.ResultsPerPrompt); Assert.Null(executionSettings.StopSequences); Assert.Null(executionSettings.TokenSelectionBiases); + Assert.Null(executionSettings.TopLogprobs); + Assert.Null(executionSettings.Logprobs); Assert.Equal(128, executionSettings.MaxTokens); } @@ -47,6 +49,8 @@ public void ItUsesExistingOpenAIExecutionSettings() StopSequences = new string[] { "foo", "bar" }, ChatSystemPrompt = "chat system prompt", MaxTokens = 128, + Logprobs = true, + TopLogprobs = 5, TokenSelectionBiases = new Dictionary() { { 1, 2 }, { 3, 4 } }, }; @@ -97,6 +101,8 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesSnakeCase() { "max_tokens", 128 }, { "token_selection_biases", new Dictionary() { { 1, 2 }, { 3, 4 } } }, { "seed", 123456 }, + { "logprobs", true }, + { "top_logprobs", 5 }, } }; @@ -105,7 +111,6 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesSnakeCase() // Assert AssertExecutionSettings(executionSettings); - Assert.Equal(executionSettings.Seed, 123456); } [Fact] @@ -124,7 +129,10 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesAsStrings() { "stop_sequences", new [] { "foo", "bar" } }, { "chat_system_prompt", "chat system prompt" }, { "max_tokens", "128" }, - { "token_selection_biases", new Dictionary() { { "1", "2" }, { "3", "4" } } } + { "token_selection_biases", new Dictionary() { { "1", "2" }, { "3", "4" } } }, + { "seed", 123456 }, + { "logprobs", true }, + { "top_logprobs", 5 } } }; @@ -149,7 +157,10 @@ public void ItCreatesOpenAIExecutionSettingsFromJsonSnakeCase() "stop_sequences": [ "foo", "bar" ], "chat_system_prompt": "chat system prompt", "token_selection_biases": { "1": 2, "3": 4 }, - "max_tokens": 128 + "max_tokens": 128, + "seed": 123456, + "logprobs": true, + "top_logprobs": 5 } """; var actualSettings = JsonSerializer.Deserialize(json); @@ -255,5 +266,8 @@ private static void AssertExecutionSettings(OpenAIPromptExecutionSettings execut Assert.Equal("chat system prompt", executionSettings.ChatSystemPrompt); Assert.Equal(new Dictionary() { { 1, 2 }, { 3, 4 } }, executionSettings.TokenSelectionBiases); Assert.Equal(128, executionSettings.MaxTokens); + Assert.Equal(123456, executionSettings.Seed); + Assert.Equal(true, executionSettings.Logprobs); + Assert.Equal(5, executionSettings.TopLogprobs); } } diff --git a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/TextGeneration/AzureOpenAITextGenerationServiceTests.cs b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/TextGeneration/AzureOpenAITextGenerationServiceTests.cs index 87f5526d5f83..d20bb502e23d 100644 --- a/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/TextGeneration/AzureOpenAITextGenerationServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.UnitTests/OpenAI/TextGeneration/AzureOpenAITextGenerationServiceTests.cs @@ -126,7 +126,8 @@ public async Task GetTextContentsHandlesSettingsCorrectlyAsync() PresencePenalty = 1.2, ResultsPerPrompt = 5, TokenSelectionBiases = new Dictionary { { 2, 3 } }, - StopSequences = ["stop_sequence"] + StopSequences = ["stop_sequence"], + TopLogprobs = 5 }; this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(HttpStatusCode.OK) @@ -154,6 +155,7 @@ public async Task GetTextContentsHandlesSettingsCorrectlyAsync() Assert.Equal(5, content.GetProperty("best_of").GetInt32()); Assert.Equal(3, content.GetProperty("logit_bias").GetProperty("2").GetInt32()); Assert.Equal("stop_sequence", content.GetProperty("stop")[0].GetString()); + Assert.Equal(5, content.GetProperty("logprobs").GetInt32()); } [Fact] diff --git a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAICompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAICompletionTests.cs index 6b07e9b7b7ba..a2285a1c4dd5 100644 --- a/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAICompletionTests.cs +++ b/dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAICompletionTests.cs @@ -9,6 +9,7 @@ using System.Text.Json; using System.Threading; using System.Threading.Tasks; +using Azure.AI.OpenAI; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Http.Resilience; @@ -504,6 +505,38 @@ public async Task SemanticKernelVersionHeaderIsSentAsync() Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var values)); } + [Theory(Skip = "This test is for manual verification.")] + [InlineData(null, null)] + [InlineData(false, null)] + [InlineData(true, 2)] + [InlineData(true, 5)] + public async Task LogProbsDataIsReturnedWhenRequestedAsync(bool? logprobs, int? topLogprobs) + { + // Arrange + var settings = new OpenAIPromptExecutionSettings { Logprobs = logprobs, TopLogprobs = topLogprobs }; + + this._kernelBuilder.Services.AddSingleton(this._logger); + var builder = this._kernelBuilder; + this.ConfigureAzureOpenAIChatAsText(builder); + Kernel target = builder.Build(); + + // Act + var result = await target.InvokePromptAsync("Hi, can you help me today?", new(settings)); + + var logProbabilityInfo = result.Metadata?["LogProbabilityInfo"] as ChatChoiceLogProbabilityInfo; + + // Assert + if (logprobs is true) + { + Assert.NotNull(logProbabilityInfo); + Assert.Equal(topLogprobs, logProbabilityInfo.TokenLogProbabilityResults[0].TopLogProbabilityEntries.Count); + } + else + { + Assert.Null(logProbabilityInfo); + } + } + #region internals private readonly XunitLogger _logger = new(output);