diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props
index 26bb8e881632..cde814b1fb31 100644
--- a/dotnet/Directory.Packages.props
+++ b/dotnet/Directory.Packages.props
@@ -68,7 +68,7 @@
-
+
@@ -215,9 +215,9 @@
runtime; build; native; contentfiles; analyzers; buildtransitive
-
-
-
+
+
+
diff --git a/dotnet/src/Connectors/Connectors.Onnx/OnnxRuntimeGenAIChatCompletionService.cs b/dotnet/src/Connectors/Connectors.Onnx/OnnxRuntimeGenAIChatCompletionService.cs
index b26ecbeb6b0d..c18b76ffed4b 100644
--- a/dotnet/src/Connectors/Connectors.Onnx/OnnxRuntimeGenAIChatCompletionService.cs
+++ b/dotnet/src/Connectors/Connectors.Onnx/OnnxRuntimeGenAIChatCompletionService.cs
@@ -2,8 +2,6 @@
using System;
using System.Collections.Generic;
-using System.Diagnostics.CodeAnalysis;
-using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Threading;
@@ -20,12 +18,13 @@ namespace Microsoft.SemanticKernel.Connectors.Onnx;
///
public sealed class OnnxRuntimeGenAIChatCompletionService : IChatCompletionService, IDisposable
{
- private readonly string _modelId;
private readonly string _modelPath;
- private readonly JsonSerializerOptions? _jsonSerializerOptions;
- private Model? _model;
- private Tokenizer? _tokenizer;
- private Dictionary AttributesInternal { get; } = new();
+ private OnnxRuntimeGenAIChatClient? _chatClient;
+ private IChatCompletionService? _chatClientWrapper;
+ private readonly Dictionary _attributesInternal = [];
+
+ ///
+ public IReadOnlyDictionary Attributes => this._attributesInternal;
///
/// Initializes a new instance of the OnnxRuntimeGenAIChatCompletionService class.
@@ -43,174 +42,38 @@ public OnnxRuntimeGenAIChatCompletionService(
Verify.NotNullOrWhiteSpace(modelId);
Verify.NotNullOrWhiteSpace(modelPath);
- this._modelId = modelId;
+ this._attributesInternal.Add(AIServiceExtensions.ModelIdKey, modelId);
this._modelPath = modelPath;
- this._jsonSerializerOptions = jsonSerializerOptions;
- this.AttributesInternal.Add(AIServiceExtensions.ModelIdKey, this._modelId);
- }
-
- ///
- public IReadOnlyDictionary Attributes => this.AttributesInternal;
-
- ///
- public async Task> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
- {
- var result = new StringBuilder();
-
- await foreach (var content in this.RunInferenceAsync(chatHistory, executionSettings, cancellationToken).ConfigureAwait(false))
- {
- result.Append(content);
- }
-
- return new List
- {
- new(
- role: AuthorRole.Assistant,
- modelId: this._modelId,
- content: result.ToString())
- };
- }
-
- ///
- public async IAsyncEnumerable GetStreamingChatMessageContentsAsync(
- ChatHistory chatHistory,
- PromptExecutionSettings? executionSettings = null,
- Kernel? kernel = null,
- [EnumeratorCancellation] CancellationToken cancellationToken = default)
- {
- await foreach (var content in this.RunInferenceAsync(chatHistory, executionSettings, cancellationToken).ConfigureAwait(false))
- {
- yield return new StreamingChatMessageContent(AuthorRole.Assistant, content, modelId: this._modelId);
- }
}
- private async IAsyncEnumerable RunInferenceAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings, [EnumeratorCancellation] CancellationToken cancellationToken)
+ private IChatCompletionService GetChatCompletionService()
{
- OnnxRuntimeGenAIPromptExecutionSettings onnxPromptExecutionSettings = this.GetOnnxPromptExecutionSettingsSettings(executionSettings);
-
- var prompt = this.GetPrompt(chatHistory, onnxPromptExecutionSettings);
- using var tokens = this.GetTokenizer().Encode(prompt);
-
- using var generatorParams = new GeneratorParams(this.GetModel());
- this.UpdateGeneratorParamsFromPromptExecutionSettings(generatorParams, onnxPromptExecutionSettings);
-
- using var generator = new Generator(this.GetModel(), generatorParams);
- generator.AppendTokenSequences(tokens);
-
- bool removeNextTokenStartingWithSpace = true;
- while (!generator.IsDone())
+ this._chatClient ??= new OnnxRuntimeGenAIChatClient(this._modelPath, new OnnxRuntimeGenAIChatClientOptions()
{
- cancellationToken.ThrowIfCancellationRequested();
-
- yield return await Task.Run(() =>
+ PromptFormatter = (messages, options) =>
{
- generator.GenerateNextToken();
-
- var outputTokens = generator.GetSequence(0);
- var newToken = outputTokens[outputTokens.Length - 1];
-
- using var tokenizerStream = this.GetTokenizer().CreateStream();
- string output = tokenizerStream.Decode(newToken);
-
- if (removeNextTokenStartingWithSpace && output[0] == ' ')
+ StringBuilder promptBuilder = new();
+ foreach (var message in messages)
{
- removeNextTokenStartingWithSpace = false;
- output = output.TrimStart();
+ promptBuilder.Append($"<|{message.Role}|>\n{message.Text}");
}
+ promptBuilder.Append("<|end|>\n<|assistant|>");
- return output;
- }, cancellationToken).ConfigureAwait(false);
- }
- }
-
- private Model GetModel() => this._model ??= new Model(this._modelPath);
-
- private Tokenizer GetTokenizer() => this._tokenizer ??= new Tokenizer(this.GetModel());
+ return promptBuilder.ToString();
+ }
+ });
- private string GetPrompt(ChatHistory chatHistory, OnnxRuntimeGenAIPromptExecutionSettings onnxRuntimeGenAIPromptExecutionSettings)
- {
- var promptBuilder = new StringBuilder();
- foreach (var message in chatHistory)
- {
- promptBuilder.Append($"<|{message.Role}|>\n{message.Content}");
- }
- promptBuilder.Append("<|end|>\n<|assistant|>");
-
- return promptBuilder.ToString();
+ return this._chatClientWrapper ??= this._chatClient.AsChatCompletionService();
}
- private void UpdateGeneratorParamsFromPromptExecutionSettings(GeneratorParams generatorParams, OnnxRuntimeGenAIPromptExecutionSettings onnxRuntimeGenAIPromptExecutionSettings)
- {
- if (onnxRuntimeGenAIPromptExecutionSettings.TopP.HasValue)
- {
- generatorParams.SetSearchOption("top_p", onnxRuntimeGenAIPromptExecutionSettings.TopP.Value);
- }
- if (onnxRuntimeGenAIPromptExecutionSettings.TopK.HasValue)
- {
- generatorParams.SetSearchOption("top_k", onnxRuntimeGenAIPromptExecutionSettings.TopK.Value);
- }
- if (onnxRuntimeGenAIPromptExecutionSettings.Temperature.HasValue)
- {
- generatorParams.SetSearchOption("temperature", onnxRuntimeGenAIPromptExecutionSettings.Temperature.Value);
- }
- if (onnxRuntimeGenAIPromptExecutionSettings.RepetitionPenalty.HasValue)
- {
- generatorParams.SetSearchOption("repetition_penalty", onnxRuntimeGenAIPromptExecutionSettings.RepetitionPenalty.Value);
- }
- if (onnxRuntimeGenAIPromptExecutionSettings.PastPresentShareBuffer.HasValue)
- {
- generatorParams.SetSearchOption("past_present_share_buffer", onnxRuntimeGenAIPromptExecutionSettings.PastPresentShareBuffer.Value);
- }
- if (onnxRuntimeGenAIPromptExecutionSettings.NumReturnSequences.HasValue)
- {
- generatorParams.SetSearchOption("num_return_sequences", onnxRuntimeGenAIPromptExecutionSettings.NumReturnSequences.Value);
- }
- if (onnxRuntimeGenAIPromptExecutionSettings.NoRepeatNgramSize.HasValue)
- {
- generatorParams.SetSearchOption("no_repeat_ngram_size", onnxRuntimeGenAIPromptExecutionSettings.NoRepeatNgramSize.Value);
- }
- if (onnxRuntimeGenAIPromptExecutionSettings.MinTokens.HasValue)
- {
- generatorParams.SetSearchOption("min_length", onnxRuntimeGenAIPromptExecutionSettings.MinTokens.Value);
- }
- if (onnxRuntimeGenAIPromptExecutionSettings.MaxTokens.HasValue)
- {
- generatorParams.SetSearchOption("max_length", onnxRuntimeGenAIPromptExecutionSettings.MaxTokens.Value);
- }
- if (onnxRuntimeGenAIPromptExecutionSettings.LengthPenalty.HasValue)
- {
- generatorParams.SetSearchOption("length_penalty", onnxRuntimeGenAIPromptExecutionSettings.LengthPenalty.Value);
- }
- if (onnxRuntimeGenAIPromptExecutionSettings.EarlyStopping.HasValue)
- {
- generatorParams.SetSearchOption("early_stopping", onnxRuntimeGenAIPromptExecutionSettings.EarlyStopping.Value);
- }
- if (onnxRuntimeGenAIPromptExecutionSettings.DoSample.HasValue)
- {
- generatorParams.SetSearchOption("do_sample", onnxRuntimeGenAIPromptExecutionSettings.DoSample.Value);
- }
- if (onnxRuntimeGenAIPromptExecutionSettings.DiversityPenalty.HasValue)
- {
- generatorParams.SetSearchOption("diversity_penalty", onnxRuntimeGenAIPromptExecutionSettings.DiversityPenalty.Value);
- }
- }
-
- [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "JSOs are required only in cases where the supplied settings are not Onnx-specific. For these cases, JSOs can be provided via the class constructor.")]
- [UnconditionalSuppressMessage("AOT", "IL3050:Calling members annotated with 'RequiresDynamicCodeAttribute' may break functionality when AOT compiling.", Justification = "JSOs are required only in cases where the supplied settings are not Onnx-specific. For these cases, JSOs can be provided via class constructor.")]
- private OnnxRuntimeGenAIPromptExecutionSettings GetOnnxPromptExecutionSettingsSettings(PromptExecutionSettings? executionSettings)
- {
- if (this._jsonSerializerOptions is not null)
- {
- return OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings, this._jsonSerializerOptions);
- }
+ ///
+ public void Dispose() => this._chatClient?.Dispose();
- return OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);
- }
+ ///
+ public Task> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) =>
+ this.GetChatCompletionService().GetChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken);
///
- public void Dispose()
- {
- this._tokenizer?.Dispose();
- this._model?.Dispose();
- }
+ public IAsyncEnumerable GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) =>
+ this.GetChatCompletionService().GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken);
}