Skip to content

.Net: Added logprobs property to OpenAIPromptExecutionSettings #6300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,7 @@ private static CompletionsOptions CreateCompletionsOptions(string text, OpenAIPr
Echo = false,
ChoicesPerPrompt = executionSettings.ResultsPerPrompt,
GenerationSampleCount = executionSettings.ResultsPerPrompt,
LogProbabilityCount = null,
LogProbabilityCount = executionSettings.TopLogprobs,
User = executionSettings.User,
DeploymentName = deploymentOrModelName
};
Expand Down Expand Up @@ -1081,7 +1081,9 @@ private static ChatCompletionsOptions CreateChatCompletionsOptions(
ChoiceCount = executionSettings.ResultsPerPrompt,
DeploymentName = deploymentOrModelName,
Seed = executionSettings.Seed,
User = executionSettings.User
User = executionSettings.User,
LogProbabilitiesPerToken = executionSettings.TopLogprobs,
EnableLogProbabilities = executionSettings.Logprobs
};

switch (executionSettings.ResponseFormat)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,39 @@ public string? User
}
}

/// <summary>
/// Whether to return log probabilities of the output tokens or not.
/// If true, returns the log probabilities of each output token returned in the `content` of `message`.
/// </summary>
[Experimental("SKEXP0010")]
[JsonPropertyName("logprobs")]
public bool? Logprobs
{
get => this._logprobs;

set
{
this.ThrowIfFrozen();
this._logprobs = value;
}
}

/// <summary>
/// An integer specifying the number of most likely tokens to return at each token position, each with an associated log probability.
/// </summary>
[Experimental("SKEXP0010")]
[JsonPropertyName("top_logprobs")]
public int? TopLogprobs
{
get => this._topLogprobs;

set
{
this.ThrowIfFrozen();
this._topLogprobs = value;
}
}

/// <inheritdoc/>
public override void Freeze()
{
Expand Down Expand Up @@ -294,7 +327,9 @@ public override PromptExecutionSettings Clone()
TokenSelectionBiases = this.TokenSelectionBiases is not null ? new Dictionary<int, int>(this.TokenSelectionBiases) : null,
ToolCallBehavior = this.ToolCallBehavior,
User = this.User,
ChatSystemPrompt = this.ChatSystemPrompt
ChatSystemPrompt = this.ChatSystemPrompt,
Logprobs = this.Logprobs,
TopLogprobs = this.TopLogprobs
};
}

Expand Down Expand Up @@ -370,6 +405,8 @@ public static OpenAIPromptExecutionSettings FromExecutionSettingsWithData(Prompt
private ToolCallBehavior? _toolCallBehavior;
private string? _user;
private string? _chatSystemPrompt;
private bool? _logprobs;
private int? _topLogprobs;

#endregion
}
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ public async Task GetChatMessageContentsHandlesSettingsCorrectlyAsync()
ResultsPerPrompt = 5,
Seed = 567,
TokenSelectionBiases = new Dictionary<int, int> { { 2, 3 } },
StopSequences = ["stop_sequence"]
StopSequences = ["stop_sequence"],
Logprobs = true,
TopLogprobs = 5
};

var chatHistory = new ChatHistory();
Expand Down Expand Up @@ -218,6 +220,8 @@ public async Task GetChatMessageContentsHandlesSettingsCorrectlyAsync()
Assert.Equal(567, content.GetProperty("seed").GetInt32());
Assert.Equal(3, content.GetProperty("logit_bias").GetProperty("2").GetInt32());
Assert.Equal("stop_sequence", content.GetProperty("stop")[0].GetString());
Assert.True(content.GetProperty("logprobs").GetBoolean());
Assert.Equal(5, content.GetProperty("top_logprobs").GetInt32());
}

[Theory]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ public void ItCreatesOpenAIExecutionSettingsWithCorrectDefaults()
Assert.Equal(1, executionSettings.ResultsPerPrompt);
Assert.Null(executionSettings.StopSequences);
Assert.Null(executionSettings.TokenSelectionBiases);
Assert.Null(executionSettings.TopLogprobs);
Assert.Null(executionSettings.Logprobs);
Assert.Equal(128, executionSettings.MaxTokens);
}

Expand All @@ -47,6 +49,8 @@ public void ItUsesExistingOpenAIExecutionSettings()
StopSequences = new string[] { "foo", "bar" },
ChatSystemPrompt = "chat system prompt",
MaxTokens = 128,
Logprobs = true,
TopLogprobs = 5,
TokenSelectionBiases = new Dictionary<int, int>() { { 1, 2 }, { 3, 4 } },
};

