diff --git a/shell/AIShell.Abstraction/AIShell.Abstraction.csproj b/shell/AIShell.Abstraction/AIShell.Abstraction.csproj
index d6ed2e77..1eed2edf 100644
--- a/shell/AIShell.Abstraction/AIShell.Abstraction.csproj
+++ b/shell/AIShell.Abstraction/AIShell.Abstraction.csproj
@@ -7,5 +7,6 @@
+
diff --git a/shell/AIShell.Abstraction/IShell.cs b/shell/AIShell.Abstraction/IShell.cs
index 8b1ab7dd..c0c04867 100644
--- a/shell/AIShell.Abstraction/IShell.cs
+++ b/shell/AIShell.Abstraction/IShell.cs
@@ -1,3 +1,5 @@
+using Microsoft.Extensions.AI;
+
namespace AIShell.Abstraction;
///
@@ -27,6 +29,26 @@ public interface IShell
/// A list of code blocks or null if there is no code block.
List ExtractCodeBlocks(string text, out List sourceInfos);
+ ///
+ /// Get available instances for LLM to use.
+ ///
+ ///
+ Task> GetAIFunctions();
+
+ ///
+ /// Call an AI function.
+ ///
+ /// A instance representing the function call request.
+ /// Whether or not to capture the exception thrown from calling the tool.
+ /// Whether or not to include the exception message to the message of the call result.
+ /// The cancellation token to cancel the call.
+ ///
+ Task CallAIFunction(
+ FunctionCallContent functionCall,
+ bool captureException,
+ bool includeDetailedErrors,
+ CancellationToken cancellationToken);
+
// TODO:
// - methods to run code: python, command-line, powershell, node-js.
// - methods to communicate with shell client.
diff --git a/shell/AIShell.App/AIShell.App.csproj b/shell/AIShell.App/AIShell.App.csproj
index dd8f9592..86969f52 100644
--- a/shell/AIShell.App/AIShell.App.csproj
+++ b/shell/AIShell.App/AIShell.App.csproj
@@ -12,7 +12,6 @@
-
diff --git a/shell/AIShell.Kernel/AIShell.Kernel.csproj b/shell/AIShell.Kernel/AIShell.Kernel.csproj
index 442b0dbf..5a58665f 100644
--- a/shell/AIShell.Kernel/AIShell.Kernel.csproj
+++ b/shell/AIShell.Kernel/AIShell.Kernel.csproj
@@ -6,8 +6,8 @@
-
-
+
+
diff --git a/shell/AIShell.Kernel/Command/CommandRunner.cs b/shell/AIShell.Kernel/Command/CommandRunner.cs
index 515d430f..1043b532 100644
--- a/shell/AIShell.Kernel/Command/CommandRunner.cs
+++ b/shell/AIShell.Kernel/Command/CommandRunner.cs
@@ -35,6 +35,7 @@ internal CommandRunner(Shell shell)
new RefreshCommand(),
new RetryCommand(),
new HelpCommand(),
+ new McpCommand(),
//new RenderCommand(),
};
diff --git a/shell/AIShell.Kernel/Command/McpCommand.cs b/shell/AIShell.Kernel/Command/McpCommand.cs
new file mode 100644
index 00000000..58f452cc
--- /dev/null
+++ b/shell/AIShell.Kernel/Command/McpCommand.cs
@@ -0,0 +1,40 @@
+using System.CommandLine;
+using AIShell.Abstraction;
+
+namespace AIShell.Kernel.Commands;
+
+internal sealed class McpCommand : CommandBase
+{
+ public McpCommand()
+ : base("mcp", "Command for managing MCP servers and tools.")
+ {
+ this.SetHandler(ShowMCPData);
+
+ //var start = new Command("start", "Start an MCP server.");
+ //var stop = new Command("stop", "Stop an MCP server.");
+ //var server = new Argument(
+ // name: "server",
+ // getDefaultValue: () => null,
+ // description: "Name of an MCP server.").AddCompletions(AgentCompleter);
+
+ //start.AddArgument(server);
+ //start.SetHandler(StartMcpServer, server);
+
+ //stop.AddArgument(server);
+ //stop.SetHandler(StopMcpServer, server);
+ }
+
+ private void ShowMCPData()
+ {
+ var shell = (Shell)Shell;
+ var host = shell.Host;
+
+ if (shell.McpManager.McpServers.Count is 0)
+ {
+ host.WriteErrorLine("No MCP server is available.");
+ return;
+ }
+
+ host.RenderMcpServersAndTools(shell.McpManager);
+ }
+}
diff --git a/shell/AIShell.Kernel/Host.cs b/shell/AIShell.Kernel/Host.cs
index 27e94544..472fe223 100644
--- a/shell/AIShell.Kernel/Host.cs
+++ b/shell/AIShell.Kernel/Host.cs
@@ -2,9 +2,12 @@
using System.Text;
using AIShell.Abstraction;
+using AIShell.Kernel.Mcp;
using Markdig.Helpers;
using Microsoft.PowerShell;
using Spectre.Console;
+using Spectre.Console.Json;
+using Spectre.Console.Rendering;
namespace AIShell.Kernel;
@@ -175,7 +178,7 @@ public void RenderFullResponse(string response)
///
public void RenderTable(IList sources)
{
- RequireStdoutOrStderr(operation: "render table");
+ RequireStdout(operation: "render table");
ArgumentNullException.ThrowIfNull(sources);
if (sources.Count is 0)
@@ -198,7 +201,7 @@ public void RenderTable(IList sources)
///
public void RenderTable(IList sources, IList> elements)
{
- RequireStdoutOrStderr(operation: "render table");
+ RequireStdout(operation: "render table");
ArgumentNullException.ThrowIfNull(sources);
ArgumentNullException.ThrowIfNull(elements);
@@ -240,7 +243,7 @@ public void RenderTable(IList sources, IList> elements)
///
public void RenderList(T source)
{
- RequireStdoutOrStderr(operation: "render list");
+ RequireStdout(operation: "render list");
ArgumentNullException.ThrowIfNull(source);
if (source is IDictionary dict)
@@ -271,7 +274,7 @@ public void RenderList(T source)
///
public void RenderList(T source, IList> elements)
{
- RequireStdoutOrStderr(operation: "render list");
+ RequireStdout(operation: "render list");
ArgumentNullException.ThrowIfNull(source);
ArgumentNullException.ThrowIfNull(elements);
@@ -313,7 +316,7 @@ public void RenderList(T source, IList> elements)
public void RenderDivider(string text, DividerAlignment alignment)
{
ArgumentException.ThrowIfNullOrEmpty(text);
- RequireStdoutOrStderr(operation: "render divider");
+ RequireStdout(operation: "render divider");
if (!text.Contains("[/]"))
{
@@ -550,15 +553,134 @@ public string PromptForArgument(ArgumentInfo argInfo, bool printCaption)
internal void RenderReferenceText(string header, string content)
{
RequireStdoutOrStderr(operation: "Render reference");
+ IAnsiConsole ansiConsole = _outputRedirected ? _stderrConsole : AnsiConsole.Console;
var panel = new Panel($"\n[italic]{content.EscapeMarkup()}[/]\n")
.RoundedBorder()
.BorderColor(Color.DarkCyan)
.Header($"[orange3 on italic] {header.Trim()} [/]");
- AnsiConsole.WriteLine();
- AnsiConsole.Write(panel);
- AnsiConsole.WriteLine();
+ ansiConsole.WriteLine();
+ ansiConsole.Write(panel);
+ ansiConsole.WriteLine();
+ }
+
+ ///
+ /// Render the MCP tool call request.
+ ///
+ /// The MCP tool.
+ /// The arguments in JSON form to be sent for the tool call.
+ internal void RenderToolCallRequest(McpTool tool, string jsonArgs)
+ {
+ RequireStdoutOrStderr(operation: "render tool call request");
+ IAnsiConsole ansiConsole = _outputRedirected ? _stderrConsole : AnsiConsole.Console;
+
+ bool hasArgs = !string.IsNullOrEmpty(jsonArgs);
+ IRenderable content = new Markup($"""
+
+ [bold]Run [olive]{tool.OriginalName}[/] from [olive]{tool.ServerName}[/] (MCP server)[/]
+
+ {tool.Description}
+
+ Input:{(hasArgs ? string.Empty : " ")}
+ """);
+
+ if (hasArgs)
+ {
+ var json = new JsonText(jsonArgs)
+ .MemberColor(Color.Aqua)
+ .ColonColor(Color.White)
+ .CommaColor(Color.White)
+ .StringStyle(Color.Tan);
+
+ content = new Grid()
+ .AddColumn(new GridColumn())
+ .AddRow(content)
+ .AddRow(json);
+ }
+
+ var panel = new Panel(content)
+ .Expand()
+ .RoundedBorder()
+ .Header("[green] Tool Call Request [/]")
+ .BorderColor(Color.Grey);
+
+ ansiConsole.WriteLine();
+ ansiConsole.Write(panel);
+ FancyStreamRender.ConsoleUpdated();
+ }
+
+ ///
+ /// Render a table with information about available MCP servers and tools.
+ ///
+ /// The MCP manager instance.
+ internal void RenderMcpServersAndTools(McpManager mcpManager)
+ {
+ RequireStdout(operation: "render MCP servers and tools");
+
+ var toolTable = new Table()
+ .LeftAligned()
+ .SimpleBorder()
+ .BorderColor(Color.Green);
+
+ toolTable.AddColumn("[green bold]Server[/]");
+ toolTable.AddColumn("[green bold]Tool[/]");
+ toolTable.AddColumn("[green bold]Description[/]");
+
+ List<(string name, string status, string info)> readyServers = null, startingServers = null, failedServers = null;
+ foreach (var (name, server) in mcpManager.McpServers)
+ {
+ (int code, string status, string info) = server.IsInitFinished
+ ? server.Error is null
+ ? (1, "[green]\u2713 Ready[/]", string.Empty)
+ : (-1, "[red]\u2717 Failed[/]", $"[red]{server.Error.Message.EscapeMarkup()}[/]")
+ : (0, "[yellow]\u25CB Starting[/]", string.Empty);
+
+ var list = code switch
+ {
+ 1 => readyServers ??= [],
+ 0 => startingServers ??= [],
+ _ => failedServers ??= [],
+ };
+
+ list.Add((name, status, info));
+ }
+
+ if (startingServers is not null)
+ {
+ foreach (var (name, status, info) in startingServers)
+ {
+ toolTable.AddRow($"[olive underline]{name}[/]", status, info);
+ }
+ }
+
+ if (failedServers is not null)
+ {
+ foreach (var (name, status, info) in failedServers)
+ {
+ toolTable.AddRow($"[olive underline]{name}[/]", status, info);
+ }
+ }
+
+ if (readyServers is not null)
+ {
+ foreach (var (name, status, info) in readyServers)
+ {
+ if (toolTable.Rows is { Count: > 0 })
+ {
+ toolTable.AddEmptyRow();
+ }
+
+ var server = mcpManager.McpServers[name];
+ toolTable.AddRow($"[olive underline]{name}[/]", status, info);
+ foreach (var item in server.Tools)
+ {
+ toolTable.AddRow(string.Empty, item.Key.EscapeMarkup(), item.Value.Description.EscapeMarkup());
+ }
+ }
+ }
+
+ AnsiConsole.Write(toolTable);
}
private static Spinner GetSpinner(SpinnerKind? kind)
@@ -583,6 +705,19 @@ private void RequireStdin(string operation)
}
}
+ ///
+ /// Throw exception if standard output is redirected.
+ ///
+ /// The intended operation.
+ /// Throw the exception if stdout is redirected.
+ private void RequireStdout(string operation)
+ {
+ if (_outputRedirected)
+ {
+ throw new InvalidOperationException($"Cannot {operation} when the stdout is redirected.");
+ }
+ }
+
///
/// Throw exception if both standard output and error are redirected.
///
diff --git a/shell/AIShell.Kernel/MCP/McpConfig.cs b/shell/AIShell.Kernel/MCP/McpConfig.cs
new file mode 100644
index 00000000..a8453eea
--- /dev/null
+++ b/shell/AIShell.Kernel/MCP/McpConfig.cs
@@ -0,0 +1,195 @@
+using System.Text;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+
+using ModelContextProtocol.Client;
+using ModelContextProtocol.Protocol;
+
+namespace AIShell.Kernel.Mcp;
+
+///
+/// MCP configuration defined in mcp.json.
+///
+internal class McpConfig
+{
+ [JsonPropertyName("servers")]
+ public Dictionary Servers { get; set; } = [];
+
+ internal static McpConfig Load()
+ {
+ McpConfig mcpConfig = null;
+ FileInfo file = new(Utils.AppMcpFile);
+ if (file.Exists)
+ {
+ using var stream = file.OpenRead();
+ mcpConfig = JsonSerializer.Deserialize(stream, McpJsonContext.Default.McpConfig);
+ mcpConfig.Validate();
+ }
+
+ return mcpConfig is { Servers.Count: 0 } ? null : mcpConfig;
+ }
+
+ ///
+ /// Post-deserialization validation.
+ ///
+ ///
+ private void Validate()
+ {
+ List allErrors = null;
+
+ foreach (var (name, server) in Servers)
+ {
+ server.Name = name;
+ if (Enum.TryParse(server.Type, ignoreCase: true, out McpType mcpType))
+ {
+ server.Transport = mcpType;
+ }
+ else
+ {
+ (allErrors ??= []).Add($"Server '{name}': 'type' is required and the value should be one of the following: {string.Join(',', Enum.GetNames())}.");
+ continue;
+ }
+
+ List curErrs = null;
+
+ if (mcpType is McpType.stdio)
+ {
+ bool hasUrlGroup = !string.IsNullOrEmpty(server.Url) || server.Headers is { };
+ if (hasUrlGroup)
+ {
+ (curErrs ??= []).Add($"'url' and 'headers' fields are invalid for 'stdio' type servers.");
+ }
+
+ if (string.IsNullOrEmpty(server.Command))
+ {
+ (curErrs ??= []).Add($"'command' is required for 'stdio' type servers.");
+ }
+ }
+ else
+ {
+ bool hasCommandGroup = !string.IsNullOrEmpty(server.Command) || server.Args is { } || server.Env is { };
+ if (hasCommandGroup)
+ {
+ (curErrs ??= []).Add($"'command', 'args', and 'env' fields are invalid for '{mcpType}' type servers.");
+ }
+
+ if (string.IsNullOrEmpty(server.Url))
+ {
+ (curErrs ??= []).Add($"'url' is required for '{mcpType}' type servers.");
+ }
+ else if (!Uri.TryCreate(server.Url, UriKind.Absolute, out Uri uri))
+ {
+ (curErrs ??= []).Add($"the specified value for 'url' is not a valid URI.");
+ }
+ else if (uri.Scheme != Uri.UriSchemeHttp && uri.Scheme != Uri.UriSchemeHttps)
+ {
+ (curErrs ??= []).Add($"'url' is expected to be a 'http' or 'https' resource.");
+ }
+ else
+ {
+ server.Endpoint = uri;
+ }
+ }
+
+ if (curErrs is [string onlyErr])
+ {
+ (allErrors ??= []).Add($"Server '{name}': {onlyErr}");
+ }
+ else if (curErrs is { Count: > 1 })
+ {
+ string prefix = $"Server '{name}':";
+ int size = curErrs.Sum(a => a.Length) + curErrs.Count * 5 + prefix.Length;
+ StringBuilder sb = new(prefix, capacity: size);
+
+ foreach (string element in curErrs)
+ {
+ sb.Append($"\n - {element}");
+ }
+
+ (allErrors ??= []).Add(sb.ToString());
+ }
+ }
+
+ if (allErrors is { })
+ {
+ string errorMsg = string.Join('\n', allErrors);
+ throw new InvalidOperationException(errorMsg);
+ }
+ }
+}
+
+///
+/// Configuration of a server.
+///
+internal class McpServerConfig
+{
+ [JsonPropertyName("type")]
+ public string Type { get; set; }
+
+ [JsonPropertyName("command")]
+ public string Command { get; set; }
+
+ [JsonPropertyName("args")]
+ public List Args { get; set; }
+
+ [JsonPropertyName("env")]
+ public Dictionary Env { get; set; }
+
+ [JsonPropertyName("url")]
+ public string Url { get; set; }
+
+ [JsonPropertyName("headers")]
+ public Dictionary Headers { get; set; }
+
+ internal string Name { get; set; }
+ internal Uri Endpoint { get; set; }
+ internal McpType Transport { get; set; }
+
+ internal IClientTransport ToClientTransport()
+ {
+ return Transport switch
+ {
+ McpType.stdio => new StdioClientTransport(new()
+ {
+ Name = Name,
+ Command = Command,
+ Arguments = Args,
+ EnvironmentVariables = Env,
+ }),
+
+ _ => new SseClientTransport(new()
+ {
+ Name = Name,
+ Endpoint = Endpoint,
+ AdditionalHeaders = Headers,
+ TransportMode = Transport is McpType.sse ? HttpTransportMode.Sse : HttpTransportMode.StreamableHttp,
+ ConnectionTimeout = TimeSpan.FromSeconds(15),
+ })
+ };
+ }
+}
+
+///
+/// MCP transport types.
+///
+internal enum McpType
+{
+ stdio,
+ sse,
+ http,
+}
+
+///
+/// Source generation helper for deserializing mcp.json.
+///
+[JsonSourceGenerationOptions(
+ WriteIndented = true,
+ AllowTrailingCommas = true,
+ PropertyNameCaseInsensitive = true,
+ ReadCommentHandling = JsonCommentHandling.Skip,
+ DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
+ UseStringEnumConverter = true)]
+[JsonSerializable(typeof(McpConfig))]
+[JsonSerializable(typeof(McpServerConfig))]
+[JsonSerializable(typeof(CallToolResponse))]
+internal partial class McpJsonContext : JsonSerializerContext { }
diff --git a/shell/AIShell.Kernel/MCP/McpManager.cs b/shell/AIShell.Kernel/MCP/McpManager.cs
new file mode 100644
index 00000000..b70575e8
--- /dev/null
+++ b/shell/AIShell.Kernel/MCP/McpManager.cs
@@ -0,0 +1,168 @@
+using Microsoft.Extensions.AI;
+using ModelContextProtocol.Client;
+using ModelContextProtocol.Protocol;
+
+namespace AIShell.Kernel.Mcp;
+
+internal class McpManager
+{
+ private readonly Task _initTask;
+ private readonly McpServerInitContext _context;
+ private readonly Dictionary _mcpServers;
+ private readonly TaskCompletionSource _parseMcpJsonTaskSource;
+
+ private McpConfig _mcpConfig;
+
+ internal Task ParseMcpJsonTask => _parseMcpJsonTaskSource.Task;
+
+ internal Dictionary McpServers
+ {
+ get
+ {
+ _initTask.Wait();
+ return _mcpServers;
+ }
+ }
+
+ internal McpManager(Shell shell)
+ {
+ _context = new(shell);
+ _parseMcpJsonTaskSource = new();
+ _mcpServers = new(StringComparer.OrdinalIgnoreCase);
+
+ _initTask = Task.Run(Initialize);
+ }
+
+ private void Initialize()
+ {
+ try
+ {
+ _mcpConfig = McpConfig.Load();
+ _parseMcpJsonTaskSource.SetResult(_mcpConfig);
+ }
+ catch (Exception e)
+ {
+ _parseMcpJsonTaskSource.SetException(e);
+ }
+
+ if (_mcpConfig is null)
+ {
+ return;
+ }
+
+ foreach (var (name, config) in _mcpConfig.Servers)
+ {
+ _mcpServers.Add(name, new McpServer(config, _context));
+ }
+ }
+
+ ///
+ /// Lists tools that are available at the time of the call.
+ /// Servers that are still initializing or failed will be skipped.
+ ///
+ internal async Task> ListAvailableTools()
+ {
+ await _initTask;
+
+ List tools = null;
+ foreach (var (name, server) in _mcpServers)
+ {
+ if (server.IsOperational)
+ {
+ (tools ??= []).AddRange(server.Tools.Values);
+ }
+ }
+
+ return tools;
+ }
+
+ ///
+ /// Make a tool call using the given function call data.
+ ///
+ /// The function call request.
+ /// Whether or not to capture the exception thrown from calling the tool.
+ /// Whether or not to include the exception message to the message of the call result.
+ /// The cancellation token to cancel the call.
+ ///
+ internal async Task CallToolAsync(
+ FunctionCallContent functionCall,
+ bool captureException = false,
+ bool includeDetailedErrors = false,
+ CancellationToken cancellationToken = default)
+ {
+ string serverName = null, toolName = null;
+
+ string functionName = functionCall.Name;
+ int dotIndex = functionName.IndexOf(McpTool.ServerToolSeparator);
+ if (dotIndex > 0)
+ {
+ serverName = functionName[..dotIndex];
+ toolName = functionName[(dotIndex + 1)..];
+ }
+
+ await _initTask;
+
+ McpTool tool = null;
+ if (!string.IsNullOrEmpty(serverName)
+ && !string.IsNullOrEmpty(toolName)
+ && _mcpServers.TryGetValue(serverName, out McpServer server))
+ {
+ await server.WaitForInitAsync(cancellationToken);
+ server.Tools.TryGetValue(toolName, out tool);
+ }
+
+ if (tool is null)
+ {
+ return new FunctionResultContent(
+ functionCall.CallId,
+ $"Error: Requested function \"{functionName}\" not found.");
+ }
+
+ FunctionResultContent resultContent = new(functionCall.CallId, result: null);
+
+ try
+ {
+ CallToolResponse response = await tool.CallAsync(
+ new AIFunctionArguments(arguments: functionCall.Arguments),
+ cancellationToken: cancellationToken);
+
+ resultContent.Result = (object)response ?? "Success: Function completed.";
+ }
+ catch (Exception e) when (!cancellationToken.IsCancellationRequested)
+ {
+ if (!captureException)
+ {
+ throw;
+ }
+
+ string message = "Error: Function failed.";
+ resultContent.Exception = e;
+ resultContent.Result = includeDetailedErrors ? $"{message} Exception: {e.Message}" : message;
+ }
+
+ return resultContent;
+ }
+}
+
+internal class McpServerInitContext
+{
+ ///
+ /// The throttle limit defines the maximum number of servers that can be initiated concurrently.
+ ///
+ private const int ThrottleLimit = 5;
+
+ internal McpServerInitContext(Shell shell)
+ {
+ Shell = shell;
+ ThrottleSemaphore = new SemaphoreSlim(ThrottleLimit, ThrottleLimit);
+ ClientOptions = new()
+ {
+ ClientInfo = new() { Name = "AIShell", Version = shell.Version },
+ InitializationTimeout = TimeSpan.FromSeconds(30),
+ };
+ }
+
+ internal Shell Shell { get; }
+ internal SemaphoreSlim ThrottleSemaphore { get; }
+ internal McpClientOptions ClientOptions { get; }
+}
diff --git a/shell/AIShell.Kernel/MCP/McpServer.cs b/shell/AIShell.Kernel/MCP/McpServer.cs
new file mode 100644
index 00000000..df43d3f8
--- /dev/null
+++ b/shell/AIShell.Kernel/MCP/McpServer.cs
@@ -0,0 +1,135 @@
+using ModelContextProtocol.Client;
+
+namespace AIShell.Kernel.Mcp;
+
+internal class McpServer : IDisposable
+{
+ private readonly McpServerConfig _config;
+ private readonly McpServerInitContext _context;
+ private readonly Dictionary _tools;
+ private readonly Task _initTask;
+
+ private string _serverInfo;
+ private IMcpClient _client;
+ private Exception _error;
+
+ ///
+ /// Name of the server declared in mcp.json.
+ ///
+ internal string Name => _config.Name;
+
+ ///
+ /// Gets whether the initialization is done.
+ ///
+ internal bool IsInitFinished => _initTask.IsCompleted;
+
+ ///
+ /// Gets whether the server is operational.
+ ///
+ internal bool IsOperational => _initTask.IsCompleted && _error is null;
+
+ ///
+ /// Full name and version of the server.
+ ///
+ internal string ServerInfo
+ {
+ get
+ {
+ WaitForInit();
+ return _serverInfo;
+ }
+ }
+
+ ///
+ /// The client connected to the server.
+ ///
+ internal IMcpClient Client
+ {
+ get
+ {
+ WaitForInit();
+ return _client;
+ }
+ }
+
+ ///
+ /// Exposed tools from the server.
+ ///
+ internal Dictionary Tools
+ {
+ get
+ {
+ WaitForInit();
+ return _tools;
+ }
+ }
+
+ internal Exception Error
+ {
+ get
+ {
+ WaitForInit();
+ return _error;
+ }
+ }
+
+ internal McpServer(McpServerConfig config, McpServerInitContext context)
+ {
+ _config = config;
+ _context = context;
+ _tools = new(StringComparer.OrdinalIgnoreCase);
+ _initTask = Initialize();
+ }
+
+ private async Task Initialize()
+ {
+ try
+ {
+ await _context.ThrottleSemaphore.WaitAsync();
+
+ IClientTransport transport = _config.ToClientTransport();
+ _client = await McpClientFactory.CreateAsync(transport, _context.ClientOptions);
+
+ var serverInfo = _client.ServerInfo;
+ // An MCP server may have the name included in the version info.
+ _serverInfo = serverInfo.Version.Contains(serverInfo.Name, StringComparison.OrdinalIgnoreCase)
+ ? serverInfo.Version
+ : $"{serverInfo.Name} {serverInfo.Version}";
+
+ await foreach (McpClientTool tool in _client.EnumerateToolsAsync())
+ {
+ _tools.TryAdd(tool.Name, new McpTool(Name, tool, _context.Shell.Host));
+ }
+ }
+ catch (Exception e)
+ {
+ _error = e;
+ _tools.Clear();
+ if (_client is { })
+ {
+ await _client.DisposeAsync();
+ _client = null;
+ }
+ }
+ finally
+ {
+ _context.ThrottleSemaphore.Release();
+ }
+ }
+
+ internal void WaitForInit(CancellationToken cancellationToken = default)
+ {
+ _initTask.Wait(cancellationToken);
+ }
+
+ internal async Task WaitForInitAsync(CancellationToken cancellationToken = default)
+ {
+ await _initTask.WaitAsync(cancellationToken);
+ }
+
+ public void Dispose()
+ {
+ _tools.Clear();
+ _client?.DisposeAsync().AsTask().Wait();
+ }
+}
diff --git a/shell/AIShell.Kernel/MCP/McpTool.cs b/shell/AIShell.Kernel/MCP/McpTool.cs
new file mode 100644
index 00000000..10771b01
--- /dev/null
+++ b/shell/AIShell.Kernel/MCP/McpTool.cs
@@ -0,0 +1,137 @@
+using System.Text.Json;
+using AIShell.Abstraction;
+using Microsoft.Extensions.AI;
+using ModelContextProtocol;
+using ModelContextProtocol.Client;
+using ModelContextProtocol.Protocol;
+
+namespace AIShell.Kernel.Mcp;
+
+///
+/// A wrapper class of to make sure the call to the tool always go through the AIShell.
+///
+internal class McpTool : AIFunction
+{
+ private readonly string _fullName;
+ private readonly string _serverName;
+ private readonly Host _host;
+ private readonly McpClientTool _clientTool;
+ private readonly string[] _userChoices;
+
+ internal const string ServerToolSeparator = "___";
+
+ internal McpTool(string serverName, McpClientTool clientTool, Host host)
+ {
+ _host = host;
+ _clientTool = clientTool;
+ _fullName = $"{serverName}{ServerToolSeparator}{clientTool.Name}";
+ _serverName = serverName;
+ _userChoices = ["Continue", "Cancel"];
+ }
+
+ ///
+ /// The server name for this tool.
+ ///
+ internal string ServerName => _serverName;
+
+ ///
+ /// The original tool name without the server name prefix.
+ ///
+ internal string OriginalName => _clientTool.Name;
+
+ ///
+ /// The fully qualified name of the tool in the form of '.'
+ ///
+ public override string Name => _fullName;
+
+ ///
+ public override string Description => _clientTool.Description;
+
+ ///
+ public override JsonElement JsonSchema => _clientTool.JsonSchema;
+
+ ///
+ public override JsonSerializerOptions JsonSerializerOptions => _clientTool.JsonSerializerOptions;
+
+ ///
+ public override IReadOnlyDictionary AdditionalProperties => _clientTool.AdditionalProperties;
+
+ ///
+ /// Overrides the base method with the call to . The only difference in behavior is we will serialize
+ /// the resulting such that the returned is a
+ /// containing the serialized . This method is intended to be used polymorphically via the base
+ /// class, typically as part of an operation.
+ ///
+ protected override async ValueTask
internal LLMAgent ActiveAgent => _activeAgent;
+ ///
+ /// Gets the version from the assembly attribute.
+ ///
+ internal string Version => _version;
+
+ ///
+ /// Gets the MCP manager.
+ ///
+ internal McpManager McpManager => _mcpManager;
+
#region IShell implementation
IHost IShell.Host => Host;
bool IShell.ChannelEstablished => Channel is not null;
CancellationToken IShell.CancellationToken => _cancellationSource.Token;
List IShell.ExtractCodeBlocks(string text, out List sourceInfos) => Utils.ExtractCodeBlocks(text, out sourceInfos);
+ async Task> IShell.GetAIFunctions() => await _mcpManager.ListAvailableTools();
+ async Task IShell.CallAIFunction(FunctionCallContent functionCall, bool captureException, bool includeDetailedErrors, CancellationToken cancellationToken)
+ => await _mcpManager.CallToolAsync(functionCall, captureException, includeDetailedErrors, cancellationToken);
#endregion IShell implementation
@@ -104,6 +120,7 @@ internal Shell(bool interactive, ShellArgs args)
_textToIgnore = new HashSet(StringComparer.OrdinalIgnoreCase);
_cancellationSource = new CancellationTokenSource();
_version = typeof(Shell).Assembly.GetCustomAttribute().InformationalVersion;
+ _mcpManager = new(this);
Exit = false;
Regenerate = false;
@@ -156,9 +173,21 @@ internal void ShowLandingPage()
_activeAgent.Display(Host, isWrapped ? _wrapper.Description : null);
}
+ try
+ {
+ McpConfig config = _mcpManager.ParseMcpJsonTask.GetAwaiter().GetResult();
+ if (config is { Servers.Count: > 0 })
+ {
+ Host.MarkupNoteLine($"{config.Servers.Count} MCP server(s) configured. Run {Formatter.Command("/mcp")} for details.");
+ }
+ }
+ catch (Exception e)
+ {
+ Host.WriteErrorLine($"Failed to load the 'mcp.json' file:\n{e.Message}\nRun '/mcp config' to open the 'mcp.json' file.\n");
+ }
+
// Write out help.
- Host.MarkupLine($"Run {Formatter.Command("/help")} for more instructions.")
- .WriteLine();
+ Host.MarkupLine($"Run {Formatter.Command("/help")} for more instructions.\n");
}
///
diff --git a/shell/AIShell.Kernel/Utility/LoadContext.cs b/shell/AIShell.Kernel/Utility/LoadContext.cs
index 3cb70c08..ae00182d 100644
--- a/shell/AIShell.Kernel/Utility/LoadContext.cs
+++ b/shell/AIShell.Kernel/Utility/LoadContext.cs
@@ -25,7 +25,14 @@ internal AgentAssemblyLoadContext(string name, string dependencyDir)
_dependencyDir = dependencyDir;
_runtimeLibDir = [];
_runtimeNativeDir = [];
- _cache = [];
+ _cache = new()
+ {
+ // Contracts exposed from 'AIShell.Abstraction' depend on these assemblies,
+ // so agents have to depend on the same assemblies from the default ALC.
+ // Otherwise, the contracts will break due to mis-match type identities.
+ ["System.CommandLine"] = null,
+ ["Microsoft.Extensions.AI.Abstractions"] = null,
+ };
if (OperatingSystem.IsWindows())
{
diff --git a/shell/AIShell.Kernel/Utility/Utils.cs b/shell/AIShell.Kernel/Utility/Utils.cs
index bc9428f0..d8cde85c 100644
--- a/shell/AIShell.Kernel/Utility/Utils.cs
+++ b/shell/AIShell.Kernel/Utility/Utils.cs
@@ -46,6 +46,7 @@ internal static class Utils
internal static string ConfigHome;
internal static string AppCacheDir;
internal static string AppConfigFile;
+ internal static string AppMcpFile;
internal static string AgentHome;
internal static string AgentConfigHome;
@@ -59,6 +60,7 @@ internal static void Setup(string appName)
ConfigHome = Path.Combine(locationPath, $".{AppName.Replace(' ', '-')}");
AppCacheDir = Path.Combine(ConfigHome, ".cache");
AppConfigFile = Path.Combine(ConfigHome, "config.json");
+ AppMcpFile = Path.Combine(ConfigHome, "mcp.json");
AgentHome = Path.Join(ConfigHome, "agents");
AgentConfigHome = Path.Join(ConfigHome, "agent-config");
diff --git a/shell/Markdown.VT/Markdown.VT.csproj b/shell/Markdown.VT/Markdown.VT.csproj
index 12b0e835..fcf0d7f6 100644
--- a/shell/Markdown.VT/Markdown.VT.csproj
+++ b/shell/Markdown.VT/Markdown.VT.csproj
@@ -6,9 +6,9 @@
-
+
-
+
diff --git a/shell/agents/AIShell.Ollama.Agent/AIShell.Ollama.Agent.csproj b/shell/agents/AIShell.Ollama.Agent/AIShell.Ollama.Agent.csproj
index d03afbaa..3e7464d1 100644
--- a/shell/agents/AIShell.Ollama.Agent/AIShell.Ollama.Agent.csproj
+++ b/shell/agents/AIShell.Ollama.Agent/AIShell.Ollama.Agent.csproj
@@ -17,7 +17,7 @@
-
+
diff --git a/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj b/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj
index b0c53299..21a674aa 100644
--- a/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj
+++ b/shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj
@@ -21,11 +21,10 @@
-
-
-
-
-
+
+
+
+
diff --git a/shell/agents/AIShell.OpenAI.Agent/Agent.cs b/shell/agents/AIShell.OpenAI.Agent/Agent.cs
index c1a93c76..cbbcab1a 100644
--- a/shell/agents/AIShell.OpenAI.Agent/Agent.cs
+++ b/shell/agents/AIShell.OpenAI.Agent/Agent.cs
@@ -1,8 +1,7 @@
-using System.ClientModel;
using System.Text;
using System.Text.Json;
using AIShell.Abstraction;
-using OpenAI.Chat;
+using Microsoft.Extensions.AI;
namespace AIShell.OpenAI.Agent;
@@ -116,42 +115,86 @@ public async Task ChatAsync(string input, IShell shell)
return checkPass;
}
- IAsyncEnumerator response = await host
+ IAsyncEnumerator response = await host
.RunWithSpinnerAsync(
- () => _chatService.GetStreamingChatResponseAsync(input, token)
+ () => _chatService.GetStreamingChatResponseAsync(input, shell, token)
).ConfigureAwait(false);
if (response is not null)
{
- StreamingChatCompletionUpdate update = null;
+ int? toolCalls = null;
+ bool isReasoning = false;
+ List updates = [];
using var streamingRender = host.NewStreamRender(token);
try
{
do
{
- update = response.Current;
- if (update.ContentUpdate.Count > 0)
+ if (toolCalls is 0)
{
- streamingRender.Refresh(update.ContentUpdate[0].Text);
+ toolCalls = null;
+ }
+
+ ChatResponseUpdate update = response.Current;
+ updates.Add(update);
+
+ foreach (AIContent content in update.Contents)
+ {
+ if (content is TextReasoningContent reason)
+ {
+ if (isReasoning)
+ {
+ streamingRender.Refresh(reason.Text);
+ }
+ else
+ {
+ isReasoning = true;
+ streamingRender.Refresh($"\n{reason.Text}");
+ }
+
+ continue;
+ }
+
+ string message = content switch
+ {
+ TextContent text => text.Text ?? string.Empty,
+ ErrorContent error => error.Message ?? string.Empty,
+ _ => null
+ };
+
+ if (message is null)
+ {
+ toolCalls = content switch
+ {
+ FunctionCallContent => (toolCalls + 1) ?? 1,
+ FunctionResultContent => toolCalls - 1,
+ _ => toolCalls
+ };
+ }
+
+ if (isReasoning)
+ {
+ isReasoning = false;
+ message = $"\n\n\n{message}";
+ }
+
+ if (!string.IsNullOrEmpty(message))
+ {
+ streamingRender.Refresh(message);
+ }
}
}
- while (await response.MoveNextAsync().ConfigureAwait(continueOnCapturedContext: false));
+ while (toolCalls is 0
+ ? await host.RunWithSpinnerAsync(() => response.MoveNextAsync().AsTask()).ConfigureAwait(false)
+ : await response.MoveNextAsync().ConfigureAwait(false));
}
catch (OperationCanceledException)
{
- update = null;
+ // Ignore cancellation exception.
}
- if (update is null)
- {
- _chatService.CalibrateChatHistory(usage: null, response: null);
- }
- else
- {
- string responseContent = streamingRender.AccumulatedContent;
- _chatService.CalibrateChatHistory(update.Usage, new AssistantChatMessage(responseContent));
- }
+ _chatService.ChatHistory.AddMessages(updates);
}
return checkPass;
diff --git a/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs b/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs
index 2ce1dd9f..790d705a 100644
--- a/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs
+++ b/shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs
@@ -1,5 +1,3 @@
-using Microsoft.ML.Tokenizers;
-
namespace AIShell.OpenAI.Agent;
internal class ModelInfo
@@ -11,7 +9,6 @@ internal class ModelInfo
private const string Gpt34Encoding = "cl100k_base";
private static readonly Dictionary s_modelMap;
- private static readonly Dictionary> s_encodingMap;
// A rough estimate to cover all third-party models.
// - most popular models today support 32K+ context length;
@@ -37,21 +34,12 @@ static ModelInfo()
// Azure naming of the 'gpt-3.5-turbo' models
["gpt-35-turbo"] = new(tokenLimit: 16_385),
};
-
- // The first call to 'GptEncoding.GetEncoding' is very slow, taking about 2 seconds on my machine.
- // We don't immediately need the encodings at the startup, so by getting the values in tasks,
- // we don't block the startup and the values will be ready when we really need them.
- s_encodingMap = new(StringComparer.OrdinalIgnoreCase)
- {
- [Gpt34Encoding] = Task.Run(() => (Tokenizer)TiktokenTokenizer.CreateForEncoding(Gpt34Encoding)),
- [Gpt4oEncoding] = Task.Run(() => (Tokenizer)TiktokenTokenizer.CreateForEncoding(Gpt4oEncoding))
- };
}
private ModelInfo(int tokenLimit, string encoding = null, bool reasoning = false)
{
TokenLimit = tokenLimit;
- _encodingName = encoding ?? Gpt34Encoding;
+ EncodingName = encoding ?? Gpt34Encoding;
// For gpt4o, gpt4 and gpt3.5-turbo, the following 2 properties are the same.
// See https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
@@ -60,22 +48,11 @@ private ModelInfo(int tokenLimit, string encoding = null, bool reasoning = false
Reasoning = reasoning;
}
- private readonly string _encodingName;
- private Tokenizer _gptEncoding;
-
+ internal string EncodingName { get; }
internal int TokenLimit { get; }
internal int TokensPerMessage { get; }
internal int TokensPerName { get; }
internal bool Reasoning { get; }
- internal Tokenizer Encoding
- {
- get {
- _gptEncoding ??= s_encodingMap.TryGetValue(_encodingName, out Task value)
- ? value.Result
- : TiktokenTokenizer.CreateForEncoding(_encodingName);
- return _gptEncoding;
- }
- }
///
/// Try resolving the specified model name.
diff --git a/shell/agents/AIShell.OpenAI.Agent/Service.cs b/shell/agents/AIShell.OpenAI.Agent/Service.cs
index 026a5fde..b366b915 100644
--- a/shell/agents/AIShell.OpenAI.Agent/Service.cs
+++ b/shell/agents/AIShell.OpenAI.Agent/Service.cs
@@ -1,104 +1,47 @@
using System.ClientModel;
using System.ClientModel.Primitives;
using Azure.AI.OpenAI;
-using Azure.Core;
using Azure.Identity;
-using Microsoft.ML.Tokenizers;
+using Microsoft.Extensions.AI;
using OpenAI;
-using OpenAI.Chat;
+
+using OpenAIChatClient = OpenAI.Chat.ChatClient;
+using ChatMessage = Microsoft.Extensions.AI.ChatMessage;
+using AIShell.Abstraction;
namespace AIShell.OpenAI.Agent;
internal class ChatService
{
// TODO: Maybe expose this to our model registration?
- // We can still use 1000 as the default value.
- private const int MaxResponseToken = 2000;
+ private const int MaxResponseToken = 1000;
private readonly string _historyRoot;
private readonly List _chatHistory;
- private readonly List _chatHistoryTokens;
- private readonly ChatCompletionOptions _chatOptions;
+ private readonly ChatOptions _chatOptions;
private GPT _gptToUse;
private Settings _settings;
- private ChatClient _client;
- private int _totalInputToken;
+ private IChatClient _client;
internal ChatService(string historyRoot, Settings settings)
{
_chatHistory = [];
- _chatHistoryTokens = [];
_historyRoot = historyRoot;
-
- _totalInputToken = 0;
_settings = settings;
- _chatOptions = new ChatCompletionOptions()
+ _chatOptions = new ChatOptions()
{
- MaxOutputTokenCount = MaxResponseToken,
+ MaxOutputTokens = MaxResponseToken,
};
}
internal List ChatHistory => _chatHistory;
- internal void AddResponseToHistory(string response)
- {
- if (string.IsNullOrEmpty(response))
- {
- return;
- }
-
- _chatHistory.Add(ChatMessage.CreateAssistantMessage(response));
- }
-
internal void RefreshSettings(Settings settings)
{
_settings = settings;
}
- ///
- /// It's almost impossible to relative-accurately calculate the token counts of all
- /// messages, especially when tool calls are involved (tool call definitions and the
- /// tool call payloads in AI response).
- /// So, I decide to leverage the useage report from AI to track the token count of
- /// the chat history. It's also an estimate, but I think more accurate than doing the
- /// counting by ourselves.
- ///
- internal void CalibrateChatHistory(ChatTokenUsage usage, AssistantChatMessage response)
- {
- if (usage is null)
- {
- // Response was cancelled and we will remove the last query from history.
- int index = _chatHistory.Count - 1;
- _chatHistory.RemoveAt(index);
- _chatHistoryTokens.RemoveAt(index);
-
- return;
- }
-
- // Every reply is primed with <|start|>assistant<|message|>, so we subtract 3 from the 'InputTokenCount'.
- int promptTokenCount = usage.InputTokenCount - 3;
- // 'ReasoningTokenCount' should be 0 for non-o1 models.
- int reasoningTokenCount = usage.OutputTokenDetails is null ? 0 : usage.OutputTokenDetails.ReasoningTokenCount;
- int responseTokenCount = usage.OutputTokenCount - reasoningTokenCount;
-
- if (_totalInputToken is 0)
- {
- // It was the first user message, so instead of adjusting the user message token count,
- // we set the token count for system message and tool calls.
- _chatHistoryTokens[0] = promptTokenCount - _chatHistoryTokens[^1];
- }
- else
- {
- // Adjust the token count of the user message, as our calculation is an estimate.
- _chatHistoryTokens[^1] = promptTokenCount - _totalInputToken;
- }
-
- _chatHistory.Add(response);
- _chatHistoryTokens.Add(responseTokenCount);
- _totalInputToken = promptTokenCount + responseTokenCount;
- }
-
private void RefreshOpenAIClient()
{
if (ReferenceEquals(_gptToUse, _settings.Active))
@@ -110,7 +53,6 @@ private void RefreshOpenAIClient()
GPT old = _gptToUse;
_gptToUse = _settings.Active;
_chatHistory.Clear();
- _chatHistoryTokens.Clear();
if (old is not null
&& old.Type == _gptToUse.Type
@@ -124,6 +66,7 @@ private void RefreshOpenAIClient()
return;
}
+ OpenAIChatClient client;
EndpointType type = _gptToUse.Type;
// Reasoning models do not support the temperature setting.
_chatOptions.Temperature = _gptToUse.ModelInfo.Reasoning ? null : 0;
@@ -154,7 +97,7 @@ private void RefreshOpenAIClient()
new ApiKeyCredential(azOpenAIApiKey),
clientOptions);
- _client = aiClient.GetChatClient(_gptToUse.Deployment);
+ client = aiClient.GetChatClient(_gptToUse.Deployment);
}
else
{
@@ -165,7 +108,7 @@ private void RefreshOpenAIClient()
credential,
clientOptions);
- _client = aiClient.GetChatClient(_gptToUse.Deployment);
+ client = aiClient.GetChatClient(_gptToUse.Deployment);
}
}
else
@@ -179,91 +122,50 @@ private void RefreshOpenAIClient()
string userKey = Utils.ConvertFromSecureString(_gptToUse.Key);
var aiClient = new OpenAIClient(new ApiKeyCredential(userKey), clientOptions);
- _client = aiClient.GetChatClient(_gptToUse.ModelName);
- }
- }
-
- ///
- /// Reference: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
- ///
- private int CountTokenForUserMessage(UserChatMessage message)
- {
- ModelInfo modelDetail = _gptToUse.ModelInfo;
- Tokenizer encoding = modelDetail.Encoding;
-
- // Tokens per message plus 1 token for the role.
- int tokenNumber = modelDetail.TokensPerMessage + 1;
- foreach (ChatMessageContentPart part in message.Content)
- {
- tokenNumber += encoding.CountTokens(part.Text);
+ client = aiClient.GetChatClient(_gptToUse.ModelName);
}
- return tokenNumber;
+ _client = client.AsIChatClient()
+ .AsBuilder()
+ .UseFunctionInvocation(configure: c => c.IncludeDetailedErrors = true)
+ .Build();
}
private void PrepareForChat(string input)
{
+ const string Guidelines = """
+ ## Tool Use Guidelines
+ You may have access to external tools.
+ Before making any tool call, you must first explain the reason for using the tool. Only issue the tool call after providing this explanation.
+
+ ## Other Guidelines
+ """;
+
// Refresh the client in case the active model was changed.
RefreshOpenAIClient();
if (_chatHistory.Count is 0)
{
- _chatHistory.Add(ChatMessage.CreateSystemMessage(_gptToUse.SystemPrompt));
- _chatHistoryTokens.Add(0);
+ string system = $"{Guidelines}\n{_gptToUse.SystemPrompt}";
+ _chatHistory.Add(new(ChatRole.System, system));
}
- var userMessage = new UserChatMessage(input);
- int msgTokenCnt = CountTokenForUserMessage(userMessage);
- _chatHistory.Add(userMessage);
- _chatHistoryTokens.Add(msgTokenCnt);
-
- int inputLimit = _gptToUse.ModelInfo.TokenLimit;
- // Every reply is primed with <|start|>assistant<|message|>, so adding 3 tokens.
- int newTotal = _totalInputToken + msgTokenCnt + 3;
-
- // Shrink the chat history if we have less than 50 free tokens left (50-token buffer).
- while (inputLimit - newTotal < 50)
- {
- // We remove a round of conversation for every trimming operation.
- int userMsgCnt = 0;
- List indices = [];
-
- for (int i = 0; i < _chatHistory.Count; i++)
- {
- if (_chatHistory[i] is UserChatMessage)
- {
- if (userMsgCnt is 1)
- {
- break;
- }
-
- userMsgCnt++;
- }
-
- if (userMsgCnt is 1)
- {
- indices.Add(i);
- }
- }
-
- foreach (int i in indices)
- {
- newTotal -= _chatHistoryTokens[i];
- }
-
- _chatHistory.RemoveRange(indices[0], indices.Count);
- _chatHistoryTokens.RemoveRange(indices[0], indices.Count);
- _totalInputToken = newTotal - msgTokenCnt;
- }
+ _chatHistory.Add(new(ChatRole.User, input));
}
- public async Task> GetStreamingChatResponseAsync(string input, CancellationToken cancellationToken)
+ public async Task> GetStreamingChatResponseAsync(string input, IShell shell, CancellationToken cancellationToken)
{
try
{
PrepareForChat(input);
- IAsyncEnumerator enumerator = _client
- .CompleteChatStreamingAsync(_chatHistory, _chatOptions, cancellationToken)
+ var tools = await shell.GetAIFunctions();
+ if (tools is { Count: > 0 })
+ {
+ _chatOptions.Tools = [.. tools];
+ }
+
+ IAsyncEnumerator enumerator = _client
+ .GetStreamingResponseAsync(_chatHistory, _chatOptions, cancellationToken)
.GetAsyncEnumerator(cancellationToken);
return await enumerator