Skip to content

.Net: Implement OnnxRuntimeGenAIChatCompletionService on OnnxRuntimeGenAIChatClient #12197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions dotnet/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
<PackageVersion Include="Microsoft.Identity.Client" Version="4.67.2" />
<PackageVersion Include="Microsoft.Identity.Client.Extensions.Msal" Version="4.67.2" />
<PackageVersion Include="Microsoft.IdentityModel.JsonWebTokens" Version="7.5.1" />
<PackageVersion Include="Microsoft.ML.OnnxRuntime" Version="1.21.0" />
<PackageVersion Include="Microsoft.ML.OnnxRuntime" Version="1.22.0" />
<PackageVersion Include="Microsoft.ML.Tokenizers.Data.Cl100kBase" Version="1.0.1" />
<PackageVersion Include="Microsoft.SemanticKernel.Abstractions" Version="1.47.0" />
<PackageVersion Include="Microsoft.SemanticKernel.Connectors.OpenAI" Version="1.47.0" />
Expand Down Expand Up @@ -215,9 +215,9 @@
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<!-- OnnxRuntimeGenAI -->
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI" Version="0.7.0-rc2" />
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI.Cuda" Version="0.7.0-rc2" />
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI.DirectML" Version="0.7.0-rc2" />
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI" Version="0.8.1" />
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI.Cuda" Version="0.8.1" />
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI.DirectML" Version="0.8.1" />
<!-- SpectreConsole-->
<PackageVersion Include="Spectre.Console" Version="0.49.1" />
<PackageVersion Include="Spectre.Console.Cli" Version="0.49.1" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Threading;
Expand All @@ -20,12 +18,13 @@ namespace Microsoft.SemanticKernel.Connectors.Onnx;
/// </summary>
public sealed class OnnxRuntimeGenAIChatCompletionService : IChatCompletionService, IDisposable
{
private readonly string _modelId;
private readonly string _modelPath;
private readonly JsonSerializerOptions? _jsonSerializerOptions;
private Model? _model;
private Tokenizer? _tokenizer;
private Dictionary<string, object?> AttributesInternal { get; } = new();
private OnnxRuntimeGenAIChatClient? _chatClient;
private IChatCompletionService? _chatClientWrapper;
private readonly Dictionary<string, object?> _attributesInternal = [];

/// <inheritdoc/>
public IReadOnlyDictionary<string, object?> Attributes => this._attributesInternal;

/// <summary>
/// Initializes a new instance of the OnnxRuntimeGenAIChatCompletionService class.
Expand All @@ -43,174 +42,38 @@ public OnnxRuntimeGenAIChatCompletionService(
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNullOrWhiteSpace(modelPath);

this._modelId = modelId;
this._attributesInternal.Add(AIServiceExtensions.ModelIdKey, modelId);
this._modelPath = modelPath;
this._jsonSerializerOptions = jsonSerializerOptions;
this.AttributesInternal.Add(AIServiceExtensions.ModelIdKey, this._modelId);
}

/// <inheritdoc />
public IReadOnlyDictionary<string, object?> Attributes => this.AttributesInternal;

/// <inheritdoc />
public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
var result = new StringBuilder();

await foreach (var content in this.RunInferenceAsync(chatHistory, executionSettings, cancellationToken).ConfigureAwait(false))
{
result.Append(content);
}

return new List<ChatMessageContent>
{
new(
role: AuthorRole.Assistant,
modelId: this._modelId,
content: result.ToString())
};
}

/// <inheritdoc />
public async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessageContentsAsync(
ChatHistory chatHistory,
PromptExecutionSettings? executionSettings = null,
Kernel? kernel = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await foreach (var content in this.RunInferenceAsync(chatHistory, executionSettings, cancellationToken).ConfigureAwait(false))
{
yield return new StreamingChatMessageContent(AuthorRole.Assistant, content, modelId: this._modelId);
}
}

private async IAsyncEnumerable<string> RunInferenceAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings, [EnumeratorCancellation] CancellationToken cancellationToken)
private IChatCompletionService GetChatCompletionService()
{
OnnxRuntimeGenAIPromptExecutionSettings onnxPromptExecutionSettings = this.GetOnnxPromptExecutionSettingsSettings(executionSettings);

var prompt = this.GetPrompt(chatHistory, onnxPromptExecutionSettings);
using var tokens = this.GetTokenizer().Encode(prompt);

using var generatorParams = new GeneratorParams(this.GetModel());
this.UpdateGeneratorParamsFromPromptExecutionSettings(generatorParams, onnxPromptExecutionSettings);

using var generator = new Generator(this.GetModel(), generatorParams);
generator.AppendTokenSequences(tokens);

bool removeNextTokenStartingWithSpace = true;
while (!generator.IsDone())
this._chatClient ??= new OnnxRuntimeGenAIChatClient(this._modelPath, new OnnxRuntimeGenAIChatClientOptions()
{
cancellationToken.ThrowIfCancellationRequested();

yield return await Task.Run(() =>
PromptFormatter = (messages, options) =>
{
generator.GenerateNextToken();

var outputTokens = generator.GetSequence(0);
var newToken = outputTokens[outputTokens.Length - 1];

using var tokenizerStream = this.GetTokenizer().CreateStream();
string output = tokenizerStream.Decode(newToken);

if (removeNextTokenStartingWithSpace && output[0] == ' ')
StringBuilder promptBuilder = new();
foreach (var message in messages)
{
removeNextTokenStartingWithSpace = false;
output = output.TrimStart();
promptBuilder.Append($"<|{message.Role}|>\n{message.Text}");
}
promptBuilder.Append("<|end|>\n<|assistant|>");

return output;
}, cancellationToken).ConfigureAwait(false);
}
}

private Model GetModel() => this._model ??= new Model(this._modelPath);

private Tokenizer GetTokenizer() => this._tokenizer ??= new Tokenizer(this.GetModel());
return promptBuilder.ToString();
}
});

