Skip to content

.Net: Code clean-up #6266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 13 additions & 27 deletions dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
}

// Find the function in the kernel and populate the arguments.
if (!kernel!.Plugins.TryGetFunctionAndArguments(kernel, toolCall.Function, out KernelFunction? function, out KernelArguments? functionArgs))
if (!kernel!.Plugins.TryGetFunctionAndArguments(toolCall.Function, out KernelFunction? function, out KernelArguments? functionArgs))
{
this.AddResponseMessage(chatRequest, chatHistory, toolCall, result: null, "Error: Requested function could not be found.");
continue;
Expand Down Expand Up @@ -237,15 +237,14 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMes
toolCalls?.Clear();

// Stream the responses
var response = this.StreamChatMessageContentsAsync(chatHistory, mistralExecutionSettings, chatRequest, modelId, cancellationToken, kernel);
var response = this.StreamChatMessageContentsAsync(chatHistory, mistralExecutionSettings, chatRequest, modelId, cancellationToken);
string? streamedRole = null;
await foreach (var update in response.ConfigureAwait(false))
{
// If we're intending to invoke function calls, we need to consume that function call information.
if (autoInvoke)
{
var completionChunk = update.InnerContent as MistralChatCompletionChunk;
if (completionChunk is null || completionChunk.Choices is null || completionChunk.Choices?.Count == 0)
if (update.InnerContent is not MistralChatCompletionChunk completionChunk || completionChunk.Choices is null || completionChunk.Choices?.Count == 0)
{
continue;
}
Expand All @@ -261,7 +260,7 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMes
// to understand the tool call responses. Also add the result message to the caller's chat
// history: if they don't want it, they can remove it, but this makes the data available,
// including metadata like usage.
chatRequest.AddMessage(new MistralChatMessage(streamedRole!, completionChunk.GetContent(0)) { ToolCalls = chatChoice.ToolCalls });
chatRequest.AddMessage(new MistralChatMessage(streamedRole, completionChunk.GetContent(0)) { ToolCalls = chatChoice.ToolCalls });
chatHistory.Add(this.ToChatMessageContent(modelId, streamedRole!, completionChunk, chatChoice));
}
}
Expand Down Expand Up @@ -315,7 +314,7 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMes
}

// Find the function in the kernel and populate the arguments.
if (!kernel!.Plugins.TryGetFunctionAndArguments(kernel, toolCall.Function, out KernelFunction? function, out KernelArguments? functionArgs))
if (!kernel!.Plugins.TryGetFunctionAndArguments(toolCall.Function, out KernelFunction? function, out KernelArguments? functionArgs))
{
this.AddResponseMessage(chatRequest, chatHistory, toolCall, result: null, "Error: Requested function could not be found.");
continue;
Expand Down Expand Up @@ -413,7 +412,7 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMes
}
}

private async IAsyncEnumerable<StreamingChatMessageContent> StreamChatMessageContentsAsync(ChatHistory chatHistory, MistralAIPromptExecutionSettings executionSettings, ChatCompletionRequest chatRequest, string modelId, [EnumeratorCancellation] CancellationToken cancellationToken, Kernel? kernel = null)
private async IAsyncEnumerable<StreamingChatMessageContent> StreamChatMessageContentsAsync(ChatHistory chatHistory, MistralAIPromptExecutionSettings executionSettings, ChatCompletionRequest chatRequest, string modelId, [EnumeratorCancellation] CancellationToken cancellationToken)
{
this.ValidateChatHistory(chatHistory);

Expand Down Expand Up @@ -481,7 +480,7 @@ internal async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<

var response = await this.SendRequestAsync<TextEmbeddingResponse>(httpRequestMessage, cancellationToken).ConfigureAwait(false);

return response.Data.Select(item => new ReadOnlyMemory<float>(item.Embedding.ToArray())).ToList();
return response.Data.Select(item => new ReadOnlyMemory<float>([.. item.Embedding])).ToList();
}