Expand Down Expand Up @@ -97,6 +101,8 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesSnakeCase()
{ "max_tokens", 128 },
{ "token_selection_biases", new Dictionary<int, int>() { { 1, 2 }, { 3, 4 } } },
{ "seed", 123456 },
{ "logprobs", true },
{ "top_logprobs", 5 },
}
};

Expand All @@ -105,7 +111,6 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesSnakeCase()

// Assert
AssertExecutionSettings(executionSettings);
Assert.Equal(executionSettings.Seed, 123456);
}

[Fact]
Expand All @@ -124,7 +129,10 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesAsStrings()
{ "stop_sequences", new [] { "foo", "bar" } },
{ "chat_system_prompt", "chat system prompt" },
{ "max_tokens", "128" },
{ "token_selection_biases", new Dictionary<string, string>() { { "1", "2" }, { "3", "4" } } }
{ "token_selection_biases", new Dictionary<string, string>() { { "1", "2" }, { "3", "4" } } },
{ "seed", 123456 },
{ "logprobs", true },
{ "top_logprobs", 5 }
}
};

Expand All @@ -149,7 +157,10 @@ public void ItCreatesOpenAIExecutionSettingsFromJsonSnakeCase()
"stop_sequences": [ "foo", "bar" ],
"chat_system_prompt": "chat system prompt",
"token_selection_biases": { "1": 2, "3": 4 },
"max_tokens": 128
"max_tokens": 128,
"seed": 123456,
"logprobs": true,
"top_logprobs": 5
}
""";
var actualSettings = JsonSerializer.Deserialize<PromptExecutionSettings>(json);
Expand Down Expand Up @@ -255,5 +266,8 @@ private static void AssertExecutionSettings(OpenAIPromptExecutionSettings execut
Assert.Equal("chat system prompt", executionSettings.ChatSystemPrompt);
Assert.Equal(new Dictionary<int, int>() { { 1, 2 }, { 3, 4 } }, executionSettings.TokenSelectionBiases);
Assert.Equal(128, executionSettings.MaxTokens);
Assert.Equal(123456, executionSettings.Seed);
Assert.Equal(true, executionSettings.Logprobs);
Assert.Equal(5, executionSettings.TopLogprobs);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ public async Task GetTextContentsHandlesSettingsCorrectlyAsync()
PresencePenalty = 1.2,
ResultsPerPrompt = 5,
TokenSelectionBiases = new Dictionary<int, int> { { 2, 3 } },
StopSequences = ["stop_sequence"]
StopSequences = ["stop_sequence"],
TopLogprobs = 5
};

this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(HttpStatusCode.OK)
Expand Down Expand Up @@ -154,6 +155,7 @@ public async Task GetTextContentsHandlesSettingsCorrectlyAsync()
Assert.Equal(5, content.GetProperty("best_of").GetInt32());
Assert.Equal(3, content.GetProperty("logit_bias").GetProperty("2").GetInt32());
Assert.Equal("stop_sequence", content.GetProperty("stop")[0].GetString());
Assert.Equal(5, content.GetProperty("logprobs").GetInt32());
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.OpenAI;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Http.Resilience;
Expand Down Expand Up @@ -504,6 +505,38 @@ public async Task SemanticKernelVersionHeaderIsSentAsync()
Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var values));
}

[Theory(Skip = "This test is for manual verification.")]
[InlineData(null, null)]
[InlineData(false, null)]
[InlineData(true, 2)]
[InlineData(true, 5)]
public async Task LogProbsDataIsReturnedWhenRequestedAsync(bool? logprobs, int? topLogprobs)
{
// Arrange
var settings = new OpenAIPromptExecutionSettings { Logprobs = logprobs, TopLogprobs = topLogprobs };

this._kernelBuilder.Services.AddSingleton<ILoggerFactory>(this._logger);
var builder = this._kernelBuilder;
this.ConfigureAzureOpenAIChatAsText(builder);
Kernel target = builder.Build();

// Act
var result = await target.InvokePromptAsync("Hi, can you help me today?", new(settings));

var logProbabilityInfo = result.Metadata?["LogProbabilityInfo"] as ChatChoiceLogProbabilityInfo;

// Assert
if (logprobs is true)
{
Assert.NotNull(logProbabilityInfo);
Assert.Equal(topLogprobs, logProbabilityInfo.TokenLogProbabilityResults[0].TopLogProbabilityEntries.Count);
}
else
{
Assert.Null(logProbabilityInfo);
}
}

#region internals

private readonly XunitLogger<Kernel> _logger = new(output);
Expand Down
Loading