private string GetPrompt(ChatHistory chatHistory, OnnxRuntimeGenAIPromptExecutionSettings onnxRuntimeGenAIPromptExecutionSettings)
{
var promptBuilder = new StringBuilder();
foreach (var message in chatHistory)
{
promptBuilder.Append($"<|{message.Role}|>\n{message.Content}");
}
promptBuilder.Append("<|end|>\n<|assistant|>");

return promptBuilder.ToString();
return this._chatClientWrapper ??= this._chatClient.AsChatCompletionService();
}

private void UpdateGeneratorParamsFromPromptExecutionSettings(GeneratorParams generatorParams, OnnxRuntimeGenAIPromptExecutionSettings onnxRuntimeGenAIPromptExecutionSettings)
{
if (onnxRuntimeGenAIPromptExecutionSettings.TopP.HasValue)
{
generatorParams.SetSearchOption("top_p", onnxRuntimeGenAIPromptExecutionSettings.TopP.Value);
}
if (onnxRuntimeGenAIPromptExecutionSettings.TopK.HasValue)
{
generatorParams.SetSearchOption("top_k", onnxRuntimeGenAIPromptExecutionSettings.TopK.Value);
}
if (onnxRuntimeGenAIPromptExecutionSettings.Temperature.HasValue)
{
generatorParams.SetSearchOption("temperature", onnxRuntimeGenAIPromptExecutionSettings.Temperature.Value);
}
if (onnxRuntimeGenAIPromptExecutionSettings.RepetitionPenalty.HasValue)
{
generatorParams.SetSearchOption("repetition_penalty", onnxRuntimeGenAIPromptExecutionSettings.RepetitionPenalty.Value);
}
if (onnxRuntimeGenAIPromptExecutionSettings.PastPresentShareBuffer.HasValue)
{
generatorParams.SetSearchOption("past_present_share_buffer", onnxRuntimeGenAIPromptExecutionSettings.PastPresentShareBuffer.Value);
}
if (onnxRuntimeGenAIPromptExecutionSettings.NumReturnSequences.HasValue)
{
generatorParams.SetSearchOption("num_return_sequences", onnxRuntimeGenAIPromptExecutionSettings.NumReturnSequences.Value);
}
if (onnxRuntimeGenAIPromptExecutionSettings.NoRepeatNgramSize.HasValue)
{
generatorParams.SetSearchOption("no_repeat_ngram_size", onnxRuntimeGenAIPromptExecutionSettings.NoRepeatNgramSize.Value);
}
if (onnxRuntimeGenAIPromptExecutionSettings.MinTokens.HasValue)
{
generatorParams.SetSearchOption("min_length", onnxRuntimeGenAIPromptExecutionSettings.MinTokens.Value);
}
if (onnxRuntimeGenAIPromptExecutionSettings.MaxTokens.HasValue)
{
generatorParams.SetSearchOption("max_length", onnxRuntimeGenAIPromptExecutionSettings.MaxTokens.Value);
}
if (onnxRuntimeGenAIPromptExecutionSettings.LengthPenalty.HasValue)
{
generatorParams.SetSearchOption("length_penalty", onnxRuntimeGenAIPromptExecutionSettings.LengthPenalty.Value);
}
if (onnxRuntimeGenAIPromptExecutionSettings.EarlyStopping.HasValue)
{
generatorParams.SetSearchOption("early_stopping", onnxRuntimeGenAIPromptExecutionSettings.EarlyStopping.Value);
}
if (onnxRuntimeGenAIPromptExecutionSettings.DoSample.HasValue)
{
generatorParams.SetSearchOption("do_sample", onnxRuntimeGenAIPromptExecutionSettings.DoSample.Value);
}
if (onnxRuntimeGenAIPromptExecutionSettings.DiversityPenalty.HasValue)
{
generatorParams.SetSearchOption("diversity_penalty", onnxRuntimeGenAIPromptExecutionSettings.DiversityPenalty.Value);
}
}

[UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "JSOs are required only in cases where the supplied settings are not Onnx-specific. For these cases, JSOs can be provided via the class constructor.")]
[UnconditionalSuppressMessage("AOT", "IL3050:Calling members annotated with 'RequiresDynamicCodeAttribute' may break functionality when AOT compiling.", Justification = "JSOs are required only in cases where the supplied settings are not Onnx-specific. For these cases, JSOs can be provided via class constructor.")]
private OnnxRuntimeGenAIPromptExecutionSettings GetOnnxPromptExecutionSettingsSettings(PromptExecutionSettings? executionSettings)
{
if (this._jsonSerializerOptions is not null)
{
return OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings, this._jsonSerializerOptions);
}
/// <inheritdoc/>
public void Dispose() => this._chatClient?.Dispose();

return OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);
}
/// <inheritdoc/>
public Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) =>
this.GetChatCompletionService().GetChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken);

/// <inheritdoc/>
public void Dispose()
{
this._tokenizer?.Dispose();
this._model?.Dispose();
}
public IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) =>
this.GetChatCompletionService().GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken);
}
Loading