#region private
Expand All @@ -503,7 +502,7 @@ internal async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<
/// to this limit, but if we do, auto-invoke will be disabled for the current flow in order to prevent runaway execution.
/// With the current setup, the way this could possibly happen is if a prompt function is configured with built-in
/// execution settings that opt-in to auto-invocation of everything in the kernel, in which case the invocation of that
/// prompt function could advertize itself as a candidate for auto-invocation. We don't want to outright block that,
/// prompt function could advertise itself as a candidate for auto-invocation. We don't want to outright block that,
/// if that's something a developer has asked to do (e.g. it might be invoked with different arguments than its parent
/// was invoked with), but we do want to limit it. This limit is arbitrary and can be tweaked in the future and/or made
/// configurable should need arise.
Expand Down Expand Up @@ -547,10 +546,7 @@ private ChatCompletionRequest CreateChatCompletionRequest(string modelId, bool s
request.SafePrompt = executionSettings.SafePrompt;
request.RandomSeed = executionSettings.RandomSeed;

if (executionSettings.ToolCallBehavior is not null)
{
executionSettings.ToolCallBehavior.ConfigureRequest(kernel, request);
}
executionSettings.ToolCallBehavior?.ConfigureRequest(kernel, request);
}

return request;
Expand All @@ -562,7 +558,7 @@ private List<MistralChatMessage> ToMistralChatMessages(ChatMessageContent conten
{
// Handling function calls supplied via ChatMessageContent.Items collection elements of the FunctionCallContent type.
var message = new MistralChatMessage(content.Role.ToString(), content.Content ?? string.Empty);
Dictionary<string, MistralToolCall> toolCalls = new();
Dictionary<string, MistralToolCall> toolCalls = [];
foreach (var item in content.Items)
{
if (item is not FunctionCallContent callRequest)
Expand Down Expand Up @@ -590,7 +586,7 @@ private List<MistralChatMessage> ToMistralChatMessages(ChatMessageContent conten
}
if (toolCalls.Count > 0)
{
message.ToolCalls = toolCalls.Values.ToList();
message.ToolCalls = [.. toolCalls.Values];
}
return [message];
}
Expand Down Expand Up @@ -634,7 +630,7 @@ private void SetRequestHeaders(HttpRequestMessage request, string apiKey, bool s
request.Headers.Add("User-Agent", HttpHeaderConstant.Values.UserAgent);
request.Headers.Add(HttpHeaderConstant.Names.SemanticKernelVersion, HttpHeaderConstant.Values.GetAssemblyVersion(this.GetType()));
request.Headers.Add("Accept", stream ? "text/event-stream" : "application/json");
request.Headers.Add("Authorization", $"Bearer {this._apiKey}");
request.Headers.Add("Authorization", $"Bearer {apiKey}");
request.Content!.Headers.ContentType = new MediaTypeHeaderValue("application/json");
}

Expand Down Expand Up @@ -679,12 +675,7 @@ private static T DeserializeResponse<T>(string body)
try
{
T? deserializedResponse = JsonSerializer.Deserialize<T>(body);
if (deserializedResponse is null)
{
throw new JsonException("Response is null");
}

return deserializedResponse;
return deserializedResponse ?? throw new JsonException("Response is null");
}
catch (JsonException exc)
{
Expand All @@ -700,11 +691,6 @@ private List<ChatMessageContent> ToChatMessageContent(string modelId, ChatComple
return response.Choices.Select(chatChoice => this.ToChatMessageContent(modelId, response, chatChoice)).ToList();
}

private List<ChatMessageContent> ToChatMessageContent(string modelId, string streamedRole, MistralChatCompletionChunk chunk)
{
return chunk.Choices.Select(chatChoice => this.ToChatMessageContent(modelId, streamedRole, chunk, chatChoice)).ToList();
}

private ChatMessageContent ToChatMessageContent(string modelId, ChatCompletionResponse response, MistralChatChoice chatChoice)
{
var message = new ChatMessageContent(new AuthorRole(chatChoice.Message!.Role!), chatChoice.Message!.Content, modelId, chatChoice, Encoding.UTF8, GetChatChoiceMetadata(response, chatChoice));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,23 @@ namespace Microsoft.SemanticKernel.Connectors.MistralAI.Client;
/// <summary>
/// Usage for chat completion.
/// </summary>
internal class MistralUsage
public class MistralUsage
{
/// <summary>
/// The number of tokens in the provided prompts for the completions request.
/// </summary>
[JsonPropertyName("prompt_tokens")]
public int? PromptTokens { get; set; }

/// <summary>
/// The number of tokens generated across all completions emissions.
/// </summary>
[JsonPropertyName("completion_tokens")]
public int? CompletionTokens { get; set; }

/// <summary>
/// The total number of tokens processed for the completions request and response.
/// </summary>
[JsonPropertyName("total_tokens")]
public int? TotalTokens { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<PropertyGroup>
<!-- NuGet Package Settings -->
<Title>Semantic Kernel - Mistral AI connectors</Title>
<Description>Semantic Kernel connectors for Mistral. Contains clients for text generation and text embedding generation.</Description>
<Description>Semantic Kernel connectors for Mistral. Contains services for chat completion and text embedding generation.</Description>
</PropertyGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,18 @@ namespace Microsoft.SemanticKernel.Connectors.MistralAI;
/// <summary>
/// Extension methods for <see cref="IReadOnlyKernelPluginCollection"/>.
/// </summary>
public static class MistralAIPluginCollectionExtensions
internal static class MistralAIPluginCollectionExtensions
{
/// <summary>
/// Given an <see cref="MistralFunction"/> object, tries to retrieve the corresponding <see cref="KernelFunction"/> and populate <see cref="KernelArguments"/> with its parameters.
/// </summary>
/// <param name="plugins">The plugins.</param>
/// <param name="kernel">The <see cref="Kernel"/> instance</param>
/// <param name="functionToolCall">The <see cref="MistralFunction"/> object.</param>
/// <param name="function">When this method returns, the function that was retrieved if one with the specified name was found; otherwise, <see langword="null"/></param>
/// <param name="arguments">When this method returns, the arguments for the function; otherwise, <see langword="null"/></param>
/// <returns><see langword="true"/> if the function was found; otherwise, <see langword="false"/>.</returns>
internal static bool TryGetFunctionAndArguments(
this IReadOnlyKernelPluginCollection plugins,
Kernel kernel,
MistralFunction functionToolCall,
[NotNullWhen(true)] out KernelFunction? function,
out KernelArguments? arguments)
Expand All @@ -40,7 +38,7 @@ internal static bool TryGetFunctionAndArguments(

if (functionArguments is not null)
{
arguments = new KernelArguments();
arguments = [];

foreach (var key in functionArguments.Keys)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.MistralAI;
using Microsoft.SemanticKernel.Connectors.MistralAI.Client;
using Xunit;

namespace SemanticKernel.IntegrationTests.Connectors.MistralAI;
Expand Down Expand Up @@ -38,7 +39,7 @@ public MistralAIChatCompletionTests()
};
}

[Fact] // (Skip = "This test is for manual verification.")
[Fact(Skip = "This test is for manual verification.")]
public async Task ValidateGetChatMessageContentsAsync()
{
// Arrange
Expand All @@ -60,7 +61,35 @@ public async Task ValidateGetChatMessageContentsAsync()
Assert.True(response[0].Content?.Length > 0);
}

[Fact] // (Skip = "This test is for manual verification.")
[Fact(Skip = "This test is for manual verification.")]
public async Task ValidateGetChatMessageContentsWithUsageAsync()
{
// Arrange
var model = this._configuration["MistralAI:ChatModel"];
var apiKey = this._configuration["MistralAI:ApiKey"];
var service = new MistralAIChatCompletionService(model!, apiKey!);

// Act
var chatHistory = new ChatHistory
{
new ChatMessageContent(AuthorRole.System, "Respond in French."),
new ChatMessageContent(AuthorRole.User, "What is the best French cheese?")
};
var response = await service.GetChatMessageContentsAsync(chatHistory, this._executionSettings);

// Assert
Assert.NotNull(response);
Assert.Single(response);
Assert.True(response[0].Content?.Length > 0);
Assert.NotNull(response[0].Metadata);
Assert.True(response[0].Metadata?.ContainsKey("Usage"));
var usage = response[0].Metadata?["Usage"] as MistralUsage;
Assert.True(usage?.CompletionTokens > 0);
Assert.True(usage?.PromptTokens > 0);
Assert.True(usage?.TotalTokens > 0);
}

[Fact(Skip = "This test is for manual verification.")]
public async Task ValidateInvokeChatPromptAsync()
{
// Arrange
Expand All @@ -84,7 +113,7 @@ public async Task ValidateInvokeChatPromptAsync()
Assert.False(string.IsNullOrEmpty(response.ToString()));
}

[Fact] // (Skip = "This test is for manual verification.")
[Fact(Skip = "This test is for manual verification.")]
public async Task ValidateGetStreamingChatMessageContentsAsync()
{
// Arrange
Expand Down Expand Up @@ -113,7 +142,7 @@ public async Task ValidateGetStreamingChatMessageContentsAsync()
Assert.False(string.IsNullOrEmpty(content.ToString()));
}

[Fact] // (Skip = "This test is for manual verification.")
[Fact(Skip = "This test is for manual verification.")]
public async Task ValidateGetChatMessageContentsHasToolCallsResponseAsync()
{
// Arrange
Expand All @@ -137,7 +166,7 @@ public async Task ValidateGetChatMessageContentsHasToolCallsResponseAsync()
Assert.Equal("tool_calls", response[0].Metadata?["FinishReason"]);
}

[Fact] // (Skip = "This test is for manual verification.")
[Fact(Skip = "This test is for manual verification.")]
public async Task ValidateGetChatMessageContentsHasRequiredToolCallResponseAsync()
{
// Arrange
Expand All @@ -164,7 +193,7 @@ public async Task ValidateGetChatMessageContentsHasRequiredToolCallResponseAsync
Assert.Equal("DoSomething", ((FunctionCallContent)response[0].Items[1]).FunctionName);
}

[Fact] // (Skip = "This test is for manual verification.")
[Fact(Skip = "This test is for manual verification.")]
public async Task ValidateGetChatMessageContentsWithAutoInvokeAsync()
{
// Arrange
Expand All @@ -188,7 +217,7 @@ public async Task ValidateGetChatMessageContentsWithAutoInvokeAsync()
Assert.Contains("sunny", response[0].Content, System.StringComparison.Ordinal);
}

[Fact] // (Skip = "This test is for manual verification.")
[Fact(Skip = "This test is for manual verification.")]
public async Task ValidateGetChatMessageContentsWithNoFunctionsAsync()
{
// Arrange
Expand All @@ -212,7 +241,7 @@ public async Task ValidateGetChatMessageContentsWithNoFunctionsAsync()
Assert.Contains("GetWeather", response[0].Content, System.StringComparison.Ordinal);
}

[Fact] // (Skip = "This test is for manual verification.")
[Fact(Skip = "This test is for manual verification.")]
public async Task ValidateGetChatMessageContentsWithAutoInvokeReturnsFunctionCallContentAsync()
{
// Arrange
Expand All @@ -239,7 +268,7 @@ public async Task ValidateGetChatMessageContentsWithAutoInvokeReturnsFunctionCal
Assert.Equal("GetWeather", ((FunctionCallContent)chatHistory[1].Items[1]).FunctionName);
}

[Fact] // (Skip = "This test is for manual verification.")
[Fact(Skip = "This test is for manual verification.")]
public async Task ValidateGetChatMessageContentsWithAutoInvokeAndFunctionFilterAsync()
{
// Arrange
Expand Down Expand Up @@ -271,7 +300,7 @@ public async Task ValidateGetChatMessageContentsWithAutoInvokeAndFunctionFilterA
Assert.Contains("GetWeather", invokedFunctions);
}

[Fact] // (Skip = "This test is for manual verification.")
[Fact(Skip = "This test is for manual verification.")]
public async Task ValidateGetChatMessageContentsWithAutoInvokeAndFunctionInvocationFilterAsync()
{
// Arrange
Expand Down Expand Up @@ -305,7 +334,7 @@ public async Task ValidateGetChatMessageContentsWithAutoInvokeAndFunctionInvocat
Assert.Contains("GetWeather", invokedFunctions);
}

[Fact] // (Skip = "This test is for manual verification.")
[Fact(Skip = "This test is for manual verification.")]
public async Task ValidateGetChatMessageContentsWithAutoInvokeAndMultipleCallsAsync()
{
// Arrange
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public MistralAITextEmbeddingTests()
.Build();
}

[Fact] // (Skip = "This test is for manual verification.")
[Fact(Skip = "This test is for manual verification.")]
public async Task MistralAIGenerateEmbeddingsAsync()
{
// Arrange
Expand All @@ -35,7 +35,7 @@ public async Task MistralAIGenerateEmbeddingsAsync()
var service = new MistralAITextEmbeddingGenerationService(model!, apiKey!);

// Act
List<string> data = new() { "Hello", "world" };
List<string> data = ["Hello", "world"];
var response = await service.GenerateEmbeddingsAsync(data);

// Assert
Expand Down
Loading